# Fine-tuning for Semantic Segmentation - Results

This contains some code for inspecting the outputs of the model fine-tuned for semantic segmentation.


In [1]:
from polarmae.datasets import PILArNetDataModule
import torch

# Turn off gradient tracking so we don't run out of memory
torch.set_grad_enabled(False);

Set up the dataset as usual

In [5]:
dataset = PILArNetDataModule(
    # data_path=f'/home/sc5303/sbnd_data/inss2025/data/schung_xyze_1e4.h5',
    data_path=f'/home/sc5303/sbnd_data/inss2025/data/schung_xyze_1e4.h5',
    batch_size=32,
    num_workers=0,
    dataset_kwargs={
        'emin': 1.0e-3,                      # min energy for log transform
        'emax': 20.0,                        # max energy for log transform
        'energy_threshold': 0.13,            # remove points with energy < 0.13
        'remove_low_energy_scatters': False,  # remove low energy scatters (PID=4)
        'maxlen': -1,                        # max number of events to load
        'min_points': 0,                  # min number of points/event to load
    }
)
dataset.setup()

INFO:polarmae.datasets.PILArNet:[rank: 0] self.emin=0.001, self.emax=20.0, self.energy_threshold=0.13, self.remove_low_energy_scatters=False
INFO:polarmae.datasets.PILArNet:[rank: 0] Building index
INFO:polarmae.datasets.PILArNet:[rank: 0] 10000 point clouds were loaded
INFO:polarmae.datasets.PILArNet:[rank: 0] 1 files were loaded
INFO:polarmae.datasets.PILArNet:[rank: 0] self.emin=0.001, self.emax=20.0, self.energy_threshold=0.13, self.remove_low_energy_scatters=False
INFO:polarmae.datasets.PILArNet:[rank: 0] Building index
INFO:polarmae.datasets.PILArNet:[rank: 0] 10000 point clouds were loaded
INFO:polarmae.datasets.PILArNet:[rank: 0] 1 files were loaded


Download and instantiate fine-tuned model

In [6]:
from polarmae.models.finetune import SemanticSegmentation
from polarmae.utils.checkpoint import load_finetune_checkpoint


model = load_finetune_checkpoint(SemanticSegmentation,
                                 '/exp/sbnd/data/users/sc5303/inss2025/PoLAr-MAE/tutorial/polarmae_fft_segsem.ckpt',
                                 data_path=dataset.hparams.data_path,
                                 pretrained_ckpt_path='/exp/sbnd/data/users/sc5303/inss2025/PoLAr-MAE/tutorial/polarmae_pretrain.ckpt').cuda()
model.eval();

INFO:polarmae.layers.grouping:[rank: 0] Using CNMS for grouping. Using `num_groups` as the K in the ball query (2048)! Make sure it's not too large!
INFO:polarmae.datasets.PILArNet:[rank: 0] self.emin=0.01, self.emax=20.0, self.energy_threshold=0.13, self.remove_low_energy_scatters=True
INFO:polarmae.datasets.PILArNet:[rank: 0] Building index
INFO:polarmae.datasets.PILArNet:[rank: 0] 9758 point clouds were loaded
INFO:polarmae.datasets.PILArNet:[rank: 0] 1 files were loaded
INFO:polarmae.datasets.PILArNet:[rank: 0] self.emin=0.01, self.emax=20.0, self.energy_threshold=0.13, self.remove_low_energy_scatters=True
INFO:polarmae.datasets.PILArNet:[rank: 0] Building index
INFO:polarmae.datasets.PILArNet:[rank: 0] 9758 point clouds were loaded
INFO:polarmae.datasets.PILArNet:[rank: 0] 1 files were loaded
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning mo

The forward pass of the model returns the class logits and a point mask. We can get predictions by taking the argmax of the logits.

In [25]:
from polarmae.utils import transforms
from math import sqrt

normalize = transforms.PointcloudCenterAndNormalize(
                    center=[384, 384, 384],
                    scale_factor=1 / (768 * sqrt(3) / 2)
                )


batch = next(iter(dataset.val_dataloader()))
points = batch['points'].cuda()

# in the past (when this model was trained), the centering and
# scaling was done in the data module. now it's done in the model
# in train_transformations and val_transformations.
# for backwards compatibility, we check if there are zero transforms
# in val_transformations and apply normalization if so.
if len(model.val_transformations.transforms) > 0:
    points = model.val_transformations(points)
else:
    points = normalize(points) # scale and normalize
lengths = batch['lengths'].cuda()
# labels = batch['semantic_id'].cuda().squeeze(-1)
output = model(points, lengths)
pred = output['id_pred']

        230, 200, 127, 533, 578, 185, 447, 309, 432,  87, 160, 595, 269, 192,
        292, 610, 628,  48], device='cuda:0') > 512)! This should not happen!


In [17]:
point_mask = output['point_mask']
pred = pred[point_mask]

In [7]:
from polarmae.utils import transforms
from math import sqrt

normalize = transforms.PointcloudCenterAndNormalize(
                    center=[384, 384, 384],
                    scale_factor=1 / (768 * sqrt(3) / 2)
                )

val_loader = dataset.val_dataloader()  # or DataLoader(dataset.val_dataset, batch_size=..., shuffle=False)


batch_idx = 0

with torch.no_grad():
    for batch in val_loader:
        points = batch['points'].cuda()
        if len(model.val_transformations.transforms) > 0:
            points = model.val_transformations(points)
        else:
            points = normalize(points) # scale and normalize
        lengths = batch['lengths'].cuda()
    
        output = model(points, lengths)
        pred = output['id_pred']
        point_mask = output['point_mask']
        pred = pred[point_mask]
        
        # Save all predictions from this batch
        import numpy as np
        np.save(f'predictions/predictions_batch_{batch_idx}.npy', pred.cpu().numpy())
        batch_idx+=1

        153, 591,  47, 218, 163, 336, 303, 196, 372, 303, 507, 641, 281,  97,
        257, 637, 231, 514], device='cuda:0') > 512)! This should not happen!
        166, 123, 573, 265, 358, 230, 199, 158, 286, 427, 151, 802, 578, 129,
        335, 255, 118, 187], device='cuda:0') > 512)! This should not happen!
        238, 393, 232, 226, 141, 394, 188, 349,  61, 228, 314, 204, 381, 293,
         47, 216, 190, 221], device='cuda:0') > 512)! This should not happen!
        278, 246, 139, 535, 131, 145, 216,  98, 385, 252, 638, 519, 321, 168,
        375, 309, 220, 204], device='cuda:0') > 512)! This should not happen!
        430, 319, 226, 479, 432, 232, 584, 334, 513, 685, 664, 288, 305, 156,
        270, 239, 133, 297], device='cuda:0') > 512)! This should not happen!
        179, 129, 337, 354, 295, 216, 264, 759, 389, 305,  93, 348,  89, 546,
        344, 717, 622, 533], device='cuda:0') > 512)! This should not happen!
        166, 494, 371,  59, 111,  98, 313, 151,  41, 369, 277, 2

KeyboardInterrupt: 