In [6]:
import os
from collections import defaultdict
from pathlib import Path
import yaml
from tqdm.notebook import tqdm
import pandas as pd
import numpy as np

from sklearn.metrics import adjusted_rand_score

import torch
import torch.nn as nn

from torch_geometric.loader import DataLoader
from torch_geometric.nn import knn

from sphenix_benchmark.utils import Cumulator, Checkpointer
from sphenix_benchmark.datasets.tpc_dataset_with_edge import TPCDataset
from sphenix_benchmark.models.mlp import MLP
from sphenix_benchmark.metrics import compute_roc, compute_pr

In [7]:
help(TPCDataset)

Help on class TPCDataset in module sphenix_benchmark.datasets.tpc_dataset_with_edge:

class TPCDataset(torch.utils.data.dataset.Dataset)
 |  TPCDataset(mmap_root, split, target=None, gnn=False, load_edge=False, **kwargs)
 |
 |  Load mmap_ninja data with optional filtering on multiplicity
 |
 |  Method resolution order:
 |      TPCDataset
 |      torch.utils.data.dataset.Dataset
 |      typing.Generic
 |      builtins.object
 |
 |  Methods defined here:
 |
 |  __getitem__(self, index)
 |
 |  __init__(self, mmap_root, split, target=None, gnn=False, load_edge=False, **kwargs)
 |      Initialize self.  See help(type(self)) for accurate signature.
 |
 |  __len__(self)
 |
 |  ----------------------------------------------------------------------
 |  Data and other attributes defined here:
 |
 |  __annotations__ = {}
 |
 |  __parameters__ = ()
 |
 |  ----------------------------------------------------------------------
 |  Methods inherited from torch.utils.data.dataset.Dataset:
 |
 |  __add

In [8]:
from processor import Processor
from models import assemble_gnn

In [9]:
import scipy.sparse.csgraph as scigraph
import scipy.sparse as sp

In [10]:
import sys
sys.path.append('/home/yhuang2/PROJs/FM_Exploration_benchmark_local/eval/')
from efficiency_purity import calc_efficiency_purity

In [11]:
! nvidia-smi

Thu Jun 12 13:30:41 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.02              Driver Version: 555.42.02      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A6000               Off |   00000000:01:00.0 Off |                  Off |
| 30%   34C    P8             29W /  300W |      18MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX A6000               Off |   00

In [12]:
# set up device
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
gpu_id = 0
torch.cuda.set_device(gpu_id)
device = 'cuda'

## submodule/module configurations

In [13]:
checkpoint_path = Path('checkpoints/')
with open(checkpoint_path/'config.yaml', 'r', encoding='UTF-8') as handle:
    config = yaml.safe_load(handle)

gnn = assemble_gnn(config['gnn_model'])
checkpoint = torch.load(checkpoint_path/'ckpt_last.pth', weights_only=True)
gnn.load_state_dict(checkpoint['model'])
_ = gnn.eval()
gnn = gnn.to(device)

## load data

In [14]:
processor = Processor(**config['data_processor'])

In [19]:
config['data']['target'] = ['reg', 'seg']
valid_ds  = TPCDataset(split='test', **config['data'])
valid_ldr = DataLoader(valid_ds, batch_size=1, shuffle=True)

data = next(iter(valid_ldr))

points    = data['features'].x.to(device)
track_ids = data['seg_target'].x.to(device)
batch     = data['features'].batch.to(device)

edge_index = data['edge_index'].x.to(device)
edge_batch = data['edge_index'].batch.to(device)


# processing filter_input
inputs, head_indices, tail_indices, labels, truncated = processor(
    points     = points,
    track_ids  = track_ids,
    batch      = batch,
    edge_index = edge_index,
    edge_batch = edge_batch
)

# run gnn model
with torch.no_grad():
    logits = gnn(inputs, head_indices, tail_indices)

probs = torch.sigmoid(logits)
#     _, _, auc = compute_roc(probs, labels, vmin=0, vmax=1, reverse=False)
#     _, _, avg_precision = compute_pr(probs, labels, vmin=0, vmax=1, reverse=False)

