# End-to-End 3D Detection & Segmentation

In [None]:
!git clone --recursive https://github.com/rafaymhddn/PCT_pytorch.git

In [None]:
!cd PCT_pytorch && git pull

#### Intsall Libs / Dependencies

In [None]:
!python -c "import torch; print(torch.__version__)"
!pip install open3d
!pip install plotly
!pip install ninja
!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'


#### Import Libraries

In [None]:
import os
import sys
import torch
import subprocess
import glob
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from vis import *
%matplotlib inline

### DataLoader

In [None]:
def custom_collate(batch):
    return {
        'point_cloud': [item['point_cloud'] for item in batch],
        'mask': [item['mask'] for item in batch],
        'bbox3d': [item['bbox3d'] for item in batch], 
        'centroid': [item['centroid'] for item in batch]
    }

class PickPlaceDataset(Dataset):
    def __init__(self, root_dir, sample_ids,  augment=False):
        self.root_dir = root_dir
        self.sample_ids = sample_ids
        self.augment = augment

    def __len__(self):
        return len(self.sample_ids)

    def __getitem__(self, idx):
        sample_id = self.sample_ids[idx]
        sample_path = os.path.join(self.root_dir, sample_id)

        pc = np.load(os.path.join(sample_path, 'pc.npy'))        # [3, H, W] or [H, W, 3] ?
        mask = np.load(os.path.join(sample_path, 'mask.npy'))    # [N, H, W]
        bbox = np.load(os.path.join(sample_path, 'bbox3d.npy'))  #  [N, 8, 3]

        H, W = pc.shape[1], pc.shape[2]
        pc = pc.reshape(3, -1).transpose(1, 0)  # shape: [Points, 3]

        # Reshape mask: [N, H, W] → [N, H*W]
        mask = mask.reshape(mask.shape[0], -1)  # shape: [N, Points]

        # Flip ?

        if self.augment:
            pc, bbox = self.apply_augmentations(pc, bbox)

        


        return {
            'point_cloud': pc.astype(np.float32),
            'mask': mask.astype(np.int64),
            'bbox3d': bbox.astype(np.float32),
            'centroid': bbox.mean(axis=1).astype(np.float32)

        }
    
    def apply_augmentations(self, pc, bbox):
        """
        Applies a series of random augmentations to the point cloud and bounding box.
        Assumes pc is [N, 3] and bbox is [N, 8, 3].
        """
        # 1. Random Rotation around the Z-axis
        angle = random.uniform(0, 2 * np.pi)
        rotation_matrix = np.array([
            [np.cos(angle), -np.sin(angle), 0],
            [np.sin(angle),  np.cos(angle), 0],
            [0,              0,             1]
        ])
        pc = pc @ rotation_matrix.T
        bbox = bbox @ rotation_matrix.T

        # 2. Random Scaling
        scale = random.uniform(0.9, 1.1)
        pc *= scale
        bbox *= scale

        # 3. Random Jitter (translation)
        jitter = (np.random.rand(1, 3) - 0.5) * 0.02
        pc += jitter
        bbox += jitter
        
        return pc, bbox

In [48]:
def get_dataloaders(data_root, batch_size=16, seed=42, train_size=0.8):
        
        all_sample_ids = sorted(os.listdir(data_root))
        all_sample_ids = [s for s in all_sample_ids if os.path.isdir(os.path.join(data_root, s))]

        random.seed(seed)
        random.shuffle(all_sample_ids)

        train_ids, temp_ids = train_test_split(
        all_sample_ids, test_size=(1-train_size), random_state=seed
        )
    
        # test and val split 50:50
        val_ids, test_ids = train_test_split(
        temp_ids, test_size=(0.5), random_state=seed
        )

        train_set = PickPlaceDataset(data_root, train_ids,  augment=True)
        val_set = PickPlaceDataset(data_root, val_ids)
        test_set = PickPlaceDataset(data_root, test_ids)

        print(f"Dataset sizes - Train: {len(train_set)}, Val: {len(val_set)}, Test: {len(test_set)}")

        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
        val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)
        test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)

        return train_loader, val_loader, test_loader

In [51]:
import plotly.graph_objects as go
import plotly.express as px
import numpy as np

