In [1]:
import cuml
import cugraph
import cudf
import os
import torch
import cupy as cp
from onetrack import TrackingData 
from onetrack.file_utils import list_files

## Normal Tracking Performance

In [47]:
event_path = "/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/hetero_gnn_processed/0GeV_v3/"

In [48]:
file_list = list_files(os.path.join(event_path, "test"))[:100]

In [49]:
tracking_data = TrackingData(file_list)
tracking_data.build_candidates(building_method="CC", sanity_check=False, score_cut = 0.01)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [50]:
matching_config = {
    "min_hits_truth": 9,
    "min_hits_reco": 5,
    "frac_reco_matched": 0.5,
    "frac_truth_matched": 0.5,
}
tracking_data.evaluate_candidates(evaluation_method="matching", **matching_config)

  0%|          | 0/10 [00:00<?, ?it/s]

{'building_method': 'CC', 'evaluation_method': 'matching', 'eff': 0.5713027393640447, 'single_eff': 0.8028714877124989, 'fr': 0.27385870820395697, 'dup': 0.025456131797069466}
n_true_tracks: 11353, n_reco_tracks: 38969, n_matched_particles: 6486, n_matched_tracks: 28297, n_duplicated_tracks: 992


## Generate Maximum Spanning Trees

In [None]:
for i, file in enumerate(file_list):
    event = torch.load(file)
    data_frame = cudf.DataFrame()
    data_frame['src'] = cp.asarray(event.edge_index[0])
    data_frame['dst'] = cp.asarray(event.edge_index[1])
    data_frame['weight'] = cp.asarray(1 - event.scores)
    G = cugraph.Graph()
    G.from_cudf_edgelist(data_frame, source='src', destination='dst', edge_attr='weight', renumber=False)
    G = cugraph.tree.minimum_spanning_tree(G, algorithm='boruvka')
    df = G.view_edge_list()
    event.edge_index = torch.tensor(cp.vstack([df['src'].to_cupy(), df['dst'].to_cupy()])).long()
    event.scores = torch.tensor(1 - df['weights'].to_cupy()).float()
    torch.save(event, "/global/cfs/cdirs/m3443/usr/ryanliu/tracking_eff/minimum_spanning_tree/test/{}".format(i))

## Compute Performance

In [51]:
event_path = "/global/cfs/cdirs/m3443/usr/ryanliu/tracking_eff/minimum_spanning_tree/"

In [52]:
file_list = list_files(os.path.join(event_path, "test"))[:100]

In [53]:
tracking_data = TrackingData(file_list)
tracking_data.build_candidates(building_method="CC", sanity_check=False, score_cut = 0.01)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [54]:
matching_config = {
    "min_hits_truth": 9,
    "min_hits_reco": 5,
    "frac_reco_matched": 0.5,
    "frac_truth_matched": 0.5,
}
tracking_data.evaluate_candidates(evaluation_method="matching", **matching_config)

  0%|          | 0/10 [00:00<?, ?it/s]

{'building_method': 'CC', 'evaluation_method': 'matching', 'eff': 0.5713027393640447, 'single_eff': 0.8032238174931736, 'fr': 0.27389597392933207, 'dup': 0.025454825382977084}
n_true_tracks: 11353, n_reco_tracks: 38971, n_matched_particles: 6486, n_matched_tracks: 28297, n_duplicated_tracks: 992
