In [2]:
import torch
import MinkowskiEngine as ME
from MinkowskiOps import (
    to_sparse,
)

In [34]:
a = torch.zeros(size=(1, 1, 5, 5))
#a[0, 0, 0, 3] = 1
#a[0, 0, 2, 3] = 1.5
a[0, 0, 1, 4] = 1
a[0, 0, 3, 4] = 1.5

a_sparse = to_sparse(a)

In [35]:
a_sparse

SparseTensor(
  coordinates=tensor([[0, 1, 4],
        [0, 3, 4]], dtype=torch.int32)
  features=tensor([[1.0000],
        [1.5000]])
  coordinate_map_key=coordinate map key:[1, 1]
  coordinate_manager=CoordinateMapManagerCPU(
	[1, 1, ]:	CoordinateMapCPU:2x3
	algorithm=MinkowskiAlgorithm.DEFAULT
  )
  spatial dimension=2)

In [36]:
a_dense = a_sparse.dense(shape=a.shape)[0]

In [37]:
a

tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 1.5000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])

In [38]:
a_dense

tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 1.5000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])

In [39]:
global_avg_pool = ME.MinkowskiGlobalAvgPooling()
pooled_sparse = global_avg_pool(a_sparse)
pooled_dense = a_dense.mean(dim=tuple(range(2, a_dense.ndim)))  # mean over D1, D2, ...


In [40]:
pooled_sparse

SparseTensor(
  coordinates=tensor([[0, 0, 0]], dtype=torch.int32)
  features=tensor([[1.2500]])
  coordinate_map_key=coordinate map key:[0, 0]
  coordinate_manager=CoordinateMapManagerCPU(
	[0, 0, ]:	CoordinateMapCPU:1x3
	[1, 1, ]:	CoordinateMapCPU:2x3
	algorithm=MinkowskiAlgorithm.DEFAULT
  )
  spatial dimension=2)

In [41]:
a_dense

tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 1.5000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])

In [42]:
pooled_dense

tensor([[0.1000]])

In [None]:
import sys
import os
import tqdm
import gc
import torch
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import ListedColormap, BoundaryNorm

module_path = os.path.abspath('..')
if module_path not in sys.path:
    sys.path.append(module_path)
    
from utils import ini_argparse, split_dataset
from dataset import *
from model import MaskedAutoencoderViT3DSparse

import matplotlib
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import font_manager
import matplotlib.colors as mcolors
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt

# reset the plot configurations to default
plt.rcdefaults()

from pathlib import Path
font_path = str(Path(matplotlib.get_data_path(), "fonts/ttf/cmr10.ttf"))
font_manager.fontManager.addfont(font_path)
prop = font_manager.FontProperties(fname=font_path)
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = prop.get_name()
plt.rcParams["axes.formatter.use_mathtext"] = True
params = {'mathtext.default': 'regular' }          
plt.rcParams.update(params)

In [None]:
# manually specify the GPUs to use
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')

parser = ini_argparse()
args = parser.parse_args([])
args.train = False
args.dataset_path = "/scratch/salonso/sparse-nns/faser/events_v5.1b"
args.batch_size = 32
args.sets_path = None#"/scratch/salonso/sparse-nns/faser/events_v5.1/sets.pkl"
args.num_workers = 32
args.load_seg = False
args.stage1 = True
args.train = False
args.preprocessing_input = "sqrt"
args.preprocessing_output = "log"
args.standardize_input = "z-score"
args.standardize_output = None

print("\n- Arguments:")
for arg, value in vars(args).items():
    print(f"  {arg}: {value}")
nb_gpus = len(args.gpus)
gpus = [int(gpu) for gpu in args.gpus]

In [None]:
dataset = SparseFASERCALDataset(args)
args.batch_size = 1
print("- Dataset size: {} events".format(len(dataset)))
train_loader, valid_loader, test_loader = split_dataset(dataset, args, splits=[0.6, 0.1, 0.3], test=True)

In [None]:
model = MaskedAutoencoderViT3DSparse()
checkpoint = torch.load("/scratch/salonso/sparse-nns/faser/deep_learning/faserDLTrans/checkpoints_final/mae_vit_v1/loss_val_total/last.ckpt", map_location='cpu')

# Remove the "model." prefix frÇom the keys in the state_dict
state_dict = {key.replace("model.", ""): value for key, value in checkpoint['state_dict'].items() if key!="loss_occupancy.alpha"}
model.load_state_dict(state_dict, strict=True)
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total trainable params model (total): {}".format(total_params))

In [None]:
from utils import arrange_truth
from sklearn.metrics import confusion_matrix as sklearn_confusion_matrix
from typing import Optional
from collections import defaultdict

