# 3D Perception Tutorial

We will explore the (1)3D object classification and the (2)instance-level semantic segmentation of indoor scenes. To get to theses goals, this tutorial will guide you through the basic concepts and practical applications of WarpConvNet for point cloud procesing, sparse voxel convolutions, and 3D attention mechanisms.

## Table of Contents
1. Basic Concepts
2. PointCloud Representation
3. Voxel Representation
4. SparseConvNet
5. Exercise 1: 3D Point Cloud Semantic Segmentation with SparseConvNet
6. Exercise 2: 3D Point Cloud Semantic Segmentation with PointTransformer

In [None]:
import subprocess
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple, List
import warnings
warnings.filterwarnings('ignore')

print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## PointCloud Representation

In [None]:
import matplotlib.pyplot as plt
import plotly
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


def visualize_point_cloud(points, colors=None, title="Point Cloud", 
                         point_size=2, colorscale='Viridis', show_axis=True):
    """
    Interactive 3D point cloud visualization using Plotly
    
    Args:
        points: numpy array of shape (N, 3)
        colors: optional colors for points (can be RGB values or scalar values)
        title: title for the plot
        point_size: size of points in the visualization
        colorscale: Plotly colorscale name
        show_axis: whether to show axis labels
    """
    if isinstance(points, torch.Tensor):
        points = points.cpu().numpy()
    
    if colors is None:
        # Color by height (z-coordinate)
        colors = points[:, 2]
    elif isinstance(colors, torch.Tensor):
        colors = colors.cpu().numpy()
    
    # Create 3D scatter plot
    fig = go.Figure(data=[go.Scatter3d(
        x=points[:, 0],
        y=points[:, 1],
        z=points[:, 2],
        mode='markers',
        marker=dict(
            size=point_size,
            color=colors if len(colors.shape) == 1 else colors[:, 0],  # Use first channel if RGB
            colorscale=colorscale,
            showscale=True,
            colorbar=dict(
                title="Value",
                thickness=20,
                len=0.7
            )
        ),
        text=[f"Point {i}<br>x: {x:.3f}<br>y: {y:.3f}<br>z: {z:.3f}" 
              for i, (x, y, z) in enumerate(points[:min(1000, len(points))])],  # Limit hover text for performance
        hovertemplate='%{text}<extra></extra>'
    )])
    
    # Update layout
    fig.update_layout(
        title=dict(
            text=title,
            x=0.5,
            xanchor='center'
        ),
        scene=dict(
            xaxis=dict(title='X' if show_axis else '', showgrid=True, gridwidth=1),
            yaxis=dict(title='Y' if show_axis else '', showgrid=True, gridwidth=1),
            zaxis=dict(title='Z' if show_axis else '', showgrid=True, gridwidth=1),
            camera=dict(
                eye=dict(x=1.5, y=1.5, z=1.5)
            ),
            aspectmode='auto'
        ),
        width=900,
        height=700,
        margin=dict(r=20, b=10, l=10, t=40),
        showlegend=False
    )
    
    fig.show()


def visualize_multiple_point_clouds(point_clouds_dict, title="Multiple Point Clouds"):
    """
    Visualize multiple point clouds in the same plot with different colors
    
    Args:
        point_clouds_dict: Dictionary with {name: points_array} pairs
        title: overall title for the plot
    """
    fig = go.Figure()
    
    # Color palette for different point clouds
    colors = px.colors.qualitative.Set1
    
    for idx, (name, points) in enumerate(point_clouds_dict.items()):
        if isinstance(points, torch.Tensor):
            points = points.cpu().numpy()
        
        color = colors[idx % len(colors)]
        
        fig.add_trace(go.Scatter3d(
            x=points[:, 0],
            y=points[:, 1],
            z=points[:, 2],
            mode='markers',
            name=name,
            marker=dict(
                size=3,
                color=color,
                opacity=0.8
            ),
            text=[f"{name} - Point {i}" for i in range(min(100, len(points)))],
            hovertemplate='%{text}<br>x: %{x:.3f}<br>y: %{y:.3f}<br>z: %{z:.3f}<extra></extra>'
        ))
    
    fig.update_layout(
        title=dict(text=title, x=0.5, xanchor='center'),
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
        ),
        width=1000,
        height=700,
        showlegend=True,
        legend=dict(x=0.02, y=0.98, bgcolor='rgba(255,255,255,0.8)')
    )
    
    fig.show()


