# Fine-tuning for Semantic Segmentation

Here we will fine-tune the model for semantic segmentation. This is done by simply adding a new point feature upscaling and classification head to the end of the encoder.

In [1]:
from point2vec.datasets import LArNetDataModule
import matplotlib.pyplot as plt
import torch
import numpy as np

# Turn off gradient tracking so we don't run out of memory
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fb8f045b400>

In [2]:
dataset = LArNetDataModule(
    data_path=f'/sdf/home/y/youngsam/data/dune/larnet/h5/DataAccessExamples/train/generic_v2*.h5',
    batch_size=24,
    num_workers=0,
    dataset_kwargs={
        'emin': 1.0e-6,                      # min energy for log transform
        'emax': 20.0,                        # max energy for log transform
        'energy_threshold': 0.13,            # remove points with energy < 0.13
        'remove_low_energy_scatters': True,  # remove low energy scatters (PID=4)
        'maxlen': -1,                        # max number of events to load
        'normalize': True,                   # normalize point cloud to unit sphere
    }
)
dataset.setup()

[DATASET] self.emin=1e-06, self.emax=20.0, self.energy_threshold=0.13, self.normalize=True, self.remove_low_energy_scatters=True
[DATASET] Building index
[DATASET] 864064 point clouds were loaded
[DATASET] 10 files were loaded
[DATASET] self.emin=1e-06, self.emax=20.0, self.energy_threshold=0.13, self.normalize=True, self.remove_low_energy_scatters=True
[DATASET] Building index
[DATASET] 8531 point clouds were loaded
[DATASET] 1 files were loaded


Instantiate pre-trained model

In [3]:
from point2vec.models import PointMAE
from glob import glob
import os

def get_newest_ckpt(ckpt_path):
    ckpt_path = glob(f'{ckpt_path}/**/*.ckpt')
    newest_ckpt = max(ckpt_path, key=os.path.getctime)
    return newest_ckpt

wandb_run_id = 'fjnp0snd'
ckpt_path = f'/sdf/home/y/youngsam/sw/dune/representations/point2vec/PointMAE-Pretraining-LArNet-5voxel/{wandb_run_id}'

model = PointMAE.load_from_checkpoint(
    get_newest_ckpt(ckpt_path)
).cuda()

# fix the tokenizer, as a bug was fixed in the latest version of the code (see https://github.com/youngsm/point2vec/commit/b32552088422d5210897dd548b3c77fbf1b0c0b5)
model.tokenizer.grouping.num_groups = 1024
model.tokenizer.grouping.context_length = 640
model.tokenizer.grouping.group_size = 32
model.tokenizer.grouping.upscale_group_size = 256
model.tokenizer.grouping.group_radius = 5 / 760
model.tokenizer.grouping.overlap_factor = 0.6

model.eval();

  return torch.load(f, map_location=map_location)


Our point feature upsampler will do the following:


Given some embeddings and centers and the points we will want to upscale to, we will

1. Find the K nearest embeddings to each center
2. Interpolate via inverse distance weighting to get embeddings for each point.
3. Apply a 2 layer MLP with batch normalization to the embeddings


In practice, the embeddings will actually be the average embeddings of a list of N layers in the encoder. For this we will use N=[3,7,11], and K=5.

In [4]:
from point2vec.modules.masking import masked_layer_norm

def get_embeddings(model, points, lengths, seg_head_fetch_layers=[3,7,11]):
    point_mask = torch.arange(lengths.max(), device=lengths.device).expand(
        len(lengths), -1
    ) < lengths.unsqueeze(-1)

    tokens, centers, embedding_mask, _, _ = model.tokenizer(points, lengths)
    pos = model.positional_encoding(centers[..., :3])
    output = model.encoder(tokens, pos, embedding_mask, return_hidden_states=True)
    batch_lengths = embedding_mask.sum(dim=1)

    hidden_states = [
        masked_layer_norm(output.hidden_states[i], output.hidden_states[i].shape[-1], embedding_mask)
        for i in seg_head_fetch_layers]  # type: ignore [(B, T, C)]
    token_features = torch.stack(hidden_states, dim=0).mean(0)  # (B, T, C)

    return token_features, centers, embedding_mask

In [7]:
from point2vec.modules.feature_upsampling import PointNetFeatureUpsampling

# Feature Upsampler ============================
point_dim = model.hparams.num_channels
upsampling_dim = model.hparams.encoder_dim
feature_upsampler = PointNetFeatureUpsampling(
    in_channel=point_dim + upsampling_dim,
    mlp=[upsampling_dim, upsampling_dim],
).cuda()

points, lengths, labels, _ = next(iter(dataset.train_dataloader()))
B, N, C = points.shape
points = points.cuda()   # (B, N, 4)
lengths = lengths.cuda() # (B,)
labels = labels.cuda().squeeze(-1) # (B, N)


