In [19]:
import h5py
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision.models import MobileNet_V3_Small_Weights, mobilenet_v3_small

In [20]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda:2


In [21]:
class VisionPoseRegressionDataset(Dataset):
    def __init__(self, h5_path: str):
        self.h5_path = h5_path
        self._h5_file = None

        with h5py.File(h5_path, "r") as f:
            self.keys = list(f.keys())
            self.primitives = list(f.attrs["primitives"])

    def _init_h5(self):
        if self._h5_file is None:
            self._h5_file = h5py.File(self.h5_path, "r")

    def close(self):
        if self._h5_file is not None:
            self._h5_file.close()
            self._h5_file = None

    def __del__(self):
        self.close()

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        self._init_h5()
        dp = self._h5_file[self.keys[idx]]
        return (dp["depths"][()], dp["masks"][()], dp["goal_maps"][()], dp["quat"][()], dp["feasibles"][()], dp["pose_diffs"][()])

In [None]:
dataset = VisionPoseRegressionDataset("../vision_pose_regression_dataset.h5")
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])
primitives = dataset.primitives
num_primitives = len(primitives)
print(dataset.primitives)

In [23]:
len(dataset)

5943

In [24]:
def calculate_pos_weight() -> torch.Tensor:
    positive_counts = torch.zeros(num_primitives)
    for batch in DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=4):
        positive_counts += batch[4].sum(dim=0)
    return (len(train_dataset) - positive_counts) / (positive_counts + 1e-8)


pos_weight = calculate_pos_weight().to(device)

