In [50]:
# Compact training script to train a simple NN
import mplhep as hep
hep.style.use("CMS")
import matplotlib
matplotlib.rc('font', size=13)
import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'
import tqdm
import os
import pickle
import torch
import numpy as np
from scipy import asarray as ar, exp

all_pids = [13, -13]

def calculate_phi(x, y, z=None):
    return np.arctan2(y, x)

def calculate_eta(x, y, z):
    theta = np.arctan2(np.sqrt(x ** 2 + y ** 2), z)
    return -np.log(np.tan(theta / 2))

import io

def get_dataset(save_ckpt=None):
    class CPU_Unpickler(pickle.Unpickler):
        def find_class(self, module, name):
            if module == 'torch.storage' and name == '_load_from_bytes':
                return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
            else:
                return super().find_class(module, name)
    if save_ckpt is not None:
        # check if exists
        if os.path.exists(save_ckpt):
            print("Loading dataset from", save_ckpt)
            #r = pickle.ile exiload(open(save_ckpt, "rb"))
            r = CPU_Unpickler(open(save_ckpt, "rb")).load()
            print("len x", len(r["x"]))
        else:
            r = None
    else:
        r = None
    old_dataset = False
    if r is None:
        path = "/eos/user/g/gkrzmanc/results/2024/PID_muons_GT_clusters_025_dataset_save/cluster_features"
        r = {}
        n = 0
        nmax = 1000000
        print("Dataset path:", path)
        for file in tqdm.tqdm(os.listdir(path)):
            n += 1
            if n > nmax: #or os.path.isdir(os.path.join(path, file)):
                break
            if os.path.isdir(os.path.join(path, file)):
                continue
            #f = pickle.load(open(os.path.join(path, file), "rb"))
            f = CPU_Unpickler(open(os.path.join(path, file), "rb")).load()
            if (len(file) != len("8510eujir6.pkl")): # in case some temporary files are still stored there
                continue
            #print(f.keys())
            if (f["e_reco"].flatten() == 1.).all():
                continue  # Some old files, ignore them for now
            for key in f:
                if key == "pid_y":
                    if key not in r:
                        r[key] = torch.tensor(f[key])
                    else:
                        r[key] = torch.concatenate([r[key], torch.tensor(f[key])])
                elif key != "y_particles" or old_dataset:
                    if key not in r:
                        r[key] = f[key]
                    else:
                        r[key] = torch.concatenate([torch.tensor(r[key]), torch.tensor(f[key])], axis=0)
                else:
                    if "pid_y" not in f.keys():
                        print(key)
                        if key not in r:
                            r[key] = f[key].pid.flatten()
                            r["part_coord"] = f[key].coord
                        else:
                            r[key] = torch.concatenate((r[key], f[key].pid.flatten()), axis=0)
                            r["part_coord"] = torch.concatenate((r["part_coord"], f[key].coord))
                            assert len(r["part_coord"]) == len(r[key])
    x_names = ["ecal_E", "hcal_E", "num_hits", "track_p", "ecal_dispersion", "hcal_dispersion", "sum_e", "num_tracks", "track_p_chis"]
    h_names = ["hit_x_avg", "hit_y_avg", "hit_z_avg"]
    h1_names = ["hit_eta_avg", "hit_phi_avg"]
    print(r.keys())
    print("x shape:", r["x"].shape)
    if old_dataset:
        r["y_particles"] = r["y_particles"][:, 6]
        xyz = r["node_features_avg"][:, [0, 1, 2]].cpu()
        eta_phi = torch.stack([calculate_eta(xyz[:, 0], xyz[:, 1], xyz[:, 2]), calculate_phi(xyz[:, 0], xyz[:, 1])], dim=1)
        r["x"] = torch.cat([r["x"], xyz, eta_phi], dim=1)
    key = "e_true"
    true_e_corr_f = r["true_e_corr"]
    key = "e_true_corrected_daughters"
    true_e_corr_f = r["e_true_corrected_daughters"] / r["e_reco"] - 1
    if "pid_y" in r:
        r["y_particles"] = r["pid_y"]
    abs_energy_diff = np.abs(r["e_true_corrected_daughters"] - r["e_true"])
    electron_brems_mask = (r["pid_y"] == 11) & (abs_energy_diff > 0)
    if save_ckpt is not None and not os.path.exists(save_ckpt):
        pickle.dump(r, open(save_ckpt, "wb"))
    return r["x"], x_names + h_names + h1_names, r["e_true"], r[key], r["e_reco"], r["y_particles"], r["coords_y"] #torch.concatenate([r["eta"].reshape(1, -1), r["phi"].reshape(1, -1)], axis=0).T