def visualize_sample_plotly(pc, mask, bbox3d, show_background=False):
    """
    Interactive Plotly visualization of 3D point cloud with instance masks, centroids, and 3D bounding boxes.

    Args:
        pc: [P, 3] point cloud
        mask: [N, P] binary masks
        bbox3d: [N, 8, 3] bounding boxes with 8 corner points
        show_background: bool, whether to display unlabeled points (background)
    """
    pc = np.asarray(pc)
    mask = np.asarray(mask)
    bbox3d = np.asarray(bbox3d)

    P = pc.shape[0]
    N = mask.shape[0]

    # Assign each point its instance ID
    instance_ids = np.full((P,), -1)
    for idx in range(N):
        instance_ids[mask[idx] > 0] = idx

    # Color map
    colorscale = px.colors.qualitative.Dark24  # 24 distinct colors
    num_colors = len(colorscale)

    fig = go.Figure()

    # Add point cloud per instance
    centroids = []
    for i in range(N):
        inds = np.where(instance_ids == i)[0]
        if inds.size == 0:
            centroids.append(np.array([np.nan, np.nan, np.nan]))
            continue
        pts = pc[inds]
        cent = pts.mean(axis=0)
        centroids.append(cent)

        fig.add_trace(go.Scatter3d(
            x=pts[:, 0], y=pts[:, 1], z=pts[:, 2],
            mode='markers',
            marker=dict(size=2, color=colorscale[i % num_colors]),
            showlegend=False  # Remove legend entries for points
        ))

        # Add centroid point
        fig.add_trace(go.Scatter3d(
            x=[cent[0]], y=[cent[1]], z=[cent[2]],
            mode='markers',
            marker=dict(size=7, color='black', symbol='x'),
            showlegend=False  # No legend for centroids
        ))

    centroids = np.vstack(centroids)  # Shape: [N, 3]

    # Add unlabeled/background points if enabled
    if show_background:
        bg_inds = np.where(instance_ids == -1)[0]
        if bg_inds.size > 0:
            pts = pc[bg_inds]
            fig.add_trace(go.Scatter3d(
                x=pts[:, 0], y=pts[:, 1], z=pts[:, 2],
                mode='markers',
                marker=dict(size=1, color='gray', opacity=0.3),
                showlegend=False  # Remove legend for background points
            ))

    # Draw bounding boxes
    edges = [
        (0, 1), (1, 2), (2, 3), (3, 0),
        (4, 5), (5, 6), (6, 7), (7, 4),
        (0, 4), (1, 5), (2, 6), (3, 7)
    ]

    for i, box in enumerate(bbox3d):
        for s, e in edges:
            x, y, z = zip(box[s], box[e])
            fig.add_trace(go.Scatter3d(
                x=x, y=y, z=z,
                mode='lines',
                line=dict(color='black', width=2),
                showlegend=False  # Remove legend for bbox lines
            ))

    fig.update_layout(
        title='3D Point Cloud with Instance Masks, Centroids, and BBoxes',
        scene=dict(
            xaxis_title='X', yaxis_title='Y', zaxis_title='Z',
            aspectmode='data',
        ),
        showlegend=False  # Remove all legends
    )

    fig.show()



In [None]:

data_root = 'data/pick_place'
train_loader, val_loader, test_loader = get_dataloaders(data_root)

for batch in train_loader:
        pc = batch['point_cloud']   # [Points, 3]
        mask = batch['mask']        # [N, Points]
        bbox = batch['bbox3d']      # [N, 8, 3] 
        centroid = batch['centroid']# [N, 3] 

        print("point_cloud shape:", np.array(pc[0]).shape)
        print("mask shape:", np.array(mask[0]).shape)
        print("bbox3d shape:", np.array(bbox[0]).shape)
        print("centroid shape:", np.array(centroid[0]).shape)

        visualize_sample_plotly(pc[0], mask[0], bbox[0])
       
    
        break


# Model

In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from module import Embedding, NeighborEmbedding, OA, SA

In [65]:
# Encoder
class NaivePCT(nn.Module):
    def __init__(self):
        super().__init__()

        self.embedding = Embedding(3, 128)

        self.sa1 = SA(128)
        self.sa2 = SA(128)
        self.sa3 = SA(128)
        self.sa4 = SA(128)

        self.linear = nn.Sequential(
            nn.Conv1d(512, 1024, kernel_size=1, bias=False),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(negative_slope=0.2)
        )
    
    def forward(self, x):
        x = self.embedding(x)
        
        x1 = self.sa1(x)
        x2 = self.sa2(x1)
        x3 = self.sa3(x2)
        x4 = self.sa4(x3)
        x = torch.cat([x1, x2, x3, x4], dim=1)

        x = self.linear(x)

        # x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
        x_max = torch.max(x, dim=-1)[0]
        x_mean = torch.mean(x, dim=-1)

        return x, x_max, x_mean