In [25]:
class VisionPoseRegressionNet(nn.Module):
    def __init__(self, hidden_dim: int = 64, head_layers: int = 2, dropout: float = 0.0):
        super().__init__()

        weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
        mnet = mobilenet_v3_small(weights=weights)
        self.cnn = mnet.features
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        cnn_out_dim = 576

        self.pose_fc = nn.Sequential(nn.Linear(4, hidden_dim), nn.ReLU())

        self.view_attn = nn.Sequential(
            nn.Linear(cnn_out_dim, cnn_out_dim // 4),
            nn.ReLU(),
            nn.Linear(cnn_out_dim // 4, 1),
        )

        mlp_input_dim = cnn_out_dim + hidden_dim

        def mlp(in_dim, out_dim):
            layers = []
            h = hidden_dim
            for _ in range(head_layers - 1):
                layers += [nn.Linear(in_dim, h), nn.ReLU()]
                if dropout > 0.0:
                    layers.append(nn.Dropout(dropout))
                in_dim = h
            layers.append(nn.Linear(in_dim, out_dim))
            return nn.Sequential(*layers)

        self.head_feas = mlp(mlp_input_dim, num_primitives)
        self.head_pose = mlp(mlp_input_dim, num_primitives * 7)

    def forward(self, depths, masks, goal_maps, quat):
        B, V, H, W = depths.shape
        per_view_feats = []

        for v in range(V):
            x_in = torch.stack([depths[:, v], masks[:, v], goal_maps[:, v]], dim=1)
            f = self.global_pool(self.cnn(x_in)).flatten(1)
            per_view_feats.append(f)

        per_view_feats = torch.stack(per_view_feats, dim=1)

        attn_logits = self.view_attn(per_view_feats)
        attn = torch.softmax(attn_logits, dim=1)
        x_img = (attn * per_view_feats).sum(dim=1)

        x_pose = self.pose_fc(quat)
        x = torch.cat([x_img, x_pose], dim=1)

        feas_logits = self.head_feas(x)
        pose_diff = self.head_pose(x).reshape(B, -1, 7)
        return feas_logits, pose_diff

In [26]:
visionmodel = VisionPoseRegressionNet(hidden_dim=256, head_layers=4, dropout=0.125).to(device)
optimizer = torch.optim.Adam(visionmodel.parameters(), lr=1e-3)
bce_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
mse_loss = nn.MSELoss(reduction="none")
sum(p.numel() for p in visionmodel.parameters() if p.requires_grad)

1717633

In [27]:
BATCH_SIZE = 32

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
train_dataset_size = len(train_loader.dataset)
test_dataset_size = len(test_loader.dataset)

In [31]:
POSE_LOSS_WEIGHT = 20.0

def move_to_device(batch: tuple[torch.Tensor, ...]):
    return tuple(item.to(device, non_blocking=True) for item in batch)


def compute_masked_pose_loss(pred_pose: torch.Tensor, target_pose: torch.Tensor, feas_mask: torch.Tensor):
    if feas_mask.sum() == 0:
        return pred_pose.new_tensor(0.0)
    diff = mse_loss(pred_pose, target_pose).mean(dim=-1)
    diff = diff * feas_mask
    return POSE_LOSS_WEIGHT * diff.sum() / (feas_mask.sum() + 1e-8)

In [32]:
def evaluate():
    visionmodel.eval()
    total_feas_loss, total_pose_loss, correct = 0.0, 0.0, 0

    with torch.no_grad():
        for batch in test_loader:
            depths, masks, goal_maps, quat, feasibles, pose_diffs = move_to_device(batch)

            feas_logits, pose_pred = visionmodel(depths, masks, goal_maps, quat)
            feas_loss = bce_loss(feas_logits, feasibles)
            pose_loss = compute_masked_pose_loss(pose_pred, pose_diffs, feasibles)

            batch_size = depths.size(0)
            total_feas_loss += feas_loss * batch_size
            total_pose_loss += pose_loss * batch_size

            preds = torch.sigmoid(feas_logits) > 0.5
            correct += (preds == feasibles).sum()

    feas_loss = total_feas_loss / test_dataset_size
    pose_loss = total_pose_loss / test_dataset_size
    acc = correct / (test_dataset_size * num_primitives)
    return feas_loss, pose_loss, acc

In [33]:
NUM_EPOCHS = 10
EVAL_EVERY = 1

for epoch in range(NUM_EPOCHS):
    visionmodel.train()
    total_feas_loss, total_pose_loss = 0.0, 0.0

    for batch in train_loader:
        depths, masks, goal_maps, quat, feasibles, pose_diffs = move_to_device(batch)

        optimizer.zero_grad()
        feas_logits, pose_pred = visionmodel(depths, masks, goal_maps, quat)
        feas_loss = bce_loss(feas_logits, feasibles)
        pose_loss = compute_masked_pose_loss(pose_pred, pose_diffs, feasibles)
        loss = feas_loss + pose_loss
        loss.backward()
        optimizer.step()

        batch_size = depths.size(0)
        total_feas_loss += feas_loss * batch_size
        total_pose_loss += pose_loss * batch_size

    if (epoch + 1) % EVAL_EVERY == 0:
        feas_loss = total_feas_loss / train_dataset_size
        pose_loss = total_pose_loss / train_dataset_size
        val_feas_loss, val_pose_loss, val_acc = evaluate()
        print(
            f"epoch {epoch + 1:03d}: "
            f"train=(feas_loss={feas_loss:.4f}, pose_loss={pose_loss:.4f}), "
            f"val=(feas_loss={val_feas_loss:.4f}, pose_loss={val_pose_loss:.4f}, acc={val_acc:.2%})"
        )

epoch 001: train=(feas_loss=0.9633, pose_loss=0.2448), val=(feas_loss=0.9734, pose_loss=0.2400, acc=54.97%)
epoch 002: train=(feas_loss=0.9019, pose_loss=0.2418), val=(feas_loss=0.8646, pose_loss=0.2481, acc=59.86%)
epoch 003: train=(feas_loss=0.7682, pose_loss=0.2377), val=(feas_loss=2.5582, pose_loss=0.2525, acc=56.52%)
epoch 004: train=(feas_loss=0.5893, pose_loss=0.2335), val=(feas_loss=1.8807, pose_loss=0.2407, acc=62.32%)
epoch 005: train=(feas_loss=0.5431, pose_loss=0.2305), val=(feas_loss=1.7328, pose_loss=0.2426, acc=64.88%)
epoch 006: train=(feas_loss=0.5353, pose_loss=0.2291), val=(feas_loss=1.5449, pose_loss=0.2400, acc=66.29%)
epoch 007: train=(feas_loss=0.5279, pose_loss=0.2290), val=(feas_loss=1.6971, pose_loss=0.2381, acc=66.33%)
epoch 008: train=(feas_loss=0.5297, pose_loss=0.2282), val=(feas_loss=1.6933, pose_loss=0.2436, acc=64.25%)
epoch 009: train=(feas_loss=0.5245, pose_loss=0.2274), val=(feas_loss=1.7764, pose_loss=0.2398, acc=64.17%)
epoch 010: train=(feas_loss=