In [9]:
import torch.nn as nn
import torch.nn.functional as F

class BoundingBox3DNet(nn.Module):
    def __init__(self, num_bins=8, feature_extractor='resnet50'):
        super(BoundingBox3DNet, self).__init__()
        
        # Store num_bins
        self.num_bins = num_bins
        
        # Feature extractor (pre-trained backbone)
        self.backbone, feature_dim = self.get_backbone(feature_extractor)

        # Dimension regression branch
        self.dim_branch = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 3)  # Output: dx, dy, dz
        )

        # Yaw orientation (confidence and residuals)
        self.yaw_conf_branch = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, num_bins)
        )
        self.yaw_res_branch = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2 * num_bins)  # sin(Δθ), cos(Δθ) for each bin
        )

        # Pitch orientation
        self.pitch_conf_branch = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, num_bins)
        )
        self.pitch_res_branch = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2 * num_bins)
        )

        # Roll orientation
        self.roll_conf_branch = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, num_bins)
        )
        self.roll_res_branch = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2 * num_bins)
        )

        # Translation regression branch
        self.trans_branch = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 3)  # Output: tx, ty, tz
        )

        # Optional: Corner prediction branch
        self.corner_branch = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 24)  # Output: 8 corners (normalized positions)
        )

    def forward(self, x):
        # Extract features
        features = self.backbone(x)
        features = features.view(features.size(0), -1)  # Flatten

        # Dimensions
        dims = self.dim_branch(features)

        # Yaw orientation
        yaw_conf = self.yaw_conf_branch(features)
        yaw_res = self.yaw_res_branch(features)
        yaw_res = yaw_res.view(-1, self.num_bins, 2)  # Reshape to (batch, bins, [sin, cos])

        # Pitch orientation
        pitch_conf = self.pitch_conf_branch(features)
        pitch_res = self.pitch_res_branch(features)
        pitch_res = pitch_res.view(-1, self.num_bins, 2)

        # Roll orientation
        roll_conf = self.roll_conf_branch(features)
        roll_res = self.roll_res_branch(features)
        roll_res = roll_res.view(-1, self.num_bins, 2)

        # Translation
        translation = self.trans_branch(features)

        # Corners (optional)
        corners = self.corner_branch(features)

        return dims, yaw_conf, yaw_res, pitch_conf, pitch_res, roll_conf, roll_res, translation, corners

    def get_backbone(self, model_name):
        if model_name == 'resnet50':
            from torchvision.models import resnet50
            backbone = resnet50(pretrained=True)
            # Remove the fully connected layer
            backbone = nn.Sequential(*list(backbone.children())[:-1])
            feature_dim = 2048
        elif model_name == 'vgg16':
            from torchvision.models import vgg16
            backbone = vgg16(pretrained=True)
            backbone = nn.Sequential(*list(backbone.features.children()))
            feature_dim = 512
        else:
            raise ValueError(f"Unsupported model: {model_name}")
        return backbone, feature_dim

# Instantiate the model
model = BoundingBox3DNet(num_bins=8)
print(model)




BoundingBox3DNet(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): C

In [10]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
import os