def generate_sample_point_cloud(n_points=1000, shape='sphere', noise_level=0.0):
    """Generate sample point cloud data with optional noise"""
    if shape == 'sphere':
        # Generate points on a sphere
        theta = np.random.uniform(0, 2*np.pi, n_points)
        phi = np.random.uniform(0, np.pi, n_points)
        r = np.random.uniform(0.8, 1.2, n_points)
        
        x = r * np.sin(phi) * np.cos(theta)
        y = r * np.sin(phi) * np.sin(theta)
        z = r * np.cos(phi)
    elif shape == 'cube':
        # Generate points in a cube
        x = np.random.uniform(-1, 1, n_points)
        y = np.random.uniform(-1, 1, n_points)
        z = np.random.uniform(-1, 1, n_points)
    elif shape == 'torus':
        # Generate points on a torus
        theta = np.random.uniform(0, 2*np.pi, n_points)
        phi = np.random.uniform(0, 2*np.pi, n_points)
        R, r = 1.0, 0.3  # Major and minor radius
        
        x = (R + r * np.cos(phi)) * np.cos(theta)
        y = (R + r * np.cos(phi)) * np.sin(theta)
        z = r * np.sin(phi)
    elif shape == 'cylinder':
        # Generate points on a cylinder
        theta = np.random.uniform(0, 2*np.pi, n_points)
        z = np.random.uniform(-1, 1, n_points)
        r = np.random.uniform(0.8, 1.0, n_points)
        
        x = r * np.cos(theta)
        y = r * np.sin(theta)
    else:
        raise ValueError(f"Unknown shape: {shape}")
    
    points = np.stack([x, y, z], axis=1).astype(np.float32)
    
    # Add noise if specified
    if noise_level > 0:
        noise = np.random.randn(*points.shape) * noise_level
        points += noise.astype(np.float32)
    
    # Add random features (e.g., RGB colors)
    features = np.random.rand(n_points, 3).astype(np.float32)
    
    return points, features

In [None]:
shapes = ['sphere', 'cube', 'torus', 'cylinder']
point_clouds = {}

for shape in shapes:
    points, features = generate_sample_point_cloud(2000, shape)
    point_clouds[shape] = points
    
# Visualize individual shape
points_sphere, features_sphere = generate_sample_point_cloud(5000, 'sphere', noise_level=0.02)
print(f"Points shape: {points_sphere.shape}")
print(f"Features shape: {features_sphere.shape}")

# Interactive visualization with color by height
fig = visualize_point_cloud(
    points_sphere, 
    colors=points_sphere[:, 2],  # Color by z-coordinate
    title="Interactive Sphere Point Cloud (5000 points)",
    point_size=2,
    colorscale='Viridis'
)

# Visualize all shapes together for comparison
visualize_multiple_point_clouds(point_clouds, title="Comparison of Different Point Cloud Shapes")

In [None]:
from warpconvnet.geometry.types.points import Points

points_tensor = torch.from_numpy(points).to(device)
features_tensor = torch.from_numpy(features).to(device)

# Create batch indices (single batch for now)
batch_indices = torch.zeros(len(points), dtype=torch.long).to(device)

# Create PointCloud object
point_cloud = Points(
    [points_tensor],
    [features_tensor],
)
print(point_cloud)

## Voxel Representation

## Conversion between PointCloud and Voxel

In [None]:
def visualize_voxels(voxel_coords, voxel_size=0.1, colors=None, title="Sparse Voxels"):
    """
    Visualize sparse voxels as 3D cubes
    """
    if isinstance(voxel_coords, torch.Tensor):
        voxel_coords = voxel_coords.cpu().numpy()
    
    # Create mesh for voxels
    fig = go.Figure()
    
    # If no colors provided, color by height
    if colors is None:
        colors = voxel_coords[:, 2]
    elif isinstance(colors, torch.Tensor):
        colors = colors.cpu().numpy()
    
    # Normalize colors for visualization
    if len(colors.shape) > 1:
        colors = colors[:, 0]
    
    # Add voxels as 3D scatter with cube markers
    fig.add_trace(go.Scatter3d(
        x=voxel_coords[:, 0],
        y=voxel_coords[:, 1],
        z=voxel_coords[:, 2],
        mode='markers',
        marker=dict(
            size=5,
            color=colors,
            colorscale='Viridis',
            showscale=True,
            symbol='square',
            colorbar=dict(
                title="Value",
                thickness=20,
                len=0.7
            )
        ),
        text=[f"Voxel {i}<br>x: {x:.2f}<br>y: {y:.2f}<br>z: {z:.2f}" 
              for i, (x, y, z) in enumerate(voxel_coords[:min(500, len(voxel_coords))])],
        hovertemplate='%{text}<extra></extra>',
        name='Voxels'
    ))
    
    # Update layout
    fig.update_layout(
        title=dict(text=title, x=0.5, xanchor='center'),
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)),
            aspectmode='data'
        ),
        width=900,
        height=700,
        showlegend=False
    )
    
    fig.show()
    return fig