def _arrange_batch(batch, device):
    patches_charge = batch['patches_charge'].to(device)
    patches_lepton = batch['patches_lepton'].to(device)
    patches_seg = batch['patches_seg'].to(device)
    patch_ids = batch['patch_ids'].to(device)
    attn_mask = batch['attn_mask'].to(device)
    faser_cal = batch['faser_cal_modules'].to(device)
    rear_cal = batch['rear_cal_modules'].to(device)
    rear_hcal = batch['rear_hcal_modules'].to(device)
    global_scalars = batch['f_glob'].to(device)

    return patches_charge, patches_lepton, patches_seg, patch_ids, attn_mask, faser_cal, rear_cal, rear_hcal, global_scalars
    

In [None]:
occ_pred.max()

In [None]:
model.eval()

target_idx = 16 # 7, 11

threshold = 0.5
gen = torch.Generator(device=device)
gen.manual_seed(0)

t = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), disable=False)
for i, batch in t:
    if i < target_idx:
        continue
        
    torch.cuda.empty_cache()
    gc.collect()

    patches_charge, patches_lepton, patches_seg, patch_ids, attn_mask, *glob = _arrange_batch(batch, device)
    
    # Forward pass
    _, preds, mask, individual_losses = model(
        patches_charge, patches_lepton, patches_seg, patch_ids, attn_mask, glob, mask_ratio=0.5
    )

    patch_ids = patch_ids.cpu().squeeze()
    mask = mask.cpu().squeeze().bool()
    patch_ids_nonmasked = patch_ids[~mask]
    patch_ids_masked = patch_ids[mask]
    
    patches_all = [x.cpu().squeeze(0) for x in (patches_charge, patches_lepton, patches_seg)]
    patches_nonmasked = [x.cpu().squeeze(0)[~mask] for x in (patches_charge, patches_lepton, patches_seg)]
    pred_patches = [x.detach().cpu().squeeze(0)[mask] for x in (preds['occ'], preds['charge'], preds['lepton'], preds['seg'])]
    
    coords_true_all, values_true_all = dataset.patcher.unpatchify(patches_all, patch_ids)
    coords_true_nonmasked, values_true_nonmasked = dataset.patcher.unpatchify(patches_nonmasked, patch_ids_nonmasked)
    coords_pred, values_pred = dataset.patcher.unpatchify(pred_patches, patch_ids_masked)

    charge_true_all, lepton_true_all, seg_true_all = values_true_all
    charge_true_nonmasked, lepton_true_nonmasked, seg_true_nonmasked = values_true_nonmasked
    occ_pred, charge_pred, lepton_pred, seg_pred = values_pred
    occ_pred = torch.sigmoid(occ_pred)

    coords_pred_act = coords_pred[occ_pred >= threshold]
    charge_pred_act = charge_pred[occ_pred >= threshold]
    lepton_pred_act = lepton_pred[occ_pred >= threshold]
    seg_pred_act = seg_pred[occ_pred >= threshold]
    break
    
# (0.9139406313093339, 0.4111643100478804, 0.36962088240746027) same

In [None]:
import matplotlib.pyplot as plt

def plot_nonzero_voxels(coords, charges, s=1, cmap='viridis'):
    """
    Plots non-zero voxel locations from a 48×48×200 torch.Tensor in a 3D scatter plot,
    coloring each point by its voxel value and scaling the Z‑axis proportionally.

    Args:
        tensor (torch.Tensor): Input tensor of shape (48, 48, 200).
        cmap (str): Name of the Matplotlib colormap to use (default: 'viridis').
    """
    # extract coordinates
    x, y, z = coords.T.cpu().numpy()
    values = charges.cpu().numpy()

    # create 3D scatter
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    sc = ax.scatter(z, x, y, c=values, cmap=cmap, s=s)

    # labels & limits
    ax.set_xlabel('Z (0–199)')
    ax.set_ylabel('X (0–47)')
    ax.set_zlabel('Y (0–47)')
    ax.set_xlim(0, 199)
    ax.set_ylim(0, 47)
    ax.set_zlim(0, 47)

    # aspect ratio: Z axis ~4.17× longer
    ax.set_box_aspect((200, 48, 48))

    # add colorbar
    cbar = fig.colorbar(sc, ax=ax, pad=0.1)
    cbar.set_label('Voxel Value')

    plt.tight_layout()
    plt.show()



plot_nonzero_voxels(coords_true_all, charge_true_all)
plot_nonzero_voxels(coords_true_nonmasked, charge_true_nonmasked)
plot_nonzero_voxels(coords_pred_act, charge_pred_act)
plot_nonzero_voxels(torch.cat((coords_true_nonmasked, coords_pred_act)),
                    torch.cat((charge_true_nonmasked, charge_pred_act)))