class BoundingBox3DDataset(Dataset):
    def __init__(self, image_dir, data_2d_dir, data_3d_dir, camera_intrinsics, transform=None):
        self.image_dir = image_dir
        self.data_2d_dir = data_2d_dir
        self.data_3d_dir = data_3d_dir
        self.camera_intrinsics = camera_intrinsics
        self.transform = transform

        # Match files across all directories
        self.file_indices = [
            os.path.splitext(f)[0]
            for f in os.listdir(self.image_dir)
            if f.endswith(".png") and
            os.path.exists(os.path.join(self.data_2d_dir, f"{os.path.splitext(f)[0]}.csv")) and
            os.path.exists(os.path.join(self.data_3d_dir, f"{os.path.splitext(f)[0]}.csv"))
        ]
        print(f"Found {len(self.file_indices)} matching files.")

    def __len__(self):
        return sum(self._get_num_objects(file_index) for file_index in self.file_indices)

    def _get_num_objects(self, file_index):
        """Helper function to count the number of objects in a file."""
        data_3d_path = os.path.join(self.data_3d_dir, f"{file_index}.csv")
        data_3d = pd.read_csv(data_3d_path, header=None)
        return len(data_3d)

    def __getitem__(self, global_idx):
        # Map global index to file index and object index within that file
        current_idx = global_idx
        for file_index in self.file_indices:
            num_objects = self._get_num_objects(file_index)
            if current_idx < num_objects:
                object_idx = current_idx
                break
            current_idx -= num_objects
        else:
            raise IndexError("Index out of range")

        # Load image
        image_path = os.path.join(self.image_dir, f"{file_index}.png")
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Load 2D data
        data_2d_path = os.path.join(self.data_2d_dir, f"{file_index}.csv")
        data_2d = pd.read_csv(data_2d_path, header=None).iloc[object_idx].values.flatten()
        bb_center = torch.tensor(data_2d[1:3], dtype=torch.float32)
        bb_size = torch.tensor(data_2d[3:5], dtype=torch.float32)
        corners_2d = torch.tensor(data_2d[5:], dtype=torch.float32)

        # Load 3D data
        data_3d_path = os.path.join(self.data_3d_dir, f"{file_index}.csv")
        data_3d = pd.read_csv(data_3d_path, header=None).iloc[object_idx].values.flatten()

        # Validate 3D corner data
        if len(data_3d[7:]) != 24:
            raise ValueError(f"Skipping invalid object in file {file_index}: Expected 24 corner values, got {len(data_3d[7:])}")

        corners_3d = torch.tensor(data_3d[7:], dtype=torch.float32).reshape(8, 3)

        # Calculate dimensions (dx, dy, dz)
        dims = torch.max(corners_3d, dim=0).values - torch.min(corners_3d, dim=0).values

        # Camera intrinsics
        camera_intrinsics = torch.tensor(self.camera_intrinsics, dtype=torch.float32)

        sample = {
            "image": image,
            "bb_center": bb_center,
            "bb_size": bb_size,
            "corners_2d": corners_2d,
            "position": torch.tensor(data_3d[1:4], dtype=torch.float32),
            "rotation": torch.tensor(data_3d[4:7], dtype=torch.float32),
            "corners_3d": corners_3d,
            "dims": dims,
            "camera_intrinsics": camera_intrinsics,
        }
        return sample



# Camera intrinsics matrix
camera_intrinsics = np.array([
    [1108.513, 0, 640, 0],
    [0, 623.5383, 360, 0],
    [0, 0, 1, 0],
    [0, 0, 0, 1]
])

# Paths to directories
image_dir = r"C:\Users\sakar\OneDrive\mt-datas\synthetic_data\8_correct_relative\images\train"
data_2d_dir = r"C:\Users\sakar\OneDrive\mt-datas\synthetic_data\8_correct_relative\2d_data"
data_3d_dir = r"C:\Users\sakar\OneDrive\mt-datas\synthetic_data\8_correct_relative\3d_data"

# Define transformations for images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Initialize dataset and dataloaders
dataset = BoundingBox3DDataset(image_dir, data_2d_dir, data_3d_dir, camera_intrinsics, transform=transform)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)


Found 2400 matching files.


In [13]:
import torch
import torch.nn.functional as F
import numpy as np
import os

