In [7]:
import uproot
import numpy as np
import os

input_file = "data/atlas/data24_13p6TeV.root"
signal_path = "data/atlas/signal_clusters.npy"
noise_path = "data/atlas/noise_clusters.npy"
N_EVENTS = 10000  # Numero di eventi di segnale e rumore da salvare

load = True  # True per caricare, False per estrarre nuovi dati

if os.path.exists(signal_path) and os.path.exists(noise_path) and load:
    signal_clusters = np.load(signal_path, allow_pickle=True).tolist()
    noise_clusters = np.load(noise_path, allow_pickle=True).tolist()
    print("Dati caricati da file numpy.")
else:
    file = uproot.open(input_file)
    tree = file["BasicTesterTree;1"]

    # Estrai tutti i dati necessari in una sola volta per efficienza
    mmOnTrack_stripCharges = tree["mmOnTrackStripCharges"].array(library="np")
    mmOnTrack_localPosX = tree["mmOnTrackLocalPos_x"].array(library="np")
    mmOnTrack_stripTimes = tree["mmOnTrackStripDriftTimes"].array(library="np")
    mmOnTrack_stationIndex = tree["mmOnTrack_stationIndex"].array(library="np")
    mmOnTrack_stationEta = tree["mmOnTrack_stationEta"].array(library="np")
    mmOnTrack_MuonLink = tree["mmOnTrack_MuonLink"].array(library="np")
    muons_pt = tree["muons_pt"].array(library="np")
    muons_author = tree["muons_author"].array(library="np")

    mmPRDRandomSectorDumped = tree["mmPRDRandomSectorDumped"].array(library="np")
    PRD_MM_stripCharges = tree["PRD_MM_stripCharges"].array(library="np")
    PRD_MM_localPosX = tree["PRD_MM_localPosX"].array(library="np")
    PRD_MM_stripTimes = tree["PRD_MM_stripTimes"].array(library="np")
    PRD_MM_stationIndex = tree["PRD_MM_stationIndex"].array(library="np")
    PRD_MM_stationEta = tree["PRD_MM_stationEta"].array(library="np")

    mmOnTrack_globalPosX = tree["mmOnTrackGlobalPos_x"].array(library="np")
    mmOnTrack_globalPosY = tree["mmOnTrackGlobalPos_y"].array(library="np")
    mmOnTrack_globalPosZ = tree["mmOnTrackGlobalPos_z"].array(library="np")

    PRD_MM_globalPosX = tree["PRD_MM_globalPosX"].array(library="np")
    PRD_MM_globalPosY = tree["PRD_MM_globalPosY"].array(library="np")
    PRD_MM_globalPosZ = tree["PRD_MM_globalPosZ"].array(library="np")

    # --- Estrazione eventi di segnale ---
    signal_clusters = []
    for i in range(len(mmOnTrack_stripCharges)):
        charges = mmOnTrack_stripCharges[i]
        xs = mmOnTrack_localPosX[i]
        times = mmOnTrack_stripTimes[i]
        muon_link = mmOnTrack_MuonLink[i] if len(mmOnTrack_MuonLink[i]) > 0 else None
        pt = muons_pt[i] if len(muons_pt[i]) > 0 else None
        author = muons_author[i] if len(muons_author[i]) > 0 else None
        n_strips = [len(cluster) for cluster in charges]
        station_index = mmOnTrack_stationIndex[i] if len(mmOnTrack_stationIndex[i]) > 0 else None
        station_eta = mmOnTrack_stationEta[i] if len(mmOnTrack_stationEta[i]) > 0 else None
        global_x = mmOnTrack_globalPosX[i]
        global_y = mmOnTrack_globalPosY[i]
        global_z = mmOnTrack_globalPosZ[i]
        # Condizione: almeno 4 cluster di muoni
        selected = []
        if (
            pt is not None and author is not None and
            station_index is not None and station_eta is not None
        ):
            for j in range(len(charges)):
                if (
                    station_index[j] == 55 and
                    station_eta[j] == 1 and
                    pt[muon_link[j]] >= 15 and
                    author[muon_link[j]] == 1
                ):
                    selected.append(j)
        if len(selected) >= 4:
            signal_clusters.append({
                "charge": [charges[j] for j in selected],
                "localPosX": [xs[j] for j in selected],
                "stripTimes": [times[j] for j in selected],
                "n_strips": [n_strips[j] for j in selected],
                #"muons_pt": [pt[muon_link[j]] for j in selected],
                #"muons_author": [author[muon_link[j]] for j in selected],
                "globalPosX": [global_x[j] for j in selected],
                "globalPosY": [global_y[j] for j in selected],
                "globalPosZ": [global_z[j] for j in selected]
            })
        if len(signal_clusters) >= N_EVENTS:
            break

    print(f"Eventi di segnale salvati: {len(signal_clusters)}")

    # --- Estrazione eventi di rumore ---
    noise_clusters = []
    for i in range(len(PRD_MM_stripCharges)):
        charges = PRD_MM_stripCharges[i]
        xs = PRD_MM_localPosX[i]
        times = PRD_MM_stripTimes[i]
        n_strips = [len(cluster) for cluster in charges]
        random_sector = mmPRDRandomSectorDumped[i] if len(mmPRDRandomSectorDumped[i]) > 0 else None
        station_index = PRD_MM_stationIndex[i] if len(PRD_MM_stationIndex[i]) > 0 else None
        station_eta = PRD_MM_stationEta[i] if len(PRD_MM_stationEta[i]) > 0 else None
        global_x = PRD_MM_globalPosX[i]
        global_y = PRD_MM_globalPosY[i]
        global_z = PRD_MM_globalPosZ[i]
        # Trova gli indici dei cluster con station_index pari e station_eta == 1
        selected = []
        if (
            station_index is not None and station_eta is not None and 
            random_sector is not None and random_sector % 2 == 0
        ):
            for j in range(len(charges)):
                if station_index[j] == 55 and station_eta[j] == 1:
                    selected.append(j)
            noise_clusters.append({
                "charge": [charges[j] for j in selected],
                "localPosX": [xs[j] for j in selected],
                "stripTimes": [times[j] for j in selected],
                "n_strips": [n_strips[j] for j in selected],
                "globalPosX": [global_x[j] for j in selected],
                "globalPosY": [global_y[j] for j in selected],
                "globalPosZ": [global_z[j] for j in selected]
            })
        if len(noise_clusters) >= N_EVENTS:
            break

    print(f"Eventi di rumore salvati: {len(noise_clusters)}")


    file.close()
    np.save(signal_path, np.array(signal_clusters, dtype=object))
    np.save(noise_path, np.array(noise_clusters, dtype=object))
    print("Dati estratti e salvati in file numpy.")