### Point Clouds -> Voxels

In [None]:
# Create sparse voxel representation with different resolutions
voxel_sizes = [0.05, 0.1, 0.2]
voxel_representations = {}

for voxel_size in voxel_sizes:
    sparse_voxels = point_cloud.to_voxels(voxel_size)
    voxel_representations[f"Size {voxel_size}"] = sparse_voxels
    visualize_voxels(sparse_voxels.coordinates, sparse_voxels.voxel_size)

### Quiz: Voxels -> Point Clouds
- reference: https://github.com/NVlabs/WarpConvNet/blob/eda68fa3e3759dddadfc53d76038fd9246bbf885/warpconvnet/geometry/types/voxels.py

## SparseConvNet

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from warpconvnet.nn.modules.sparse_conv import SparseConv3d


class SparseConvNet(nn.Module):
    """Simple sparse voxel convolution network"""
    
    def __init__(self, in_channels=3, out_channels=32):
        super().__init__()
        
        # Sparse convolution layers
        self.conv1 = SparseConv3d(in_channels, 16, kernel_size=3, stride=1)
        self.conv2 = SparseConv3d(16, 32, kernel_size=3, stride=2)
        self.conv3 = SparseConv3d(32, out_channels, kernel_size=3, stride=1)
        
        # Batch normalization
        self.bn1 = nn.BatchNorm1d(16)
        self.bn2 = nn.BatchNorm1d(32)
        self.bn3 = nn.BatchNorm1d(out_channels)

        self.sparse_conv = Sequential(
            SparseConv3d(in_channels, 16, kernel_size=3, stride=1),
            nn.LayerNorm(16),
            nn.ReLU(),
            SparseConv3d(16, 32, kernel_size=2, stride=2),  # stride
            nn.LayerNorm(32),
            nn.ReLU(),
            SparseConv3d(32, 64, kernel_size=3, stride=1),
            nn.LayerNorm(64),
            nn.ReLU(),
            SparseConv3d(64, 128, kernel_size=2, stride=2),  # stride
            nn.LayerNorm(128),
            nn.ReLU(),
            SparseConv3d(128, 256, kernel_size=3, stride=1),
            nn.LayerNorm(256),
            nn.ReLU(),
        )
        
    def forward(self, sparse_voxels):
        return self.sparse_conv(sparse_voxels)

# Create and test the network
sparse_net = SparseVoxelNet(in_channels=3, out_channels=64).to(device)
output_voxels = sparse_net(sparse_voxels)

print(f"Input voxels: {sparse_voxels}")
print(f"Output voxels: {output_voxels}")
print(f"Output feature dims: {output_voxels.features.shape[1]}")

## Exercise 1: 3D Point Cloud Semantic Segmentation with SparseConvNet

In [None]:
from typing import Dict, List, Optional, Tuple, Union
import yaml

try:
    import hydra
    from hydra.core.config_store import ConfigStore
    from omegaconf import DictConfig, OmegaConf
except ImportError:
    print("Hydra and OmegaConf not installed, pip install hydra-core omegaconf")
    exit(1)

import time
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import warp as wp
from torch import Tensor
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Subset
from torchmetrics.classification import MulticlassConfusionMatrix
from tqdm import tqdm
from warpconvnet.dataset.scannet import ScanNetDataset
from warpconvnet.geometry.base.geometry import Geometry
from warpconvnet.geometry.types.points import Points
from warpconvnet.nn.modules.sparse_pool import PointToSparseWrapper

from mink_unet import MinkUNetBase