# Loss functions
def multibin_loss(conf, residuals, gt_orientation, num_bins=8):
    bin_width = 2 * np.pi / num_bins
    bin_indices = (gt_orientation // bin_width).long()
    bin_indices = torch.clamp(bin_indices, min=0, max=conf.size(1) - 1)  # Ensure indices are valid

    # Confidence loss
    conf_loss = F.cross_entropy(conf, bin_indices)

    # Residual loss
    delta = gt_orientation - bin_indices * bin_width
    sin_gt, cos_gt = torch.sin(delta), torch.cos(delta)
    residual_loss = F.mse_loss(residuals[:, :, 0], sin_gt) + F.mse_loss(residuals[:, :, 1], cos_gt)

    return conf_loss + residual_loss

def dimension_loss(pred_dims, gt_dims):
    return F.mse_loss(pred_dims, gt_dims)

def translation_loss(pred_trans, gt_trans):
    return F.mse_loss(pred_trans, gt_trans)

# Early stopping class
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

# Training loop with early stopping
def train_with_early_stopping(
    model, train_loader, val_loader, optimizer, num_epochs=20, device="cuda", patience=5, save_dir="./models"
):
    """
    Train the model with early stopping and save the best model.

    Args:
        model: The model to train.
        train_loader: DataLoader for training data.
        val_loader: DataLoader for validation data.
        optimizer: Optimizer for training.
        num_epochs: Number of epochs to train.
        device: Device to train on ('cuda' or 'cpu').
        patience: Number of epochs to wait before stopping if no improvement.
        save_dir: Directory to save the model weights.
    """
    # Ensure the save directory exists
    os.makedirs(save_dir, exist_ok=True)

    model = model.to(device)
    early_stopping = EarlyStopping(patience=patience)
    best_model = None
    best_epoch = -1
    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()

            # Move data to device
            images = batch["image"].to(device)
            gt_dims = batch["dims"].to(device)  # Use 3D dimensions (dx, dy, dz)
            gt_orientations = batch["rotation"].to(device)
            gt_positions = batch["position"].to(device)

            # Forward pass
            pred_dims, yaw_conf, yaw_res, pitch_conf, pitch_res, roll_conf, roll_res, pred_trans, _ = model(images)

            # Extract individual ground-truth orientation values
            gt_yaw = gt_orientations[:, 0]  # Yaw
            gt_pitch = gt_orientations[:, 1]  # Pitch
            gt_roll = gt_orientations[:, 2]  # Roll

            # Loss calculation
            loss_dim = dimension_loss(pred_dims, gt_dims)
            loss_yaw = multibin_loss(yaw_conf, yaw_res, gt_yaw)
            loss_pitch = multibin_loss(pitch_conf, pitch_res, gt_pitch)
            loss_roll = multibin_loss(roll_conf, roll_res, gt_roll)
            loss_orient = loss_yaw + loss_pitch + loss_roll
            loss_trans = translation_loss(pred_trans, gt_positions)

            # Total loss
            loss = loss_dim + loss_orient + loss_trans
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {train_loss:.4f}")

        # Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                images = batch["image"].to(device)
                gt_dims = batch["dims"].to(device)  # Use 3D dimensions (dx, dy, dz)
                gt_orientations = batch["rotation"].to(device)
                gt_positions = batch["position"].to(device)

                pred_dims, yaw_conf, yaw_res, pitch_conf, pitch_res, roll_conf, roll_res, pred_trans, _ = model(images)

                gt_yaw = gt_orientations[:, 0]
                gt_pitch = gt_orientations[:, 1]
                gt_roll = gt_orientations[:, 2]

                loss_dim = dimension_loss(pred_dims, gt_dims)
                loss_yaw = multibin_loss(yaw_conf, yaw_res, gt_yaw)
                loss_pitch = multibin_loss(pitch_conf, pitch_res, gt_pitch)
                loss_roll = multibin_loss(roll_conf, roll_res, gt_roll)
                loss_orient = loss_yaw + loss_pitch + loss_roll
                loss_trans = translation_loss(pred_trans, gt_positions)

                val_loss += (loss_dim + loss_orient + loss_trans).item()

        val_loss /= len(val_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}] - Validation Loss: {val_loss:.4f}")

        # Save model if validation loss improves
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = epoch
            best_model = model.state_dict()
            model_save_path = os.path.join(save_dir, f"best_model_epoch_{epoch+1}.pth")
            torch.save(best_model, model_save_path)
            print(f"Model saved at {model_save_path}")

        # Check early stopping
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break

    # Final save of the best model
    final_model_path = os.path.join(save_dir, "best_model_final.pth")
    if best_model:
        torch.save(best_model, final_model_path)
        print(f"Final best model saved at {final_model_path}")
    else:
        print("No best model was saved.")


# Initialize model and optimizer
model = BoundingBox3DNet(num_bins=8)  # Ensure this aligns with the model definition
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Start training with early stopping and model saving
train_with_early_stopping(
    model, train_loader, val_loader, optimizer, num_epochs=20, patience=5, save_dir="./models"
)




RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