# # Calculate and evaluate performance 
# plot_distribution(probs, labels, data_type='probability', log_y=False)
# plot_distribution(probs, labels, data_type='probability', log_y=True)
# # ROC curve
# fpr, tpr, auc = compute_roc(probs, labels, num_thresholds=100)
# plot_roc_curve(fpr, tpr, auc)
# # PR curve
# recall, precision, average_precision = compute_pr(probs, labels, num_thresholds=100)
# plot_pr_curve(recall, precision, average_precision)


# connected component
threshold = .9
mask = probs > threshold
edges = torch.stack([head_indices[mask], tail_indices[mask]]).detach().cpu().numpy()
num_edges = edges.shape[1]
num_vertices = len(inputs)
print(num_vertices, num_edges)

# get connected components
sparse_edges = sp.coo_matrix((np.ones(num_edges), edges),
                             shape=(num_vertices, num_vertices))

connected_components = scigraph.connected_components(sparse_edges)[1]
# print(connected_components)
ari = adjusted_rand_score(connected_components, track_ids.cpu().numpy())
print(f'ARI: {ari}')

efficiency, purity, cell = calc_efficiency_purity(connected_components, 
                                                  track_ids.cpu().numpy(), 
                                                  return_df=True)
print(f'efficiency: {efficiency}')
print(f'purity: {purity}')
# # attach labels to data
# graph.labels = connected_components


Loaded 13106 events from /home/sphenix_fm/data/pp_100k_mmap-with_charge!

775 24023
ARI: 0.8421893254497874
efficiency: 0.9333333333333333
purity: 0.7368421052631579


In [30]:
config['data']['target'] = ['reg', 'seg']
valid_ds  = TPCDataset(split='test', **config['data'])
valid_ldr = DataLoader(valid_ds, batch_size=1, shuffle=False)

columns = ['true_label', 'px', 'py', 'pz', 'vtx_x', 'vtx_y', 'vtx_z', 'q', 'e']
cumulator = Cumulator()

pbar = tqdm(enumerate(valid_ldr, start=1), total=len(valid_ldr), desc='evaluation')

threshold = .8
folder = Path(f'results/events_gnn-{int(threshold * 100)}')
if not folder.exists():
    folder.mkdir(parents=True)

stat = defaultdict(list)
for event_idx, data in pbar:

    points    = data['features'].x.to(device)
    track_ids = data['seg_target'].x.to(device)
    batch     = data['features'].batch.to(device)
    particles = data['reg_target'].x # we don't need it to be on gpu
    
    edge_index = data['edge_index'].x.to(device)
    edge_batch = data['edge_index'].batch.to(device)
    
    
    # processing filter_input
    inputs, head_indices, tail_indices, labels, truncated = processor(
        points     = points,
        track_ids  = track_ids,
        batch      = batch,
        edge_index = edge_index,
        edge_batch = edge_batch
    )
    
    # run gnn model
    with torch.no_grad():
        logits = gnn(inputs, head_indices, tail_indices)
    
    probs = torch.sigmoid(logits)

    # connected component    
    mask = probs > threshold
    edges = torch.stack([head_indices[mask], tail_indices[mask]]).detach().cpu().numpy()
    num_edges = edges.shape[1]
    num_vertices = len(inputs)
    
    # get connected components
    sparse_edges = sp.coo_matrix((np.ones(num_edges), edges),
                                 shape=(num_vertices, num_vertices))
    
    connected_components = scigraph.connected_components(sparse_edges)[1]
    ari = adjusted_rand_score(connected_components, track_ids.cpu().numpy())
    
    efficiency, purity, cell = calc_efficiency_purity(connected_components, 
                                                      track_ids.cpu().numpy(), 
                                                      return_df=True)
    metrics = {'ari': ari, 
               'efficiency': efficiency, 
               'purity': purity}
    
    # save metrics
    for key, val in metrics.items():
        stat[key].append(val)
    stat['event_idx'].append(event_idx)
    
    cumulator.update(metrics)
    metrics = cumulator.get_average()
    pbar.set_postfix(metrics)

    # per-event track hit metric
    record = torch.hstack([track_ids.unsqueeze(-1).cpu(), particles])
    df = pd.DataFrame(data=record, columns=columns)
    df['true_label'] = df['true_label'].astype(int)
    cell = cell.merge(df.drop_duplicates(subset='true_label', keep='first'), on='true_label')
    cell.to_csv(f'results/events_gnn-{int(threshold * 100)}/test_event-{event_idx}.csv', index=False)