In [62]:
class Segmentation(nn.Module):
    def __init__(self, part_num):
        super().__init__()

        self.part_num = part_num

        self.label_conv = nn.Sequential(
            nn.Conv1d(16, 64, kernel_size=1, bias=False),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(negative_slope=0.2)
        )

        self.convs1 = nn.Conv1d(1024 * 3 + 64, 512, 1)
        self.convs2 = nn.Conv1d(512, 256, 1)
        self.convs3 = nn.Conv1d(256, self.part_num, 1)

        self.bns1 = nn.BatchNorm1d(512)
        self.bns2 = nn.BatchNorm1d(256)

        self.dp1 = nn.Dropout(0.5)
    
    def forward(self, x, x_max, x_mean, cls_label):
        batch_size, _, N = x.size()

        x_max_feature = x_max.unsqueeze(-1).repeat(1, 1, N)
        x_mean_feature = x_mean.unsqueeze(-1).repeat(1, 1, N)

        cls_label_one_hot = cls_label.view(batch_size, 16, 1)
        cls_label_feature = self.label_conv(cls_label_one_hot).repeat(1, 1, N)

        x = torch.cat([x, x_max_feature, x_mean_feature, cls_label_feature], dim=1)  # 1024 * 3 + 64

        x = F.relu(self.bns1(self.convs1(x)))
        x = self.dp1(x)
        x = F.relu(self.bns2(self.convs2(x)))
        x = self.convs3(x)

        return x

In [None]:
class NaivePCTSeg(nn.Module):
    def __init__(self, part_num=50):
        super().__init__()
    
        self.encoder = NaivePCT()
        self.seg = Segmentation(part_num)

    def forward(self, x, cls_label):
        x, x_max, x_mean = self.encoder(x)
        x = self.seg(x, x_max, x_mean, cls_label)
        return x
    

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class SimplePointNet(nn.Module):
    def __init__(self, feat_dim=64):
        super(SimplePointNet, self).__init__()
        self.feat_dim = feat_dim
        
        # Point-wise feature extraction (shared MLP)
        self.mlp1 = nn.Linear(3, 64)
        self.mlp2 = nn.Linear(64, 128)
        self.mlp3 = nn.Linear(128, feat_dim)

        # Segmentation Head
        self.seg_head = nn.Sequential(
            nn.Linear(feat_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),  # output: foreground mask probability
        )

        # Detection Head (centroid regression + bbox regression)
        self.centroid_head = nn.Sequential(
            nn.Linear(feat_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 3)  # centroid (x, y, z)
        )

        self.bbox_head = nn.Sequential(
            nn.Linear(feat_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 24),  # 8 corners × 3 coords
        )

    def forward(self, pc):  # pc: [P, 3]
        x = F.relu(self.mlp1(pc))
        x = F.relu(self.mlp2(x))
        features = self.mlp3(x)  # [P, feat_dim]

        # Segmentation (instance-agnostic for now)
        seg_logits = self.seg_head(features).squeeze(-1)  # [P]

        # Global feature for detection (simple max pool)
        global_feat = torch.max(features, dim=0, keepdim=True)[0]  # [1, feat_dim]

        centroid_pred = self.centroid_head(global_feat).squeeze(0)  # [3]
        bbox_pred = self.bbox_head(global_feat).view(8, 3)  # [8, 3]

        return {
            "seg_logits": seg_logits,       # [P]
            "centroid_pred": centroid_pred, # [3]
            "bbox_pred": bbox_pred          # [8, 3]
        }


In [None]:
batch = {
    'point_cloud': torch.randn(2048, 3),
    'mask': torch.randint(0, 2, (4, 2048)),
    'bbox3d': torch.randn(4, 8, 3),
    'centroid': torch.randn(4, 3),
}

model = SimplePointNet()
out = model(batch['point_cloud'])
print(out)


In [None]:
# testing dimension

import torch
import numpy as np

# Use MPS device if available
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Instantiate and move model to MPS
model = SimplePointNet().to(device)
model.eval()