In [None]:
def set_seed(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def collate_fn(batch: List[Dict[str, Tensor]]):
    """
    Return dict of list of tensors
    """
    keys = batch[0].keys()
    return {key: [torch.tensor(item[key]) for item in batch] for key in keys}


class DataToTensor:
    def __init__(
        self,
        device: str = "cuda",
    ):
        self.device = device

    def __call__(self, batch_dict: Dict[str, Tensor]) -> Tuple[Geometry, Dict[str, Tensor]]:
        # cat all features into a single tensor
        cat_batch_dict = {k: torch.cat(v, dim=0).to(self.device) for k, v in batch_dict.items()}
        return (
            Points.from_list_of_coordinates(
                batch_dict["coords"],
                features=batch_dict["colors"],
            ).to(self.device),
            cat_batch_dict,
        )


def confusion_matrix_to_metrics(conf_matrix: Tensor) -> Dict[str, float]:
    """
    Return accuracy, miou, class_iou, class_accuracy

    Rows are ground truth, columns are predictions.
    """
    conf_matrix = conf_matrix.cpu()
    accuracy = (conf_matrix.diag().sum() / conf_matrix.sum()).item() * 100
    class_accuracy = (conf_matrix.diag() / conf_matrix.sum(dim=1)) * 100
    class_iou = conf_matrix.diag() / (
        conf_matrix.sum(dim=1) + conf_matrix.sum(dim=0) - conf_matrix.diag()
    )
    miou = class_iou.mean().item() * 100
    return {
        "accuracy": accuracy,
        "miou": miou,
        "class_iou": class_iou,
        "class_accuracy": class_accuracy,
    }

In [None]:
@torch.amp.autocast(device_type="cuda", enabled=True)
def train(
    model: nn.Module,
    train_loader: DataLoader,
    optimizer: optim.Optimizer,
    epoch: int,
    cfg: DictConfig,
):
    model.train()
    bar = tqdm(train_loader)
    dict_to_data = DataToTensor(device=cfg.device)
    for batch_dict in bar:
        start_time = time.time()
        optimizer.zero_grad()
        st, batch_dict = dict_to_data(batch_dict)
        output = model(st.to(cfg.device))
        loss = F.cross_entropy(
            output.features,
            batch_dict["labels"].long().to(cfg.device),
            reduction="mean",
            ignore_index=cfg.data.ignore_index,
        )
        loss.backward()
        optimizer.step()
        bar.set_description(f"Train Epoch: {epoch} Loss: {loss.item(): .3f}")


@torch.amp.autocast(device_type="cuda", enabled=True)
@torch.inference_mode()
def test(
    model: nn.Module,
    test_loader: DataLoader,
    cfg: DictConfig,
    num_test_batches: Optional[int] = None,
    save_visuals: bool = False,
):
    model.eval()
    torch.cuda.empty_cache()
    confusion_matrix = MulticlassConfusionMatrix(
        num_classes=cfg.data.num_classes, ignore_index=cfg.data.ignore_index
    ).to(cfg.device)
    test_loss = 0
    num_batches = 0

    visual_data = []
    dict_to_data = DataToTensor(device=cfg.device)
    for batch_dict in tqdm(test_loader):
        original_batch_dict = batch_dict
        st, batch_dict = dict_to_data(batch_dict)
        output = model(st.to(cfg.device))
        labels = batch_dict["labels"].long().to(cfg.device)
        test_loss += F.cross_entropy(
            output.features,
            labels,
            reduction="mean",
            ignore_index=cfg.data.ignore_index,
        ).item()
        pred = output.features.argmax(dim=1)
        confusion_matrix.update(pred, labels)

        if save_visuals:
            num_items_in_batch = len(st.offsets) - 1
            for i in range(num_items_in_batch):
                start_idx = st.offsets[i]
                end_idx = st.offsets[i+1]

                visual_data.append({
                    "coords": original_batch_dict["coords"][i],
                    "colors": original_batch_dict["colors"][i],
                    "preds": pred[start_idx:end_idx].cpu(),
                    "labels": labels[start_idx:end_idx].cpu(),
                })
        
        num_batches += 1
        if num_test_batches is not None and num_batches >= num_test_batches:
            break

    if save_visuals and visual_data:
        save_path = cfg.paths.output_dir + "visual_predictions.pt"
        os.makedirs(cfg.paths.output_dir, exist_ok=True)
        torch.save(visual_data, save_path)
        print(f"\nSaved visualization data for {len(visual_data)} point clouds to {save_path}")

    metrics = confusion_matrix_to_metrics(confusion_matrix.compute())
    
    print(
        f"\nTest set: Average loss: {test_loss / num_batches: .4f}, Accuracy: {metrics['accuracy']: .2f}%, mIoU: {metrics['miou']: .2f}%\n"
    )
    return metrics

In [None]:
# Embedded YAML configuration
CONFIG_YAML_BASE = """
# Path configuration
paths:
  data_dir: /data/scannet_3d
  output_dir: ./results/scannet_3d
  ckpt_path: null

# Training configuration.
train:
  batch_size: 32
  lr: 0.005
  epochs: 2
  step_size: 20
  gamma: 0.7
  num_workers: 8

# Testing configuration
test:
  batch_size: 12
  num_workers: 4

# Dataset configuration
data:
  num_classes: 20
  voxel_size: 0.02
  ignore_index: 255

# Model configuration
model:
  _target_: mink_unet.MinkUNet18
  in_channels: 3
  out_channels: 20
  in_type: "voxel"

# General configuration
device: "cuda"
use_wandb: false
seed: 42
"""

In [None]:
# Load configs
cfg_dict = yaml.safe_load(CONFIG_YAML_BASE)
cfg = OmegaConf.create(cfg_dict)

set_seed(cfg.seed)
device = torch.device(cfg.device)

# Define dataloaders
train_dataset = ScanNetDataset(cfg.paths.data_dir, split="train",)
train_dataset = Subset(train_dataset, range(100))
test_dataset = ScanNetDataset(cfg.paths.data_dir, split="val",)
test_dataset = Subset(test_dataset, range(50))
train_loader = DataLoader(
    train_dataset,
    batch_size=cfg.train.batch_size,
    num_workers=cfg.train.num_workers,
    shuffle=True,
    collate_fn=collate_fn,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=cfg.test.batch_size,
    num_workers=cfg.test.num_workers,
    shuffle=False,
    collate_fn=collate_fn,
)

# Model initialization
model = MinkUNetBase(
    in_channels=cfg.model.in_channels,
    out_channels=cfg.model.out_channels,
).to(device)

if hasattr(cfg.model, "in_type") and cfg.model.in_type == "voxel":
    model = PointToSparseWrapper(
        inner_module=model,
        voxel_size=cfg.data.voxel_size,
        concat_unpooled_pc=False,
    )

optimizer = optim.AdamW(model.parameters(), lr=cfg.train.lr)
scheduler = StepLR(optimizer, step_size=cfg.train.step_size, gamma=cfg.train.gamma)

# Test before training
metrics = test(model, test_loader, cfg, num_test_batches=5)

for epoch in range(1, cfg.train.epochs + 1):
    train(
        model,
        train_loader,
        optimizer,
        epoch,
        cfg,
    )
    should_save_visuals = (epoch == cfg.train.epochs)
    metrics = test(model, test_loader, cfg, save_visuals=should_save_visuals)
    scheduler.step()

print(f"Final mIoU: {metrics['miou']: .2f}%")

del model
torch.cuda.empty_cache()

In [None]:
class MinkUNetCustom(MinkUNetBase):
    def __init__(self, in_channels: int, out_channels: int, **kwargs):
        raise NotImplemented

In [None]:
# Embedded YAML configuration
CONFIG_YAML_CUSTOM = """
# Path configuration
paths:
  data_dir: /data/scannet_3d
  output_dir: ./results/scannet_3d
  ckpt_path: null

# Training configuration.
train:
  batch_size: 64
  lr: 0.005
  epochs: 2
  step_size: 20
  gamma: 0.7
  num_workers: 8

# Testing configuration
test:
  batch_size: 12
  num_workers: 4

# Dataset configuration
data:
  num_classes: 20
  voxel_size: 0.02
  ignore_index: 255

# Model configuration
model:
  _target_: mink_unet.MinkUNet18
  in_channels: 3
  out_channels: 20
  in_type: "voxel"

# General configuration
device: "cuda"
use_wandb: false
seed: 42
"""

In [None]:
# Load configs
cfg_dict = yaml.safe_load(CONFIG_YAML_CUSTOM)
cfg = OmegaConf.create(cfg_dict)

set_seed(cfg.seed)
device = torch.device(cfg.device)

# Define dataloaders
train_dataset = ScanNetDataset(cfg.paths.data_dir, split="train",)
train_dataset = Subset(train_dataset, range(100))
test_dataset = ScanNetDataset(cfg.paths.data_dir, split="val",)
test_dataset = Subset(test_dataset, range(50))
train_loader = DataLoader(
    train_dataset,
    batch_size=cfg.train.batch_size,
    num_workers=cfg.train.num_workers,
    shuffle=True,
    collate_fn=collate_fn,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=cfg.test.batch_size,
    num_workers=cfg.test.num_workers,
    shuffle=False,
    collate_fn=collate_fn,
)

# Model initialization
model = MinkUNetCustom(
    in_channels=cfg.model.in_channels,
    out_channels=cfg.model.out_channels,
).to(device)

if hasattr(cfg.model, "in_type") and cfg.model.in_type == "voxel":
    model = PointToSparseWrapper(
        inner_module=model,
        voxel_size=cfg.data.voxel_size,
        concat_unpooled_pc=False,
    )

optimizer = optim.AdamW(model.parameters(), lr=cfg.train.lr)
scheduler = StepLR(optimizer, step_size=cfg.train.step_size, gamma=cfg.train.gamma)

# Test before training
metrics = test(model, test_loader, cfg, num_test_batches=5)

for epoch in range(1, cfg.train.epochs + 1):
    train(
        model,
        train_loader,
        optimizer,
        epoch,
        cfg,
    )
    should_save_visuals = (epoch == cfg.train.epochs)
    metrics = test(model, test_loader, cfg, save_visuals=should_save_visuals)
    scheduler.step()

print(f"Final mIoU: {metrics['miou']: .2f}%")

del model
torch.cuda.empty_cache()

## Exercise 2: 3D Point Cloud Semantic Segmentation with PointTransformerV3

In [None]:
# Embedded YAML configuration
CONFIG_YAML_PTV3 = """
# Path configuration
paths:
  data_dir: /data/scannet_3d
  output_dir: ./results/scannet_3d
  ckpt_path: null

# Training configuration.
train:
  batch_size: 4
  lr: 0.001
  epochs: 2
  step_size: 20
  gamma: 0.7
  num_workers: 8

# Testing configuration
test:
  batch_size: 8
  num_workers: 4

# Dataset configuration
data:
  num_classes: 20
  voxel_size: 0.02
  ignore_index: 255

# Model configuration
model:
  _target_: mink_unet.MinkUNet18
  in_channels: 3
  out_channels: 20
  in_type: "voxel"

# General configuration
device: "cuda"
use_wandb: false
seed: 42
"""

In [None]:
from point_transformer_v3 import PointTransformerV3

# Load configs
cfg_dict = yaml.safe_load(CONFIG_YAML_PTV3)
cfg = OmegaConf.create(cfg_dict)

set_seed(cfg.seed)
device = torch.device(cfg.device)

# Define dataloaders
train_dataset = ScanNetDataset(cfg.paths.data_dir, split="train",)
train_dataset = Subset(train_dataset, range(100))
test_dataset = ScanNetDataset(cfg.paths.data_dir, split="val",)
test_dataset = Subset(test_dataset, range(50))
train_loader = DataLoader(
    train_dataset,
    batch_size=cfg.train.batch_size,
    num_workers=cfg.train.num_workers,
    shuffle=True,
    collate_fn=collate_fn,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=cfg.test.batch_size,
    num_workers=cfg.test.num_workers,
    shuffle=False,
    collate_fn=collate_fn,
)

# Model initialization
model = PointTransformerV3(
    in_channels=cfg.model.in_channels,
    out_channels=cfg.model.out_channels,
).to(device)

if hasattr(cfg.model, "in_type") and cfg.model.in_type == "voxel":
    model = PointToSparseWrapper(
        inner_module=model,
        voxel_size=cfg.data.voxel_size,
        concat_unpooled_pc=False,
    )

optimizer = optim.AdamW(model.parameters(), lr=cfg.train.lr)
scheduler = StepLR(optimizer, step_size=cfg.train.step_size, gamma=cfg.train.gamma)

# Test before training
metrics = test(model, test_loader, cfg, num_test_batches=5)

for epoch in range(1, cfg.train.epochs + 1):
    train(
        model,
        train_loader,
        optimizer,
        epoch,
        cfg,
    )
    should_save_visuals = (epoch == cfg.train.epochs)
    metrics = test(model, test_loader, cfg, save_visuals=should_save_visuals)
    scheduler.step()

print(f"Final mIoU: {metrics['miou']: .2f}%")