In [19]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from scipy.spatial.transform import Rotation as R
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [20]:
SOURCE_H5_PATH = "../dataset.h5"
DATASET_PATH = "../easy-dataset.h5"
INTRINSICS = {"focalLength": 1.0, "width": 420.0, "height": 360.0, "zRange": [0.01, 2.0]}

device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda:5


In [21]:
def move_to_device(batch: tuple[torch.Tensor, ...]):
    return tuple(item.to(device, non_blocking=True) for item in batch)

In [22]:
def quat_to_matrix(quaternions: np.typing.ArrayLike):
    return R.from_quat(np.roll(quaternions, -1, axis=-1)).as_matrix()


def project_points(world_point, cam_poses):
    cam_positions = cam_poses[:, :3]
    cam_quaternions = cam_poses[:, 3:]

    rots = quat_to_matrix(cam_quaternions)

    points_rel = world_point[None, :] - cam_positions
    p_cam = np.einsum("nij,nj->ni", rots.transpose(0, 2, 1), points_rel)

    f, w, h = (INTRINSICS["focalLength"], INTRINSICS["width"], INTRINSICS["height"])
    u = (f * p_cam[:, 0] / p_cam[:, 2]) * (w / 2) + w / 2
    v = (f * p_cam[:, 1] / p_cam[:, 2]) * (h / 2) + h / 2
    return np.stack([u, v], axis=1)


def make_gaussian_maps(centers, sigma=8):
    width, height = int(INTRINSICS["width"]), int(INTRINSICS["height"])
    X, Y = np.meshgrid(np.arange(width), np.arange(height))
    u = centers[:, 0, None, None]
    v = centers[:, 1, None, None]

    g = np.exp(-((X - u) ** 2 + (Y - v) ** 2) / (2 * sigma**2))
    g /= g.max(axis=(1, 2), keepdims=True) + 1e-8
    return g.astype(np.float32)


def compute_pose_scalar_diffs(pose, final_pose):
    pos_i, rot_i = pose[..., :3], pose[..., 3:]
    pos_f, rot_f = final_pose[..., :3], final_pose[..., 3:]

    pos_diff = np.linalg.norm(pos_f - pos_i).astype(np.float32)

    rot_i = rot_i / np.linalg.norm(rot_i)
    rot_f = rot_f / np.linalg.norm(rot_f)

    dot = np.clip(np.abs(np.sum(rot_i * rot_f)), -1.0, 1.0)
    rot_diff = (2.0 * np.arccos(dot)).astype(np.float32)

    return np.array([pos_diff, rot_diff], dtype=np.float32)

In [5]:
with h5py.File(SOURCE_H5_PATH, "r") as f_in, h5py.File(DATASET_PATH, "w") as f_out:
    f_out.attrs.update(f_in.attrs)

    for dp_key in tqdm(f_in.keys(), desc="Processing scenes"):
        dp = f_in[dp_key]

        depths = dp["depths"][()]
        cam_poses = dp["cam_poses"][()]
        obj_ids = dp["obj_ids"][()]
        seg_ids = dp["seg_ids"][()]
        poses = dp["poses"][()]
        final_poses = dp["final_poses"][()]

        for oi in range(poses.shape[0]):
            target_pose = dp["target_poses"][oi]
            pose_diffs = compute_pose_scalar_diffs(final_poses[oi], target_pose)

            centers = project_points(target_pose[:3], cam_poses)
            goal_maps = make_gaussian_maps(centers, sigma=6)

            dp_group = f_out.create_group(f"{dp_key}_obj_{oi}")
            dp_group.create_dataset("depths", data=depths)
            dp_group.create_dataset("masks", data=(seg_ids == obj_ids[oi]).astype(np.float32))
            dp_group.create_dataset("goal_maps", data=goal_maps)
            dp_group.create_dataset("quat", data=poses[oi][3:])
            dp_group.create_dataset("feasibles", data=dp["feasibles"][oi])
            dp_group.create_dataset("pose_diffs", data=pose_diffs)

Processing scenes: 100%|██████████| 1300/1300 [01:29<00:00, 14.52it/s]


In [23]:
class EasyDataset(Dataset):
    def __init__(self):
        self._h5_file = None

        with h5py.File(DATASET_PATH, "r") as f:
            self.keys = list(f.keys())
            self.primitives = list(f.attrs["primitives"])

    def _init_h5(self):
        if self._h5_file is None:
            self._h5_file = h5py.File(DATASET_PATH, "r")

    def close(self):
        if self._h5_file is not None:
            self._h5_file.close()
            self._h5_file = None

    def __del__(self):
        self.close()

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        self._init_h5()
        dp = self._h5_file[self.keys[idx]]
        return (dp["depths"][()], dp["masks"][()], dp["goal_maps"][()], dp["quat"][()], dp["feasibles"][()], dp["pose_diffs"][()])

