# Model imports


In [1]:
"""
Model for Bracket Point Prediction
Place this file in: pointcept/models/bracket_point_model.py
"""

import torch
import torch.nn as nn
from pointcept.models.builder import MODELS
from pointcept.models.losses import build_criteria

# Dataset Imports

In [1]:
import os
import json
import numpy as np
#import torch
from torch.utils.data import Dataset
from pointcept.datasets.builder import DATASETS
from pointcept.datasets.transform import Compose
import trimesh
from torch.utils.data import DataLoader

# Model

In [3]:
class BracketPointPredictor(nn.Module):
    """
    Model for predicting bracket_point (3D coordinate) from point clouds.
    
    Args:
        backbone (dict): Backbone network config
        criteria (list): Loss functions
    """
    
    def __init__(self, backbone, criteria):
        super().__init__()
        from pointcept.models.builder import build_model
        
        self.backbone = build_model(backbone)
        self.criteria = build_criteria(criteria)
        
    def forward(self, data_dict):
        """
        Forward pass.
        
        Args:
            data_dict (dict): Input data dictionary containing:
                - feat: [B, N, C] point features
                - coord: [B, N, 3] point coordinates
                - bracket_point: [B, 3] target bracket point (only in training)
        
        Returns:
            dict: Predictions and losses
        """
        # Extract features
        feat = data_dict["feat"]  # [B, N, C]
        
        # Forward through backbone
        pred = self.backbone(feat)  # [B, num_classes] or needs adaptation
        
        # If backbone outputs per-point features, aggregate them
        if len(pred.shape) == 3:  # [B, N, C]
            # Global pooling: mean + max
            pred_mean = torch.mean(pred, dim=1)  # [B, C]
            pred_max, _ = torch.max(pred, dim=1)  # [B, C]
            pred = torch.cat([pred_mean, pred_max], dim=1)  # [B, 2*C]
            
            # Add final prediction head if needed
            if not hasattr(self, 'head'):
                self.head = nn.Linear(pred.shape[1], 3).to(pred.device)
            pred = self.head(pred)
        
        result_dict = {"bracket_point_pred": pred}
        
        # Calculate loss if in training mode
        if self.training and "bracket_point" in data_dict:
            target = data_dict["bracket_point"]  # [B, 3]
            loss = self.criteria(pred, target)
            result_dict["loss"] = loss
        
        return result_dict

