In [1]:
import sys
import os
import tqdm
import gc
import glob
import torch
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import ListedColormap, BoundaryNorm

module_path = os.path.abspath('..')
if module_path not in sys.path:
    sys.path.append(module_path)
    
from utils import ini_argparse, split_dataset, collate_test
from dataset import *
from model import MinkMAEViT, mae_vit_base, mae_vit_large, mae_vit_huge

import matplotlib
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import font_manager
import matplotlib.colors as mcolors
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt

# reset the plot configurations to default
plt.rcdefaults()

from pathlib import Path
font_path = str(Path(matplotlib.get_data_path(), "fonts/ttf/cmr10.ttf"))
font_manager.fontManager.addfont(font_path)
prop = font_manager.FontProperties(fname=font_path)
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = prop.get_name()
plt.rcParams["axes.formatter.use_mathtext"] = True
params = {'mathtext.default': 'regular' }          
plt.rcParams.update(params)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# manually specify the GPUs to use
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')

parser = ini_argparse()
args = parser.parse_args([])
#args.dataset_path = "/scratch/salonso/sparse-nns/faser/events_v3.5"
args.dataset_path = "/scratch/salonso/sparse-nns/faser/events_v5.1b"
args.batch_size = 32
args.sets_path = None
args.num_workers = 32
args.load_seg = False
args.stage1 = True
args.train = False
args.preprocessing_input = "sqrt"
args.preprocessing_output = "log"
args.standardize_input = "unit-var"
args.standardize_output = "unit-var"

print("\n- Arguments:")
for arg, value in vars(args).items():
    print(f"  {arg}: {value}")
nb_gpus = len(args.gpus)
gpus = [int(gpu) for gpu in args.gpus]


- Arguments:
  train: False
  stage1: True
  preprocessing_input: sqrt
  preprocessing_output: log
  standardize_input: unit-var
  standardize_output: unit-var
  augmentations_enabled: True
  dataset_path: /scratch/salonso/sparse-nns/faser/events_v5.1b
  mask_ratio: 0.75
  eps: 1e-12
  batch_size: 32
  epochs: 50
  layer_decay: 0.9
  num_workers: 32
  lr: 0.0001
  accum_grad_batches: 1
  warmup_steps: 0
  cosine_annealing_steps: 0
  weight_decay: 0.05
  beta1: 0.9
  beta2: 0.999
  ema_decay: 0.9999
  head_init: 0.001
  dropout: 0.1
  save_dir: /scratch/salonso/sparse-nns/faser/deep_learning/faserDL
  name: v1
  log_every_n_steps: 50
  early_stop_patience: 0
  save_top_k: 1
  checkpoint_path: /scratch/salonso/sparse-nns/faser/deep_learning/faserDL/checkpoints
  checkpoint_name: v1
  load_checkpoint: None
  gpus: [0]
  sets_path: None
  load_seg: False


In [3]:
dataset = SparseFASERCALDataset(args)
print("- Dataset size: {} events".format(len(dataset)))
train_loader, valid_loader, test_loader = split_dataset(dataset, args, splits=[0.6, 0.1, 0.3], test=True)

- Dataset size: 110938 events


In [4]:
from collections import defaultdict
import torch

def count_parameters(model):
    parts = defaultdict(int)
    
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if "downsample_layers" in name:
            parts["1. Downsample Layers"] += param.numel()
        elif "blocks" in name and "decoder_blocks" not in name:
            parts["2. Encoder Blocks"] += param.numel()
        elif "cls_token" in name or "pos_embed" in name or "global_feats_encoder" in name:
            parts["3. Encoder Misc"] += param.numel()
        elif "decoder_embed" in name or "mask_token" in name or "decoder_pos_embed" in name:
            parts["4. Decoder Input"] += param.numel()
        elif "decoder_blocks" in name:
            parts["5. Decoder Blocks"] += param.numel()
        elif "decoder_norm" in name or "final_embed" in name:
            parts["6. Decoder Output"] += param.numel()
        elif "upsample_layers" in name:
            parts["7. Upsample Layers"] += param.numel()
        elif "reg_head" in name:
            parts["8. Regression Head"] += param.numel()
        elif "cls_head" in name:
            parts["9. Classification Head"] += param.numel()
        else:
            print(name)
            parts["Other"] += param.numel()

    total = sum(parts.values())
    for k, v in sorted(parts.items(), key=lambda x: x[0]):
        print(f"{k:<25}: {v:,}")
    print(f"\nTotal Parameters        : {total:,}")

# Example usage:
count_parameters(mae_vit_base())
count_parameters(mae_vit_large())
count_parameters(mae_vit_huge())

norm.weight
norm.bias
1. Downsample Layers     : 1,788,864
2. Encoder Blocks        : 85,054,464
3. Encoder Misc          : 9,487,104
4. Decoder Input         : 406,560
5. Decoder Blocks        : 26,818,176
6. Decoder Output        : 407,328
7. Upsample Layers       : 3,246,048
8. Regression Head       : 97
9. Classification Head   : 388
Other                    : 1,536

Total Parameters        : 127,210,565
norm.weight
norm.bias
1. Downsample Layers     : 3,073,644
2. Encoder Blocks        : 292,940,928
3. Encoder Misc          : 16,322,544
4. Decoder Input         : 533,280
5. Decoder Blocks        : 26,818,176
6. Decoder Output        : 534,288
7. Upsample Layers       : 5,590,998
8. Regression Head       : 127
9. Classification Head   : 508
Other                    : 2,016

Total Parameters        : 345,816,509
norm.weight
norm.bias
1. Downsample Layers     : 5,071,572
2. Encoder Blocks        : 645,511,680
3. Encoder Misc          : 26,958,096
4. Decoder Input         : 685,344
5.