In [24]:
dataset = EasyDataset()
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
primitives = dataset.primitives
num_primitives = len(primitives)
print(len(dataset))

5943


In [25]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))


class MLPBlock(nn.Module):
    def __init__(self, in_features: int, out_features: int, dropout: float = 0.2):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.relu(self.linear(x)))


class EasyNet(nn.Module):
    def __init__(self, num_views: int, num_primitives: int):
        super().__init__()
        self.num_views = num_views

        in_channels = 3  # depth, mask, goal
        self.feature_extractor = nn.Sequential(
            ConvBlock(in_channels, 32, stride=2),
            ConvBlock(32, 64, stride=1),
            ConvBlock(64, 64, stride=2),
            ConvBlock(64, 128, stride=1),
            ConvBlock(128, 128, stride=2),
            ConvBlock(128, 256, stride=1),
        )

        self.spatial_pool = nn.AdaptiveAvgPool2d((7, 6))

        self.fusion_conv = nn.Sequential(
            ConvBlock(256 * num_views, 512),
            ConvBlock(512, 256),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
        )

        self.quat_encoder = nn.Sequential(
            nn.Linear(4, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 64),
            nn.ReLU(inplace=True),
        )

        self.mlp = nn.Sequential(
            MLPBlock(256 + 64, 512),
            MLPBlock(512, 512),
            MLPBlock(512, 256),
        )

        self.feasibility_head = nn.Linear(256, num_primitives)
        self.pos_diff_head = nn.Linear(256, num_primitives)
        self.rot_diff_head = nn.Linear(256, num_primitives)

    def forward(self, depths, masks, goals, quat):
        per_view_features = []
        for v in range(self.num_views):
            x = torch.stack([depths[:, v], masks[:, v], goals[:, v]], dim=1)
            features = self.feature_extractor(x)
            features = self.spatial_pool(features)
            per_view_features.append(features)

        multi_view = torch.cat(per_view_features, dim=1)
        visual_features = self.fusion_conv(multi_view)
        quat_features = self.quat_encoder(quat)

        combined = torch.cat([visual_features, quat_features], dim=1)
        features = self.mlp(combined)

        feasibility_logits = self.feasibility_head(features)
        pos_diffs = self.pos_diff_head(features)
        rot_diffs = self.rot_diff_head(features)
        return feasibility_logits, pos_diffs, rot_diffs

In [26]:
def compute_feasibility_weights(dataset):
    pos_counts = np.zeros(num_primitives, dtype=np.float32)
    neg_counts = np.zeros(num_primitives, dtype=np.float32)

    for idx in tqdm(range(len(dataset))):
        _, _, _, _, feasibles, _ = dataset[idx]
        pos_counts += feasibles
        neg_counts += 1 - feasibles

    pos_weights = neg_counts / (pos_counts + 1e-8)

    for i, prim in enumerate(dataset.primitives):
        total = pos_counts[i] + neg_counts[i]
        pos_ratio = pos_counts[i] / total
        print(f"{prim:15s}: {pos_ratio:.2%} feasible, pos_weight={pos_weights[i]:.3f}")

    return torch.tensor(pos_weights, dtype=torch.float32)

In [27]:
pos_weights = compute_feasibility_weights(dataset).to(device)
bce_loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weights)

100%|██████████| 5943/5943 [00:16<00:00, 360.57it/s]

push_x_pos     : 16.84% feasible, pos_weight=4.937
push_x_neg     : 17.75% feasible, pos_weight=4.633
push_y_pos     : 17.99% feasible, pos_weight=4.559
push_y_neg     : 16.19% feasible, pos_weight=5.178
lift_x         : 46.47% feasible, pos_weight=1.152
lift_y         : 47.28% feasible, pos_weight=1.115
pull_x         : 40.57% feasible, pos_weight=1.465
pull_y         : 40.80% feasible, pos_weight=1.451





In [32]:
model = EasyNet(num_views=3, num_primitives=num_primitives).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=3)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

5861944


In [33]:
train_losses = []
val_losses = []

