In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
import torch; torch.set_grad_enabled(False); # turn off gradients so memory doesn't explode

# Fine-tuning for Semantic Segmentation

Here we will understand the mechanics of fine-tuning the model for semantic segmentation. This is done by adding a new point feature upscaling and classification head to the end of the encoder.

## Feature Upscaling

The feature upscaling is done using a PointNet++ upscaling module. The basic procedure is as follows:

1. Pass the point cloud through the encoder to get the token features at all intermediate layers.
2. Compute the per-event global features by performing an average and max pooling of the token features across the intermediate layers (usually layers [3,7,11]), and concatenate the results.
3. Treating each token from the last layer of the encoder as a point in 3D space, interpolate the latent features for all points in the cloud via inverse distance weighting.
4. Concatenate the 3D positions of each point with its interpolated latent features, and pass the result through a MLP to encode positional context.
5. Concatenate the 'global features' with the positional features, and pass the result through a MLP to encode positional context.

## Segmentation Head

The segmentation head is a simple MLP head that takes the upscaled features and outputs a per-point classification logits.

---

In [2]:
from polarmae.datasets import PILArNetDataModule
dataset = PILArNetDataModule(
    data_path='/sdf/home/y/youngsam/data/dune/larnet/h5/DataAccessExamples/train/generic_v2*.h5',
    batch_size=24,
    num_workers=0,
    dataset_kwargs={
        'emin': 1.0e-6,                      # min energy for log transform
        'emax': 20.0,                        # max energy for log transform
        'energy_threshold': 0.13,            # remove points with energy < 0.13
        'remove_low_energy_scatters': True,  # remove low energy scatters (PID=4)
        'maxlen': -1,                        # max number of events to load
        'min_points': 1024,
    }
)
dataset.setup()

INFO:polarmae.datasets.PILArNet:[rank: 0] self.emin=1e-06, self.emax=20.0, self.energy_threshold=0.13, self.remove_low_energy_scatters=True
INFO:polarmae.datasets.PILArNet:[rank: 0] Building index
INFO:polarmae.datasets.PILArNet:[rank: 0] 1045215 point clouds were loaded
INFO:polarmae.datasets.PILArNet:[rank: 0] 10 files were loaded
INFO:polarmae.datasets.PILArNet:[rank: 0] self.emin=1e-06, self.emax=20.0, self.energy_threshold=0.13, self.remove_low_energy_scatters=True
INFO:polarmae.datasets.PILArNet:[rank: 0] Building index
INFO:polarmae.datasets.PILArNet:[rank: 0] 10473 point clouds were loaded
INFO:polarmae.datasets.PILArNet:[rank: 0] 1 files were loaded


In [3]:
! wget https://github.com/DeepLearnPhysics/PoLAr-MAE/releases/download/weights/polarmae_pretrain.ckpt

Will not apply HSTS. The HSTS database must be a regular and non-world-writable file.
ERROR: could not open HSTS store at '/sdf/home/y/youngsam/.wget-hsts'. HSTS will be disabled.
--2025-02-07 17:52:36--  https://github.com/DeepLearnPhysics/PoLAr-MAE/releases/download/weights/polarmae_pretrain.ckpt
Resolving github.com (github.com)... 140.82.116.3
Connecting to github.com (github.com)|140.82.116.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/927478490/ade5074b-3d24-4d8a-b0e0-65297a6fa9cd?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250208%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250208T015209Z&X-Amz-Expires=300&X-Amz-Signature=3743eb82ceb3fad54858cc9c520ce139f89a3e307b1938b5b3ecc3d1c5db40ba&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3Dpolarmae_pretrain.ckpt&response-content-type=application%2Foctet-stream 


2025-02-07 17:52:42 (64.5 MB/s) - ‘polarmae_pretrain.ckpt.2’ saved [356044228/356044228]



Instantiate pre-trained model

In [3]:
from polarmae.models.ssl import PoLArMAE

model = PoLArMAE.load_from_checkpoint("polarmae_pretrain.ckpt").cuda()
model.eval();

/sdf/home/y/youngsam/sw/dune/.conda/envs/py310_torch/lib/python3.12/site-packages/pytorch_lightning/utilities/parsing.py:209: Attribute 'encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder'])`.
/sdf/home/y/youngsam/sw/dune/.conda/envs/py310_torch/lib/python3.12/site-packages/pytorch_lightning/utilities/parsing.py:209: Attribute 'decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['decoder'])`.
INFO:polarmae.models.ssl.polarmae:[rank: 0] ⚙️  MAE prediction: full patch reconstruction


Our point feature upsampler will do the following:


Given some embeddings and centers and the points we will want to upscale to, we will

1. Find the K nearest embeddings to each center
2. Interpolate via inverse distance weighting to get embeddings for each point.
3. Apply a 2 layer MLP with batch normalization to the embeddings


In practice, the embeddings will actually be the average embeddings of a list of N layers in the encoder. For this we will use N=[3,7,11], and K=5.

In [4]:
from polarmae.layers.feature_upsampling import PointNetFeatureUpsampling

# Feature Upsampler ============================
point_dim = 3
upsampling_dim = model.encoder.embed_dim
feature_upsampler = PointNetFeatureUpsampling(
    in_channel=upsampling_dim,
    mlp=[upsampling_dim, upsampling_dim],
).cuda()