for batch in train_loader:
    pc = batch['point_cloud']   # [B, Points, 3]
    mask = batch['mask']        # [B, N, Points]
    bbox = batch['bbox3d']      # [B, N, 8, 3]
    centroid = batch['centroid']# [B, N, 3]

    # Use only the first sample in the batch
    pc_sample = pc[0]           # [Points, 3]
    mask_sample = mask[0]       # [N, Points]
    bbox_sample = bbox[0]       # [N, 8, 3]
    centroid_sample = centroid[0] # [N, 3]

    print("point_cloud shape:", np.array(pc_sample).shape)
    print("mask shape:", np.array(mask_sample).shape)
    print("bbox3d shape:", np.array(bbox_sample).shape)
    print("centroid shape:", np.array(centroid_sample).shape)

    # Optional: still visualize using NumPy
    #visualize_sample_plotly(pc_sample, mask_sample, bbox_sample)

    # ✅ Convert to Tensor and move to MPS device
    pc_tensor = torch.tensor(pc_sample, dtype=torch.float32).to(device)

    # Forward pass
    with torch.no_grad():
        outputs = model(pc_tensor)

    print("segmentation logits shape:", outputs['seg_logits'].shape)
    print("predicted centroid:", outputs['centroid_pred'].cpu().numpy())
    print("predicted bbox shape:", outputs['bbox_pred'].cpu().numpy())

    break


In [None]:
# over fit to one instance of example
import torch
import torch.nn.functional as F
import numpy as np

# Device setup (MPS if available, else CPU)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Instantiate model and move to device
model = SimplePointNet().to(device)
model.train()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Get one batch/sample from train_loader
for batch in train_loader:
    pc = batch['point_cloud'][0]     # [Points, 3]
    mask = batch['mask'][0]          # [N, Points]
    bbox = batch['bbox3d'][0]        # [N, 8, 3]
    centroid = batch['centroid'][0]  # [N, 3]
    break  # only one batch/sample

# Convert inputs & GT to tensors on device
pc_tensor = torch.tensor(pc, dtype=torch.float32).to(device)
mask_tensor = torch.tensor(mask, dtype=torch.float32).to(device)   # shape: [N, Points]
bbox_tensor = torch.tensor(bbox, dtype=torch.float32).to(device)   # shape: [N, 8, 3]
centroid_tensor = torch.tensor(centroid, dtype=torch.float32).to(device) # [N, 3]

# For simplicity, assume N=1 (one instance) or pick first instance to overfit
mask_single = mask_tensor[0]      # [Points]
bbox_single = bbox_tensor[0]      # [8,3]
centroid_single = centroid_tensor[0]  # [3]

num_epochs = 500

for epoch in range(num_epochs):
    optimizer.zero_grad()

    # Forward pass
    outputs = model(pc_tensor)  # outputs: dict with keys 'seg_logits', 'centroid_pred', 'bbox_pred'

    # seg_logits: [Points] raw scores for binary segmentation (point belongs to object or not)
    seg_logits = outputs['seg_logits']  # [Points]
    centroid_pred = outputs['centroid_pred']  # [3]
    bbox_pred = outputs['bbox_pred']  # [8,3]

    # Segmentation loss (Binary Cross Entropy)
    seg_loss = F.binary_cross_entropy_with_logits(seg_logits, mask_single)

    # Centroid loss (L2)
    centroid_loss = F.mse_loss(centroid_pred, centroid_single)

    # Bounding box loss (L2)
    bbox_loss = F.mse_loss(bbox_pred, bbox_single)

    # Total loss
    total_loss = seg_loss + centroid_loss + bbox_loss

    # Backpropagation
    total_loss.backward()
    optimizer.step()

    if (epoch + 1) % 50 == 0 or epoch == 0:
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {total_loss.item():.4f}")

print("Finished training to overfit one example.")


In [None]:
import torch
import numpy as np

# Make sure model is in eval mode
model.eval()

# Convert example data to tensors on device (reuse from training)
pc_tensor = torch.tensor(pc, dtype=torch.float32).to(device)
mask_tensor = torch.tensor(mask, dtype=torch.float32).to(device)
bbox_tensor = torch.tensor(bbox, dtype=torch.float32).to(device)
centroid_tensor = torch.tensor(centroid, dtype=torch.float32).to(device)

# Pick first instance (index 0)
mask_single = mask_tensor[0]      # [Points]
bbox_single = bbox_tensor[0]      # [8,3]
centroid_single = centroid_tensor[0]  # [3]