for epoch in range(10):
    model.train()
    train_loss = 0.0
    for batch in train_loader:
        depths, masks, goals, quat, feasibles, _ = move_to_device(batch)
        optimizer.zero_grad()
        feasibility_logits, _, _ = model(depths, masks, goals, quat)
        loss = bce_loss_fn(feasibility_logits, feasibles)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * depths.size(0)
    train_loss /= len(train_dataset)
    train_losses.append(train_loss)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in test_loader:
            depths, masks, goals, quat, feasibles, _ = move_to_device(batch)
            feasibility_logits, _, _ = model(depths, masks, goals, quat)
            loss = bce_loss_fn(feasibility_logits, feasibles)
            val_loss += loss.item() * depths.size(0)
    val_loss /= len(test_dataset)
    val_losses.append(val_loss)

    # scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]["lr"]
    print(f"epoch {epoch + 1:02d}, train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, lr={current_lr:.2e}")

epoch 01, train_loss=0.9700, val_loss=0.9660, lr=1.00e-02
epoch 02, train_loss=0.9639, val_loss=0.9657, lr=1.00e-02
epoch 03, train_loss=0.9636, val_loss=0.9654, lr=1.00e-02
epoch 04, train_loss=0.9637, val_loss=0.9655, lr=1.00e-02
epoch 05, train_loss=0.9638, val_loss=0.9656, lr=1.00e-02
epoch 06, train_loss=0.9636, val_loss=0.9655, lr=1.00e-02
epoch 07, train_loss=0.9638, val_loss=0.9658, lr=1.00e-02
epoch 08, train_loss=0.9638, val_loss=0.9662, lr=1.00e-02
epoch 09, train_loss=0.9637, val_loss=0.9663, lr=1.00e-02
epoch 10, train_loss=0.9637, val_loss=0.9660, lr=1.00e-02


In [17]:
model.eval()
tp = np.zeros(num_primitives)
fp = np.zeros(num_primitives)
tn = np.zeros(num_primitives)
fn = np.zeros(num_primitives)

with torch.no_grad():
    for batch in test_loader:
        depths, masks, goals, quat, feasibles, _ = move_to_device(batch)
        feasibility_logits, _, _ = model(depths, masks, goals, quat)
        preds = (torch.sigmoid(feasibility_logits) > 0.5).float()
        feasibles_np = feasibles.cpu().numpy()
        preds_np = preds.cpu().numpy()

        tp += ((preds_np == 1) & (feasibles_np == 1)).sum(axis=0)
        fp += ((preds_np == 1) & (feasibles_np == 0)).sum(axis=0)
        tn += ((preds_np == 0) & (feasibles_np == 0)).sum(axis=0)
        fn += ((preds_np == 0) & (feasibles_np == 1)).sum(axis=0)

print("Per-primitive metrics:")
print(f"{'Primitive':<15} {'Accuracy':>8} {'Precision':>9} {'Recall':>8} {'F1':>8}")

for i, prim in enumerate(primitives):
    accuracy = (tp[i] + tn[i]) / (tp[i] + tn[i] + fp[i] + fn[i])
    precision = tp[i] / (tp[i] + fp[i]) if (tp[i] + fp[i]) > 0 else 0.0
    recall = tp[i] / (tp[i] + fn[i]) if (tp[i] + fn[i]) > 0 else 0.0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

    print(f"{prim:<15} {accuracy:>8.2%} {precision:>9.2%} {recall:>8.2%} {f1:>8.3f}")

overall_accuracy = (tp.sum() + tn.sum()) / (tp.sum() + tn.sum() + fp.sum() + fn.sum())
overall_precision = tp.sum() / (tp.sum() + fp.sum())
overall_recall = tp.sum() / (tp.sum() + fn.sum())
overall_f1 = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall)

print(f"{'Overall':<15} {overall_accuracy:>8.2%} {overall_precision:>9.2%} {overall_recall:>8.2%} {overall_f1:>8.3f}")

Per-primitive metrics:
Primitive       Accuracy Precision   Recall       F1
push_x_pos        93.69%    73.38%   97.47%    0.837
push_x_neg        91.92%    70.37%  100.00%    0.826
push_y_pos        90.91%    65.59%   99.51%    0.791
push_y_neg        91.67%    65.98%  100.00%    0.795
lift_x            56.23%    54.09%   48.31%    0.510
lift_y            54.46%    51.45%   84.15%    0.639
pull_x            58.16%    49.91%   54.23%    0.520
pull_y            53.03%    46.75%   77.18%    0.582
Overall           73.76%    55.74%   75.34%    0.641
