# Download pre-trained checkpoints

In [None]:
!wget https://github.com/DeepLearnPhysics/PoLAr-MAE/releases/download/weights/mae_pretrain.ckpt
!wget https://github.com/DeepLearnPhysics/PoLAr-MAE/releases/download/weights/polarmae_pretrain.ckpt

In [1]:
from polarmae.models.ssl.pointmae_multitask import PointMAE_Multitask
from polarmae.models.ssl.pointmae import PointMAE

In [2]:
# Load pretrained polarmae
model = PointMAE_Multitask.load_from_checkpoint("polarmae_pretrain.ckpt").cuda()
model.eval()

/sdf/home/y/youngsam/sw/dune/.conda/envs/py310_torch/lib/python3.12/site-packages/pytorch_lightning/utilities/parsing.py:209: Attribute 'encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder'])`.
/sdf/home/y/youngsam/sw/dune/.conda/envs/py310_torch/lib/python3.12/site-packages/pytorch_lightning/utilities/parsing.py:209: Attribute 'decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['decoder'])`.
INFO:polarmae.models.ssl.pointmae_multitask:[rank: 0] ⚙️  MAE prediction: full patch reconstruction


PointMAE_Multitask(
  (encoder): TransformerEncoder(
    (transformer): Transformer(
      (blocks): ModuleList(
        (0): Block(
          (drop_path): Identity()
          (norm1): MaskedLayerNorm()
          (attn): Attention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.05, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (norm2): MaskedLayerNorm()
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (drop): Identity()
          )
        )
        (1): Block(
          (drop_path): MaskedDropPath(drop_prob=0.023)
          (norm1): MaskedLayerNorm()
          (attn): Attention(
            (qkv): Linear(in_features=384, out_features

# Dataset Initialization

In [68]:
from polarmae.datasets import PILArNetDataModule
import torch

dataset = PILArNetDataModule(
    data_path='/sdf/home/y/youngsam/data/dune/larnet/h5/DataAccessExamples/train/generic_v2*.h5',
    batch_size=24,
    num_workers=0,
    dataset_kwargs={
        'emin': 1.0e-2,                      # min energy for log transform
        'emax': 20.0,                        # max energy for log transform
        'energy_threshold': 0.13,            # remove points with energy < 0.13
        'remove_low_energy_scatters': True,  # remove low energy scatters (PID=4)
        'maxlen': -1,                        # max number of events to load
        'min_points': 1024,                  # minimum number of points in an event
    }
)
dataset.setup() # creates and initializes train/val datasets, removed low energy deposits from possible classes, etc.

INFO:polarmae.datasets.PILArNet:[rank: 0] self.emin=0.01, self.emax=20.0, self.energy_threshold=0.13, self.remove_low_energy_scatters=True
INFO:polarmae.datasets.PILArNet:[rank: 0] Building index
INFO:polarmae.datasets.PILArNet:[rank: 0] 1045215 point clouds were loaded
INFO:polarmae.datasets.PILArNet:[rank: 0] 10 files were loaded
INFO:polarmae.datasets.PILArNet:[rank: 0] self.emin=0.01, self.emax=20.0, self.energy_threshold=0.13, self.remove_low_energy_scatters=True
INFO:polarmae.datasets.PILArNet:[rank: 0] Building index
INFO:polarmae.datasets.PILArNet:[rank: 0] 10473 point clouds were loaded
INFO:polarmae.datasets.PILArNet:[rank: 0] 1 files were loaded


In [69]:
from polarmae.utils import transforms
from math import sqrt

normalize = transforms.PointcloudCenterAndNormalize(
                    center=[384, 384, 384],
                    scale_factor=1 / (768 * sqrt(3) / 2)
                )

# Inference on one batch
for batch in dataset.val_dataloader():
    points = batch['points'].cuda()

    # in the past (when this model was trained), the centering and
    # scaling was done in the data module. now it's done in the model
    # in train_transformations and val_transformations.
    # for backwards compatibility, we check if there are zero transforms
    # in val_transformations and apply normalization if so.
    if len(model.val_transformations.transforms) > 0:
        points = model.val_transformations(points)
    else:
        points = normalize(points) # scale and normalize
    lengths = batch['lengths'].cuda()

    # group & encode toks
    out = model.encoder.prepare_tokens_with_masks(points, lengths)
    masked, unmasked = out['masked'], out['unmasked']

    # run visible tokens through encoder
    encoder_output = model.encoder.transformer(out['x'], out['pos_embed'], unmasked).last_hidden_state

    # corrupt embeddings with masked tokens
    corrupted_embeddings = (
        encoder_output * unmasked.unsqueeze(-1) + 
        model.mask_token * masked.unsqueeze(-1)
    )

    # run corrupted embeddings through decoder
    decoder_output = model.decoder.transformer(corrupted_embeddings, out['pos_embed'], out['emb_mask']).last_hidden_state
    masked_output = decoder_output[masked]

    # Full patch reconstruction task
    upscaled = model.increase_dim(masked_output.transpose(0, 1)).transpose(0, 1)
    upscaled = upscaled.reshape(upscaled.shape[0], -1, model.mae_channels)
    break

`upscaled` is flattened so we need to reconstruct it to be the original shape of `groups`.

In [70]:
groups = out['groups']
centers = out['centers']
mask = out['emb_mask']
point_mask = out['point_mask']
upscaled_pts = torch.zeros_like(groups)
upscaled_pts[masked] = upscaled

group_radius = model.encoder.tokenizer.grouping.group_radius


rescaled_groups_truth = groups.clone()
rescaled_groups_truth[:, :, :, :3] *= group_radius # undo the scaling by group_radius
rescaled_groups_truth[:, :, :, :3] = rescaled_groups_truth[:, :, :, :3] + centers[:, :, None, :3] # add back on the centers

rescaled_upscaled_pred = upscaled_pts.clone()
rescaled_upscaled_pred[:, :, :, :3] *= group_radius
rescaled_upscaled_pred[:, :, :, :3] = rescaled_upscaled_pred[:, :, :, :3] + centers[:, :, None, :3]

# Get point lengths for each event
point_lengths = point_mask.sum(-1)

In [71]:
event_idx = 2 # change me to see different events!

# Get original and upscaled (reconstructed) points.
# We are just taking the true visible groups and the reconstructed masked groups.
orig_pts = rescaled_groups_truth[event_idx][unmasked[event_idx]].cpu().numpy()
upscaled_pts = rescaled_upscaled_pred[event_idx][masked[event_idx]].cpu().numpy()

# Get point lengths for just our event
unmasked_lengths = point_lengths[event_idx][unmasked[event_idx]].cpu().numpy()
masked_lengths = point_lengths[event_idx][masked[event_idx]].cpu().numpy()

Below we plot the reconstructed points in red and the original points in black.

In [72]:
import plotly.graph_objects as go
import numpy as np

fig = go.Figure()

# Plot the reconstructed points in red
for i in range(masked_lengths.shape[0]):
    pts = upscaled_pts[i, :masked_lengths[i]]
    fig.add_trace(
        go.Scatter3d(
            x=pts[:, 0],
            y=pts[:, 1],
            z=pts[:, 2],
            mode='markers',
            marker=dict(
                size=1,
                color='red'
            ),
            name='prediction' if i == 0 else '',
            showlegend=True if i == 0 else False
        )
    )

# Plot the original points in black
for i in range(unmasked_lengths.shape[0]):
    pts = orig_pts[i, :unmasked_lengths[i]]
    fig.add_trace(
        go.Scatter3d(
            x=pts[:, 0],
            y=pts[:, 1],
            z=pts[:, 2],
            mode='markers',
            marker=dict(
                size=1,
                color='black'
            ),
            name='visible groups' if i == 0 else '',
            showlegend=True if i == 0 else False
        )
    )
    
fig.update_layout(
    scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z"), width=600, height=600
)

fig.show()

Looks pretty great!

# Embedding PCA

We may also be interested in seeing the variation of the embeddings in a single event.

Since only the visible tokens went through the encoder, we can only visualize the visible ("unmasked") tokens.

In [73]:
from sklearn.decomposition import PCA
import plotly.graph_objects as go

event_idx = 2
example_embeddings = encoder_output[event_idx][out["unmasked"][event_idx]].cpu().numpy()
centers = out["centers"]
example_centers = centers[event_idx][out["unmasked"][event_idx]].cpu().numpy()

pca = PCA(n_components=3).fit(example_embeddings)
scaled_embeddings = pca.transform(example_embeddings)
scaled_embeddings = (scaled_embeddings - scaled_embeddings.min(axis=0)) / (
    scaled_embeddings.max(axis=0) - scaled_embeddings.min(axis=0)
)

# Convert the scaled embeddings to RGB color strings
colors = [
    "rgb({},{},{})".format(int(r * 255), int(g * 255), int(b * 255))
    for r, g, b in scaled_embeddings
]

fig = go.Figure(
    data=go.Scatter3d(
        x=example_centers[:, 0],
        y=example_centers[:, 1],
        z=example_centers[:, 2],
        mode="markers",
        marker=dict(size=5, opacity=1.0, color=colors),
    )
)

fig.update_layout(
    scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z"), width=600, height=600
)

fig.show()

Below we encode at ALL tokens by by switching the attention mask from `unmasked` to `emb_mask` -- which is the mask for all non-padded tokens -- during the encoder forward pass.

In [75]:
encoder_output = model.encoder.transformer(out['x'], out['pos_embed'], out['emb_mask']).last_hidden_state

In [77]:
example_embeddings = encoder_output[event_idx][out["emb_mask"][event_idx]].cpu().numpy()
centers = out["centers"]
example_centers = centers[event_idx][out["emb_mask"][event_idx]].cpu().numpy()

pca = PCA(n_components=3).fit(example_embeddings)
scaled_embeddings = pca.transform(example_embeddings)
scaled_embeddings = (scaled_embeddings - scaled_embeddings.min(axis=0)) / (
    scaled_embeddings.max(axis=0) - scaled_embeddings.min(axis=0)
)

# Convert the scaled embeddings to RGB color strings
colors = [
    "rgb({},{},{})".format(int(r * 255), int(g * 255), int(b * 255))
    for r, g, b in scaled_embeddings
]

fig = go.Figure(
    data=go.Scatter3d(
        x=example_centers[:, 0],
        y=example_centers[:, 1],
        z=example_centers[:, 2],
        mode="markers",
        marker=dict(size=5, opacity=1.0, color=colors),
    )
)

fig.update_layout(
    scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z"), width=600, height=600
)

fig.show()