def get_split(ds, overfit=False):
    from sklearn.model_selection import train_test_split
    x, _, y, etrue, _, pids, positions = ds
    xtrain, xtest, ytrain, ytest, energiestrain, energiestest, pid_train, pid_test, pos_train, pos_test = train_test_split(
        x, y, etrue, pids, positions, test_size=0.2, random_state=42
    )
    if overfit:
        return xtrain[:100], xtest[:100], ytrain[:100], ytest[:100], energiestrain[:100], energiestest[:100], pid_train[:100], pid_test[:100]
    return xtrain, xtest, ytrain, ytest, energiestrain, energiestest, pid_train, pid_test, pos_train, pos_test # 8,9 are pos train and pos test


In [51]:
ds = get_dataset()
split = get_split(ds)

Dataset path: /eos/user/g/gkrzmanc/results/2024/PID_muons_GT_clusters_025_dataset_save/cluster_features


  r[key] = torch.concatenate([torch.tensor(r[key]), torch.tensor(f[key])], axis=0)
100%|██████████| 56/56 [00:03<00:00, 15.41it/s]


dict_keys(['x', 'e_true', 'e_reco', 'true_e_corr', 'e_true_corrected_daughters', 'coords_y', 'pid_y'])
x shape: torch.Size([15464, 16])


In [52]:
filter = split[6].abs()==13
filter.sum() # Muons


tensor(56)

In [53]:
split[0][:, 9]

tensor([0., 0., 0.,  ..., 0., 0., 0.])

In [54]:
num_muon_hits_muons = split[0][:, 10][filter]
num_muon_hits_other = split[0][:, 10][~filter]
e_muon_hits_muons = split[0][:, 9][filter]
e_muon_hits_other = split[0][:, 9][~filter]

In [55]:
num_muon_hits_muons

tensor([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  8.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  3.,  0.,  7.,  0.,  0.,  0.,  0.,
         0.,  0., 11.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])

In [56]:
num_muon_hits_muons.mean(), num_muon_hits_other.mean()

(tensor(0.5536), tensor(0.1038))

In [57]:
(num_muon_hits_muons != 0).float().mean(), (num_muon_hits_other != 0).float().mean()

(tensor(0.0893), tensor(0.0249))

In [58]:
e_muon_hits_muons.mean(), e_muon_hits_other.mean()

(tensor(6.5641e-05), tensor(3.5406e-05))

In [59]:
'''
([per_graph_e_hits_ecal / sum_e,
                                per_graph_e_hits_hcal / sum_e,
                                num_hits, track_p,
                                per_graph_e_hits_ecal_dispersion,
                                per_graph_e_hits_hcal_dispersion,
                                sum_e, num_tracks, torch.clamp(chis_tracks, -5, 5),
                                per_graph_e_hits_muon / sum_e,
                             ]).T'''

'\n([per_graph_e_hits_ecal / sum_e,\n                                per_graph_e_hits_hcal / sum_e,\n                                num_hits, track_p,\n                                per_graph_e_hits_ecal_dispersion,\n                                per_graph_e_hits_hcal_dispersion,\n                                sum_e, num_tracks, torch.clamp(chis_tracks, -5, 5),\n                                per_graph_e_hits_muon / sum_e,\n                             ]).T'