Dati caricati da file numpy.


In [8]:
from sklearn.model_selection import train_test_split
import torch
from torch_geometric.data import Data
from sklearn.neighbors import NearestNeighbors
import numpy as np

# --- Crea tutti i grafi senza normalizzazione ---
event_graphs = []
k = 8  # Numero di vicini per ogni nodo

def extract_features(c):
    total_charge = [np.sum(cluster) for cluster in c["charge"]]
    mean_x = [np.mean(cluster) for cluster in c["localPosX"]]
    mean_time = [np.mean(cluster) for cluster in c["stripTimes"]]
    n_strips = [len(cluster) for cluster in c["charge"]]
    features = np.stack([total_charge, mean_x, mean_time, n_strips], axis=1)
    global_x = [np.mean(cluster) for cluster in c["globalPosX"]]
    global_y = [np.mean(cluster) for cluster in c["globalPosY"]]
    global_z = [np.mean(cluster) for cluster in c["globalPosZ"]]
    global_positions = np.stack([global_x, global_y, global_z], axis=1)
    return features, global_positions

for c in signal_clusters:
    features, global_positions = extract_features(c)
    labels = np.ones(features.shape[0], dtype=np.int64)
    if len(features) < 2:
        continue
    x = torch.tensor(features, dtype=torch.float)
    y = torch.tensor(labels, dtype=torch.long)
    pos = torch.tensor(global_positions, dtype=torch.float)
    N = x.shape[0]
    coords = x[:, 1:3].numpy()
    nbrs = NearestNeighbors(n_neighbors=min(k+1, N), algorithm='ball_tree').fit(coords)
    _, indices = nbrs.kneighbors(coords)
    edge_index = []
    for idx, neighbors in enumerate(indices):
        for n in neighbors[1:]:
            edge_index.append([idx, n])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    event_graphs.append(Data(x=x, edge_index=edge_index, y=y, pos=pos))