# Get embeddings
point_mask = torch.arange(lengths.max(), device=lengths.device).expand(
    len(lengths), -1
) < lengths.unsqueeze(-1) # (B, N)

embeddings, centers, embedding_mask = get_embeddings(model, points, lengths)
group_lengths = embedding_mask.sum(dim=1)

upsampled_features,_ = feature_upsampler(
    points[..., :3],                 # xyz1
    centers[..., :3],       # xyz2
    points,                 # points1
    embeddings,             # points2
    lengths,                # point_lens
    group_lengths,          # embedding_lens
    point_mask,             # point_mask for masked bn
) # (B, N, C)

In [8]:
print('upsampled_features.shape', upsampled_features.shape)
print('points.shape', points.shape)

upsampled_features.shape torch.Size([24, 6347, 384])
points.shape torch.Size([24, 6347, 4])


Now we have a latent feature for each individual point in the point cloud. Now we can perform point classification by running each point through a classification head.

In Point-MAE/poit2vec, we actually concatenate along with the individual point features two global feature vectors that give a per-event summary of the point cloud. These correspond to the maximum and mean of the token features for each event.

In [9]:
def masked_mean(group, point_mask):
    valid_elements = point_mask.sum(-1).float().clamp(min=1)
    return (group * point_mask.unsqueeze(-1)).sum(-2) / valid_elements.unsqueeze(-1)

def masked_max(group, point_mask):
    return (group - 1e10 * (~point_mask.unsqueeze(-1))).max(-2).values

B, N, C = points.shape
global_feature = torch.cat(
    [masked_max(embeddings, embedding_mask), masked_mean(embeddings, embedding_mask)], dim=-1
)
upsampled_features = torch.cat(
    [upsampled_features, global_feature.unsqueeze(-1).expand(-1, -1, N).transpose(1, 2)], dim=-1
)
print('upsampled_features.shape', upsampled_features.shape)

upsampled_features.shape torch.Size([24, 6347, 1152])


Our final segmentation head will be a 3 layer MLP with batch normalization and dropout. Each layer will downscale the feature dimension by 2.

In [10]:
import torch.nn as nn
from point2vec.modules.masking import MaskedBatchNorm1d

class SegmentationHead(nn.Module):
    def __init__(
        self,
        encoder_dim: int,
        upsampling_dim: int,
        seg_head_dim: int,
        seg_head_dropout: float,
        num_seg_classes: int,
    ):
        super().__init__()

        self.conv1 = nn.Conv1d(
            2 * encoder_dim + upsampling_dim,
            seg_head_dim,
            1,
            bias=False,
        )
        self.bn1 = MaskedBatchNorm1d(seg_head_dim)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(seg_head_dropout)

        self.conv2 = nn.Conv1d(seg_head_dim, seg_head_dim // 2, 1, bias=False)
        self.bn2 = MaskedBatchNorm1d(seg_head_dim // 2)
        self.relu2 = nn.ReLU()
        # self.dropout2 = nn.Dropout(seg_head_dropout)  # Uncomment if needed

        self.conv3 = nn.Conv1d(seg_head_dim // 2, num_seg_classes, 1)

    def forward(self, x, point_mask):
        """
        x: Input tensor of shape [B, C, N], where N is the maximum number of points.
        point_mask: Boolean tensor of shape [B, N], where True indicates valid points.
        """
        # Ensure point_mask has the correct shape and type
        mask = point_mask.unsqueeze(1).float()  # [B, 1, N]

        # Apply first layer
        x = self.conv1(x)
        x = self.bn1(x, mask)
        x = self.relu1(x)
        x = self.dropout1(x)

        # Apply second layer
        x = self.conv2(x)
        x = self.bn2(x, mask)
        x = self.relu2(x)
        # x = self.dropout2(x)  # Uncomment if dropout is needed

        # Final convolution layer (no batch norm or activation)
        x = self.conv3(x)

        return x

seg_head_dim = 512
num_seg_classes = dataset.num_seg_classes
print('num_seg_classes', num_seg_classes)
seg_head_dropout = 0.5

segmentation_head = SegmentationHead(
    encoder_dim=model.hparams.encoder_dim,
    upsampling_dim=upsampling_dim,
    seg_head_dim=seg_head_dim,
    seg_head_dropout=seg_head_dropout,
    num_seg_classes=num_seg_classes,
).cuda()

cls_logits = segmentation_head(upsampled_features.transpose(1,2), point_mask).transpose(1,2)

pred_label = torch.max(cls_logits, dim=-1).indices

print('First 10 predictions and labels:')
print(pred_label[0, :10], labels.squeeze()[0, :10])

num_seg_classes 4
First 10 predictions and labels:
tensor([0, 3, 2, 2, 2, 2, 3, 0, 3, 2], device='cuda:0') tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')


___

This entire model is encapsulated in `point2vec.models.part_segmentation.Point2VecPartSegmentation`.