metrics = pd.DataFrame(data=stat)
metrics.to_csv(f'results/test_gnn-{int(threshold * 100)}_metrics.csv', index=False)
metrics.describe()


Loaded 13106 events from /home/sphenix_fm/data/pp_100k_mmap-with_charge!



evaluation:   0%|          | 0/13106 [00:00<?, ?it/s]

Unnamed: 0,ari,efficiency,purity,event_idx
count,13106.0,13106.0,13106.0,13106.0
mean,0.859604,0.900199,0.767196,6553.5
std,0.174839,0.122639,0.232422,3783.520649
min,-0.100212,0.0,0.0,1.0
25%,0.779448,0.846154,0.636364,3277.25
50%,0.930678,0.931034,0.823529,6553.5
75%,0.994689,1.0,1.0,9829.75
max,1.0,1.0,1.0,13106.0


In [31]:
csv_fnames = sorted(list(Path(f'results/events_gnn-{int(threshold * 100)}/').glob('*csv')), key=lambda fname: int(fname.stem.split('-')[-1]))

combined_dfs = []
for event_idx, csv_fname in tqdm(enumerate(csv_fnames, start=1), total=len(csv_fnames)):
    df = pd.read_csv(csv_fname)
    df['pT'] = np.sqrt(df['px']**2 + df['py']**2)
    temp = pd.concat([df.groupby('true_label')['matched'].any(), 
                      df.groupby('true_label')['pT'].mean()], axis=1).reset_index()
    temp['event_idx'] = event_idx
    combined_dfs.append(temp)

combined_df = pd.concat(combined_dfs, axis=0)
combined_df.to_csv(f'results/test_gnn-{int(threshold * 100)}_tracking_efficiency.csv', index=False)

  0%|          | 0/13106 [00:00<?, ?it/s]

In [32]:
df_pt = combined_df[combined_df.pT > 1]
df_pt.matched.sum() / len(df_pt)

0.9259878419452887

In [33]:
csv_fnames = sorted(list(Path(f'results/events_gnn-{int(threshold * 100)}/').glob('*csv')), key=lambda fname: int(fname.stem.split('-')[-1]))

combined_dfs = []
for event_idx, csv_fname in tqdm(enumerate(csv_fnames, start=1), total=len(csv_fnames)):
    
    df = pd.read_csv(csv_fname)
    df['pT'] = np.sqrt(df['px']**2 + df['py']**2)
    temp = df.groupby('true_label')[['true_ratio', 'pT']].max().reset_index()

    temp['event_idx'] = event_idx
    combined_dfs.append(temp)

combined_df = pd.concat(combined_dfs, axis=0)
combined_df.to_csv(f'results/test_gnn-{int(threshold * 100)}_hit_efficiency.csv', index=False)
hit_efficiency = combined_df['true_ratio'].mean()
print(f'overall hit efficiency = {hit_efficiency}')

  0%|          | 0/13106 [00:00<?, ?it/s]

overall hit efficiency = 0.9639638341489517


In [34]:
csv_fnames = sorted(list(Path(f'results/events_gnn-{int(threshold * 100)}/').glob('*csv')), key=lambda fname: int(fname.stem.split('-')[-1]))

combined_dfs = []
for event_idx, csv_fname in tqdm(enumerate(csv_fnames, start=1), total=len(csv_fnames)):
    
    df = pd.read_csv(csv_fname)
    temp = df.groupby('pred_label')[['pred_ratio']].max().reset_index()
    temp['event_idx'] = event_idx
    combined_dfs.append(temp)

combined_df = pd.concat(combined_dfs, axis=0)
combined_df.to_csv(f'results/test_gnn-{int(threshold * 100)}_hit_purity.csv', index=False)
hit_purity = combined_df['pred_ratio'].mean()
print(f'overall hit purity = {hit_purity}')

  0%|          | 0/13106 [00:00<?, ?it/s]

overall hit purity = 0.9774567254274198