for c in noise_clusters:
    features, global_positions = extract_features(c)
    labels = np.zeros(features.shape[0], dtype=np.int64)
    if len(features) < 2:
        continue
    x = torch.tensor(features, dtype=torch.float)
    y = torch.tensor(labels, dtype=torch.long)
    pos = torch.tensor(global_positions, dtype=torch.float)
    N = x.shape[0]
    coords = x[:, 1:3].numpy()
    nbrs = NearestNeighbors(n_neighbors=min(k+1, N), algorithm='ball_tree').fit(coords)
    _, indices = nbrs.kneighbors(coords)
    edge_index = []
    for idx, neighbors in enumerate(indices):
        for n in neighbors[1:]:
            edge_index.append([idx, n])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    event_graphs.append(Data(x=x, edge_index=edge_index, y=y, pos=pos))

print(f"Numero di eventi/grafi creati: {len(event_graphs)}")

# --- Split train/val/test a livello di evento ---
train_graphs, temp_graphs = train_test_split(event_graphs, test_size=0.4, random_state=42)
val_graphs, test_graphs = train_test_split(temp_graphs, test_size=0.5, random_state=42)
print(f"Train set: {len(train_graphs)} eventi, Validation set: {len(val_graphs)} eventi, Test set: {len(test_graphs)} eventi")

# --- Calcola mean e std SOLO sui dati di train ---
all_train_features = []
for data in train_graphs:
    all_train_features.append(data.x.numpy())
all_train_features = np.concatenate(all_train_features, axis=0)
mean = all_train_features.mean(axis=0)
std = all_train_features.std(axis=0)

def normalize_features(features, mean, std):
    return (features - mean) / std

# --- Normalizza tutti i set con mean/std del train ---
for dataset in [train_graphs, val_graphs, test_graphs]:
    for data in dataset:
        data.x = torch.tensor(normalize_features(data.x.numpy(), mean, std), dtype=torch.float)

Numero di eventi/grafi creati: 18605
Train set: 11163 eventi, Validation set: 3721 eventi, Test set: 3721 eventi


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def evaluate_metrics(model, loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = torch.sigmoid(model(data))
            preds = (out > 0.5).long().cpu().numpy()
            labels = data.y.cpu().numpy()
            all_preds.append(preds)
            all_labels.append(labels)
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, average='binary', zero_division=0)
    rec = recall_score(all_labels, all_preds, average='binary', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='binary', zero_division=0)
    return acc, prec, rec, f1