with torch.no_grad():
    outputs = model(pc_tensor)

    seg_logits = outputs['seg_logits']  # [Points]
    centroid_pred = outputs['centroid_pred']  # [3]
    bbox_pred = outputs['bbox_pred']  # [8, 3]

    # Convert segmentation logits to probabilities
    seg_probs = torch.sigmoid(seg_logits)

    # Binarize mask prediction (threshold=0.5)
    seg_pred = (seg_probs > 0.5).float()

# Move tensors back to CPU & numpy for printing/comparison
seg_pred_np = seg_pred.cpu().numpy()
seg_gt_np = mask_single.cpu().numpy()

centroid_pred_np = centroid_pred.cpu().numpy()
centroid_gt_np = centroid_single.cpu().numpy()

bbox_pred_np = bbox_pred.cpu().numpy()
bbox_gt_np = bbox_single.cpu().numpy()

# Print comparison metrics
print("Segmentation Accuracy: {:.4f}".format((seg_pred_np == seg_gt_np).mean()))
print("Centroid Prediction (pred vs GT):")
print(centroid_pred_np)
print(centroid_gt_np)
print("Bounding Box Prediction vs GT (first 3 points):")
print(bbox_pred_np[:3])
print(bbox_gt_np[:3])


In [None]:
print("Visualizing Ground Truth:")
visualize_sample_plotly(pc, mask, bbox)

In [None]:
print("Visualizing Prediction:")
# Make predicted bbox shape [1, 8, 3] for visualize_sample_plotly
bbox_pred_np_exp = np.expand_dims(bbox_pred_np, axis=0)

visualize_sample_plotly(pc, seg_pred_np, bbox_pred_np_exp)

In [None]:
import torch
import torch.nn.functional as F
import numpy as np

# Device setup (MPS if available, else CPU)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Define fixed number of points per cloud
NUM_POINTS = 100000

# Sample or pad a point cloud or mask to NUM_POINTS
def sample_points(x, num_points=NUM_POINTS):
    x = torch.tensor(x)
    N = x.shape[0]
    if N >= num_points:
        idx = torch.randperm(N)[:num_points]
        return x[idx]
    else:
        pad_idx = torch.randint(0, N, (num_points - N,))
        return torch.cat([x, x[pad_idx]], dim=0)

# Instantiate model and move to device
model = SimplePointNet().to(device)
model.train()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Load a batch from train_loader
for batch in train_loader:
    pc_batch = batch['point_cloud']     # list of [P_i, 3]
    mask_batch = batch['mask']          # list of [P_i]
    bbox_batch = batch['bbox3d']        # tensor [B, 8, 3]
    centroid_batch = batch['centroid']  # tensor [B, 3]
    break  # take one batch only

# Sample each point cloud and corresponding mask
pc_tensor = torch.stack([
    sample_points(pc, NUM_POINTS) for pc in pc_batch
]).float().to(device)  # [B, NUM_POINTS, 3]

mask_tensor = torch.stack([
    sample_points(mask, NUM_POINTS) for mask in mask_batch
]).float().to(device)  # [B, NUM_POINTS]

# Convert bbox and centroid directly
bbox_tensor = torch.tensor(bbox_batch, dtype=torch.float32).to(device)         # [B, 8, 3]
centroid_tensor = torch.tensor(centroid_batch, dtype=torch.float32).to(device) # [B, 3]

num_epochs = 500

for epoch in range(num_epochs):
    optimizer.zero_grad()

    # Forward pass
    outputs = model(pc_tensor)  # dict with keys: seg_logits, centroid_pred, bbox_pred

    seg_logits = outputs['seg_logits']        # [B, NUM_POINTS]
    centroid_pred = outputs['centroid_pred']  # [B, 3]
    bbox_pred = outputs['bbox_pred']          # [B, 8, 3]

    # Losses
    seg_loss = F.binary_cross_entropy_with_logits(seg_logits, mask_tensor)
    centroid_loss = F.mse_loss(centroid_pred, centroid_tensor)
    bbox_loss = F.mse_loss(bbox_pred, bbox_tensor)

    total_loss = seg_loss + centroid_loss + bbox_loss
    print(f"Total loss: {total_loss}")

    total_loss.backward()
    optimizer.step()

    if (epoch + 1) % 50 == 0 or epoch == 0:
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {total_loss.item():.4f}")

print("Finished training to overfit batch of examples.")