class SimplePointNet(nn.Module):
    """
    Simple PointNet-style backbone for regression.
    Can be used as an alternative to PointNet++.
    """
    
    def __init__(self, in_channels=6, hidden_dim=128, num_classes=3):
        super().__init__()
        
        # Point-wise MLPs
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(64, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
        )
        self.conv3 = nn.Sequential(
            nn.Conv1d(128, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(),
        )
        
        # Global feature extraction
        self.conv4 = nn.Sequential(
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
        )
        
        # Fully connected layers for regression
        self.fc1 = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        self.fc2 = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        self.fc3 = nn.Linear(128, num_classes)
        
    def forward(self, x):
        """
        Args:
            x: [B, N, C] point features
        Returns:
            [B, num_classes] predictions
        """
        # Transpose for Conv1d: [B, C, N]
        x = x.transpose(1, 2)
        
        # Feature extraction
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        
        # Global max pooling
        x = torch.max(x, dim=2)[0]  # [B, 512]
        
        # Regression head
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        
        return x

class SimpleBracketPredictor(nn.Module):
    """
    Simple end-to-end model for bracket point prediction.
    """
    
    def __init__(self, in_channels=3, hidden_dim=128, criteria=None):
        super().__init__()
        self.backbone = SimplePointNet(in_channels, hidden_dim, num_classes=3)
        self.criteria = build_criteria(criteria) if criteria else nn.MSELoss()
        
    def forward(self, data_dict):
        feat = data_dict["coord"]  # [B, N, C]
        pred = self.backbone(feat)
        
        result_dict = {"bracket_point_pred": pred}
        
        if self.training and "bracket_point" in data_dict:
            target = data_dict["bracket_point"]
            loss = self.criteria(pred, target)
            result_dict["loss"] = loss
        
        return result_dict

# Dataset

In [2]:
"""
Custom Dataset for Bracket Point Prediction from STL files
Place this file in: pointcept/datasets/custom_bracket_dataset.py
"""

# @DATASETS.register_module()
class BracketPointDataset(Dataset):
    """
    Dataset for predicting bracket_point from STL files.
    
    Args:
        data_root (str): Root directory containing stl and json files
        split (str): 'train', 'val', or 'test'
        transform (list): List of transforms to apply
        test_mode (bool): Whether in test mode
        loop (int): Number of times to loop through dataset (for training)
    """
    
    def __init__(
        self,
        data_root,
        split="train",
        transform=None,
        test_mode=False,
        loop=1,
    ):
        super().__init__()
        self.data_root = data_root
        self.split = split
        self.transform = Compose(transform) if transform is not None else None
        self.test_mode = test_mode
        self.loop = loop
        
        # Get all STL files
        self.data_list = self._load_data_list()
        
    def _load_data_list(self):
        """Load list of data samples."""
        split_file = os.path.join(self.data_root, f"{self.split}.txt")
        
        if os.path.exists(split_file):
            # If split file exists, use it
            with open(split_file, 'r') as f:
                file_names = [line.strip() for line in f.readlines()]
        else:
            # Otherwise, use all STL files in the directory
            file_names = [f.replace('.stl', '') for f in os.listdir(self.data_root) 
                         if f.endswith('.stl')]
            
            # Optionally split into train/val/test
            # maybie need a shuffle here?
            if self.split == 'train':
                file_names = file_names[:int(len(file_names) * 0.8)]
            elif self.split == 'val':
                file_names = file_names[int(len(file_names) * 0.8):int(len(file_names) * 0.9)]
            elif self.split == 'test':
                file_names = file_names[int(len(file_names) * 0.9):]
        
        return file_names
    
    def _load_stl(self, stl_path):
        """Load STL file and extract point cloud."""
        mesh = trimesh.load(stl_path, force='mesh')
        points, face_indices = trimesh.sample.sample_surface(mesh, count=1024)
        normals = mesh.face_normals[face_indices]
        
        return points.astype(np.float32), normals.astype(np.float32)
    
    def _load_json(self, json_path):
        """Load JSON file and extract bracket_point."""
        with open(json_path, 'r') as f:
            data = json.load(f)
        
        bracket_point = np.array(data['bracket_point'], dtype=np.float32)
        return bracket_point
    
    def __getitem__(self, idx):
        # Handle looping
        idx = idx % len(self.data_list)
        
        # Get file name
        file_name = self.data_list[idx]
        
        # Load STL and JSON
        stl_path = os.path.join(self.data_root, f"{file_name}.stl")
        json_path = os.path.join(self.data_root, f"{file_name}.json")
        
        # Load point cloud from STL
        coord, normal = self._load_stl(stl_path)
        
        # Load target bracket_point
        bracket_point = self._load_json(json_path)
        
        # Create data dict
        data_dict = {
            "coord": coord,
            "normal": normal,
            "name": file_name,
            "bracket_point": bracket_point,
        }
        
        # Apply transforms
        if self.transform is not None:
            data_dict = self.transform(data_dict)
        
        return data_dict
    
    def __len__(self):
        return len(self.data_list) * self.loop


# Optional: Custom collate function if needed
def collate_fn(batch):
    """Custom collate function for batching."""
    # Stack coordinates, normals, and bracket_points
    coords = torch.stack([torch.from_numpy(item['coord']) for item in batch])
    normals = torch.stack([torch.from_numpy(item['normal']) for item in batch])
    bracket_points = torch.stack([torch.from_numpy(item['bracket_point']) for item in batch])
    
    batch_dict = {
        'coord': coords,
        'normal': normals,
        'bracket_point': bracket_points,
        'name': [item['name'] for item in batch],
    }
    
    # Include any other keys from transforms
    for key in batch[0].keys():
        if key not in ['coord', 'normal', 'bracket_point', 'name']:
            if isinstance(batch[0][key], np.ndarray):
                batch_dict[key] = torch.stack([torch.from_numpy(item[key]) for item in batch])
    
    return batch_dict

In [3]:
dataset = BracketPointDataset("/work/grana_maxillo/Mlugli/brackets_melted/flattened", "val")
train_loader = DataLoader(dataset, batch_size = 4, shuffle=True, drop_last=True)

In [None]:
for i,x in enumerate(train_loader):
    print(f"{i}/{len(train_loader)}")
    data_dict = {k: v.cuda() if torch.is_tensor(v) else v for k, v in x.items()}

In [3]:
from pointcept.models.builder import build_model
from torch_scatter import segment_csr  
from pointcept.models.utils.structure import Point  

class VoxelBracketPredictor(nn.Module):  
    """  
    Voxel-based backbone + regression head for 3D point prediction.  
    """  
      
    def __init__(  
        self,
        backbone,  
        backbone_out_channels=96,  
        output_dim=3,  # 3D point coordinates  
    ):  
        super().__init__()  
          
        self.backbone = build_model(backbone)  
          
        # Regression head: outputs 3D point coordinates  
        self.head = nn.Sequential(  
            nn.Linear(backbone_out_channels, 256),  
            nn.BatchNorm1d(256),  
            nn.ReLU(inplace=True),  
            nn.Dropout(p=0.3),  
            nn.Linear(256, 128),  
            nn.BatchNorm1d(128),  
            nn.ReLU(inplace=True),  
            nn.Dropout(p=0.3),  
            nn.Linear(128, output_dim),  
        )  
      
    def forward(self, input_dict):    
        # Pass through backbone    
        point = self.backbone(input_dict)    
            
        # Handle Point structure from voxel-based backbones    
        if isinstance(point, Point):    
            # Global average pooling across all points    
            point.feat = segment_csr(    
                src=point.feat,    
                indptr=nn.functional.pad(point.offset, (1, 0)),    
                reduce="mean",    
            )    
            feat = point.feat    
        else:    
            feat = point    
            
        # Predict 3D point    
        bracket_point_pred = self.head(feat)    
            
        out = {"bracket_point_pred": bracket_point_pred}  # Add predictions to output  
        
        # Compute MSE loss if ground truth available    
        if "bracket_point" in input_dict:    
            target = input_dict["bracket_point"]    
            loss = nn.functional.mse_loss(bracket_point_pred, target)    
            out["loss"] = loss    
            
        return out

In [4]:
model = VoxelBracketPredictor(    
    backbone=dict(  
        type="SpUNet-v1m1",  
        in_channels=3,  # xyz coordinates only  
        num_classes=0,  
        channels=(32, 64, 128, 256, 256, 128, 96, 96),  
        layers=(2, 3, 4, 6, 2, 2, 2, 2),  
    ),  
    backbone_out_channels=96,  
    output_dim=3, )

In [None]:
x = torch.rand()
# need to know the shape of the input first.