# --- Modello GNN per classificazione grafo ---
channels = 4  # total_charge, mean_x, mean_time, n_strips
class GNN(nn.Module):
    def __init__(self, in_channels=channels, hidden_channels=64, out_channels=1, num_layers=3, dropout=0.2):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        for _ in range(num_layers-1):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        self.dropout = nn.Dropout(dropout)
        self.lin = nn.Linear(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
            x = self.dropout(x)
        x = self.lin(x)
        return x.view(-1)


In [10]:
import time

# --- Training e test ---
batch_size = 16
train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=batch_size)
test_loader = DataLoader(test_graphs, batch_size=batch_size)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = GNN(hidden_channels = 32, out_channels = 1, num_layers = 6, dropout = 0.5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Calcola il rapporto tra classi
loss_fn = nn.BCEWithLogitsLoss()

epochs = 20
for epoch in range(epochs):
    start_time = time.time()
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        target = data.y.float().to(device)
        loss = loss_fn(out, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    avg_train_loss = total_loss / len(train_loader.dataset)
    elapsed = time.time() - start_time

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            out = model(data)
            target = data.y.float().to(device)
            loss = loss_fn(out, target)
            val_loss += loss.item() * data.num_graphs
    avg_val_loss = val_loss / len(val_loader.dataset)
    print(f"Epoch {epoch+1}/{epochs} - Train loss: {avg_train_loss:.4f} - Validation loss: {avg_val_loss:.4f} - Time: {elapsed:.2f}s")

print("Training complete.")
# Test accuracy
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        data = data.to(device)
        out = torch.sigmoid(model(data))
        pred = (out > 0.5).long()
        correct += (pred == data.y.to(device)).sum().item()
        total += data.y.size(0)
print(f"Test accuracy: {correct/total:.2%}")

val_acc, val_prec, val_rec, val_f1 = evaluate_metrics(model, val_loader, device)
print(f"Val metrics - Acc: {val_acc:.3f} Prec: {val_prec:.3f} Rec: {val_rec:.3f} F1: {val_f1:.3f}")

test_acc, test_prec, test_rec, test_f1 = evaluate_metrics(model, test_loader, device)
print(f"Test metrics - Acc: {test_acc:.3f} Prec: {test_prec:.3f} Rec: {test_rec:.3f} F1: {test_f1:.3f}")

Using device: cuda
Epoch 1/20 - Train loss: 0.2355 - Validation loss: 0.1439 - Time: 3.56s
Epoch 2/20 - Train loss: 0.1416 - Validation loss: 0.1150 - Time: 3.48s
Epoch 3/20 - Train loss: 0.1253 - Validation loss: 0.1090 - Time: 3.90s
Epoch 4/20 - Train loss: 0.1270 - Validation loss: 0.1059 - Time: 7.07s
Epoch 5/20 - Train loss: 0.1250 - Validation loss: 0.1022 - Time: 7.31s
Epoch 6/20 - Train loss: 0.1229 - Validation loss: 0.1109 - Time: 7.34s
Epoch 7/20 - Train loss: 0.1170 - Validation loss: 0.1003 - Time: 7.20s
Epoch 8/20 - Train loss: 0.1143 - Validation loss: 0.1017 - Time: 7.00s
Epoch 9/20 - Train loss: 0.1118 - Validation loss: 0.1100 - Time: 7.27s
Epoch 10/20 - Train loss: 0.1148 - Validation loss: 0.0968 - Time: 4.36s
Epoch 11/20 - Train loss: 0.1099 - Validation loss: 0.1016 - Time: 3.79s
Epoch 12/20 - Train loss: 0.1081 - Validation loss: 0.1044 - Time: 3.63s
Epoch 13/20 - Train loss: 0.1084 - Validation loss: 0.0993 - Time: 3.42s
Epoch 14/20 - Train loss: 0.1106 - Valida

In [11]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import os

# Crea la cartella se non esiste
output_dir = "images/atlas"
os.makedirs(output_dir, exist_ok=True)

for j in range(20):
    data_es = test_graphs[j]
    x = data_es.x.cpu().numpy()
    labels_true = data_es.y.cpu().numpy()
    pos = data_es.pos.cpu().numpy()  # shape: [num_clusters, 3]

    model.eval()
    with torch.no_grad():
        out = torch.sigmoid(model(data_es.to(device)))
        pred_labels = (out.squeeze() > 0.5).cpu().numpy()

    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')

    # Plot rumore (label 0)
    mask_noise = labels_true == 0
    ax.scatter(pos[mask_noise, 0], pos[mask_noise, 1], pos[mask_noise, 2], c='gray', s=40, label='Rumore', alpha=0.7)

    # Plot segnale (label 1)
    mask_signal = labels_true == 1
    ax.scatter(pos[mask_signal, 0], pos[mask_signal, 1], pos[mask_signal, 2], c='orange', s=60, label='Segnale', alpha=0.8)

    # Cerchio blu attorno ai cluster predetti come segnale dal modello
    for i in range(len(pos)):
        if pred_labels[i] == 1:
            ax.plot([pos[i,0]], [pos[i,1]], [pos[i,2]], marker='o', markersize=18, markerfacecolor='none', markeredgecolor='blue', markeredgewidth=2, alpha=0.7)

    # Zoom automatico sui dati
    margin = 0.05
    for idx, set_lim in enumerate([ax.set_xlim, ax.set_ylim, ax.set_zlim]):
        data = pos[:, idx]
        delta = (data.max() - data.min()) * margin
        set_lim(data.min() - delta, data.max() + delta)

    ax.set_xlabel('Global X')
    ax.set_ylabel('Global Y')
    ax.set_zlabel('Global Z')
    ax.set_title(f'Evento {j}: arancione=vero segnale, grigio=rumore, blu=predetto segnale')
    ax.legend()
    plt.tight_layout()
    plt.savefig(f"{output_dir}/event_{j:02d}.png")
    plt.close(fig)
print("Salvati i primi 20 esempi in images/atlas/")

Salvati i primi 20 esempi in images/atlas/