batch = next(iter(dataset.train_dataloader()))
points = batch['points'].cuda()
lengths = batch['lengths'].cuda()
labels = batch['semantic_id'].cuda().squeeze(-1)
B, N, C = points.shape
points = points.cuda()   # (B, N, 4)
lengths = lengths.cuda() # (B,)
labels = labels.cuda().squeeze(-1) # (B, N)

out = model.encoder.prepare_tokens(points, lengths)
output = model.encoder(out["x"], out["pos_embed"], out["emb_mask"], return_hidden_states=True)
batch_lengths = out["emb_mask"].sum(dim=1)
embeddings = output.last_hidden_state
group_lengths = out["emb_mask"].sum(dim=1)
point_mask = torch.arange(lengths.max(), device=lengths.device).expand(
            len(lengths), -1
        ) < lengths.unsqueeze(-1)

upsampled_features,_ = feature_upsampler(
    points[..., :3],                 # xyz1
    out['centers'][..., :3],       # xyz2
    points[..., :3],                 # points1
    embeddings,             # points2
    lengths,                # point_lens
    batch_lengths,          # embedding_lens
    point_mask,             # point_mask for masked bn
) # (B, N, C)

In [5]:
print('upsampled_features.shape', upsampled_features.shape)
print('points.shape', points.shape)

upsampled_features.shape torch.Size([24, 7839, 384])
points.shape torch.Size([24, 7839, 4])


Now we have a latent feature for each individual point in the point cloud. Now we can perform point classification by running each point through a simple classification head!

In Point-MAE/PoLAr-MAE, we actually concatenate along with the individual point features two global feature vectors that give a per-event summary of the point cloud. These correspond to the maximum and mean of the token features for each event. We first get the intermediate token features from the 3rd, 7th, and 11th layers of the encoder and take their mean. These features are then run through the masked mean/max functions and concatenated together.

`TransformerEncoder.combine_intermediate_layers` does this fetching and averaging of the intermediate token features for us. Here's the code for it:

```python
    def combine_intermediate_layers(
        self,
        output: TransformerOutput,
        mask: Optional[torch.Tensor] = None,
        layers: List[int] = [0],
    ) -> torch.Tensor:
        hidden_states = [
            masked_layer_norm(output.hidden_states[i], output.hidden_states[i].shape[-1], mask)
            for i in layers
        ]
        return torch.stack(hidden_states, dim=0).mean(0)
```

In [6]:
def masked_mean(group, point_mask):
    """
    perform a mean over the last dimension of the input,
    taking care to only include valid points
    """
    valid_elements = point_mask.sum(-1).float().clamp(min=1)
    return (group * point_mask.unsqueeze(-1)).sum(-2) / valid_elements.unsqueeze(-1)

def masked_max(group, point_mask):
    """
    perform a max over the last dimension of the input,
    taking care to only include valid points
    """
    return (group - 1e10 * (~point_mask.unsqueeze(-1))).max(-2).values

B, N, C = points.shape

intermediate_features = model.encoder.combine_intermediate_layers(output, out["emb_mask"], [3,7,11])

global_feature = torch.cat(
    [masked_max(intermediate_features, out["emb_mask"]), masked_mean(intermediate_features, out["emb_mask"])], dim=-1
)
upsampled_features = torch.cat(
    [upsampled_features, global_feature.unsqueeze(-1).expand(-1, -1, N).transpose(1, 2)], dim=-1
)
print('upsampled_features.shape', upsampled_features.shape)

upsampled_features.shape torch.Size([24, 7839, 1152])


We now have a whopping 1152 (384 $\times$ 3) features for each point in the point cloud!

Our final segmentation head will be a 3 layer MLP with batch normalization and dropout. Each layer will downscale the feature dimension by 2.

In [7]:
import torch.nn as nn
from polarmae.layers.seg_head import SegmentationHead

seg_head_dim = 512
num_seg_classes = dataset.num_seg_classes # 4 for us
print('Number of segmentation classes:', num_seg_classes)
seg_head_dropout = 0.5

segmentation_head = SegmentationHead(
    encoder_dim=model.encoder.embed_dim,
    label_embedding_dim=0, # event-wide label embedding -- 0 for our dataset!
    upsampling_dim=model.encoder.embed_dim,
    seg_head_dim=seg_head_dim,
    seg_head_dropout=seg_head_dropout,
    num_seg_classes=num_seg_classes,
).cuda()

# note the need to transpose the features. this is because we use a 1D conv1d.
cls_logits = segmentation_head(upsampled_features.transpose(1,2), point_mask).transpose(1,2)

pred_label = torch.max(cls_logits, dim=-1).indices

print('First 10 predicted and true labels for the first event:')
print(pred_label[0, :10], labels.squeeze()[0, :10])

Number of segmentation classes: 4
First 10 predicted and true labels for the first event:
tensor([1, 1, 3, 3, 1, 3, 1, 1, 1, 1], device='cuda:0') tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')


The performance here is poor, but this is expected -- we haven't fine-tuned the model yet!

___

This entire model is encapsulated in `polarmae.models.finetune.semantic_segmentation.SemanticSegmentation`. This code is a little more complex than the other models because it has to handle the upscaling and the segmentation head and has more options (e.g., whether to do this global feature conditioning or use a transformer-based segmentation decoder after the encoder).