In [27]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.neighbors import NearestNeighbors
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv


In [28]:

# === Caricamento dei dati ===
signal_path = "data/atlas/signal_clusters.npy"
noise_path = "data/atlas/noise_clusters.npy"

signal_clusters = np.load(signal_path, allow_pickle=True).tolist()
noise_clusters = np.load(noise_path, allow_pickle=True).tolist()
print("Dati caricati da file numpy.")


Dati caricati da file numpy.


In [29]:

# === Funzione per estrazione delle feature ===
def extract_features(c):
    total_charge = [np.sum(cluster) for cluster in c["charge"]]
    x = c["localPosX"]
    mean_time = [np.average(t, weights=q) for t, q in zip(c["stripTimes"], c["charge"])]
    n_strips = [len(cluster) for cluster in c["charge"]]
    features = np.stack([total_charge, x, mean_time, n_strips], axis=1)

    global_x = c["globalPosX"]
    global_y = c["globalPosY"]
    global_z = c["globalPosZ"]
    global_positions = np.stack([global_x, global_y, global_z], axis=1)

    return features, global_positions


In [30]:

# === Creazione dei grafi per ogni evento ===
event_graphs = []
k = 4  # numero di vicini per il grafo

def create_graphs(clusters, label):
    for c in clusters:
        features, global_positions = extract_features(c)
        if len(features) < 2:
            continue
        x = torch.tensor(features, dtype=torch.float)
        y = torch.tensor(np.full(features.shape[0], label), dtype=torch.long)
        pos = torch.tensor(global_positions, dtype=torch.float)
        coords = global_positions
        N = x.shape[0]
        nbrs = NearestNeighbors(n_neighbors=min(k+1, N), algorithm='ball_tree').fit(coords)
        _, indices = nbrs.kneighbors(coords)
        edge_index = []
        for idx, neighbors in enumerate(indices):
            for n in neighbors[1:]:
                edge_index.append([idx, n])
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        event_graphs.append(Data(x=x, edge_index=edge_index, y=y, pos=pos))

create_graphs(signal_clusters, label=1)
create_graphs(noise_clusters, label=0)

print(f"Numero di eventi/grafi creati: {len(event_graphs)}")


Numero di eventi/grafi creati: 20000


In [31]:

# === Divisione train/val/test ===
train_graphs, temp_graphs = train_test_split(event_graphs, test_size=0.4, random_state=42)
val_graphs, test_graphs = train_test_split(temp_graphs, test_size=0.5, random_state=42)
print(f"Train set: {len(train_graphs)} eventi, Validation set: {len(val_graphs)} eventi, Test set: {len(test_graphs)} eventi")


Train set: 12000 eventi, Validation set: 4000 eventi, Test set: 4000 eventi


In [32]:

# === Normalizzazione delle feature ===
all_train_features = np.concatenate([data.x.numpy() for data in train_graphs], axis=0)
mean = all_train_features.mean(axis=0)
std = all_train_features.std(axis=0)

def normalize_features(features, mean, std):
    return (features - mean) / std

for dataset in [train_graphs, val_graphs, test_graphs]:
    for data in dataset:
        data.x = torch.tensor(normalize_features(data.x.numpy(), mean, std), dtype=torch.float)


In [33]:

# === Modello GNN ===
channels = 4  # total_charge, mean_x, mean_time, n_strips

class GNN(nn.Module):
    def __init__(self, in_channels=channels, hidden_channels=64, out_channels=1, num_layers=3, dropout=0.2):
        super().__init__()
        self.pre_mlp = nn.Sequential(
            nn.Linear(in_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU()
        )
        self.convs = nn.ModuleList([GCNConv(hidden_channels, hidden_channels) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)
        self.post_mlp = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.pre_mlp(x)
        for conv in self.convs:
            x_res = x
            x = F.relu(conv(x, edge_index))
            x = self.dropout(x)
            x = x + x_res 
        x = self.post_mlp(x)
        return x.view(-1)


In [34]:

# === Funzione di valutazione ===
def evaluate_metrics(model, loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = torch.sigmoid(model(data))
            preds = (out > 0.5).long().cpu().numpy()
            labels = data.y.cpu().numpy()
            all_preds.append(preds)
            all_labels.append(labels)
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, average='binary', zero_division=0)
    rec = recall_score(all_labels, all_preds, average='binary', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='binary', zero_division=0)
    return acc, prec, rec, f1


In [35]:
# === Inizializzazione e loss ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = GNN(hidden_channels=32, out_channels=1, num_layers=4, dropout=0.5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
loss_fn = nn.BCEWithLogitsLoss()

train_loader = DataLoader(train_graphs, batch_size=16, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=16)
test_loader = DataLoader(test_graphs, batch_size=16)

Using device: cuda


In [36]:
# === Funzioni di checkpoint ===
def save_checkpoint(model, optimizer, epoch, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, path)

def save_model(model, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    print(f"Modello finale salvato in {path}")

def load_model(model, optimizer, checkpoint_path, load_checkpoint=True):
    if load_checkpoint and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        if optimizer and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint.get('epoch', 0) + 1
        print(f"Checkpoint caricato da {checkpoint_path}, riprendo da epoch {start_epoch}")
    else:
        print("Nessun checkpoint caricato.")
        start_epoch = 0
    return model, optimizer, start_epoch

checkpoint_path = "gnn_model/atlas/checkpoint.pt"
model_path = "gnn_model/atlas/model.pt"
model, optimizer, start_epoch = load_model(model, optimizer, checkpoint_path, load_checkpoint=False)


Nessun checkpoint caricato.


In [37]:

# === Training ===
epochs = 20
for epoch in range(start_epoch, epochs):
    start_time = time.time()
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        target = data.y.float().to(device)
        loss = loss_fn(out, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    avg_train_loss = total_loss / len(train_loader.dataset)

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            out = model(data)
            target = data.y.float().to(device)
            loss = loss_fn(out, target)
            val_loss += loss.item() * data.num_graphs
    avg_val_loss = val_loss / len(val_loader.dataset)
    elapsed = time.time() - start_time

    print(f"Epoch {epoch+1}/{epochs} - Train loss: {avg_train_loss:.4f} - Validation loss: {avg_val_loss:.4f} - Time: {elapsed:.2f}s")
    save_checkpoint(model, optimizer, epoch, checkpoint_path)

print("Training completo.")
save_model(model, model_path)

Epoch 1/20 - Train loss: 0.1290 - Validation loss: 0.0831 - Time: 3.81s
Epoch 2/20 - Train loss: 0.0824 - Validation loss: 0.0779 - Time: 3.70s
Epoch 3/20 - Train loss: 0.0760 - Validation loss: 0.0728 - Time: 3.66s
Epoch 4/20 - Train loss: 0.0713 - Validation loss: 0.0696 - Time: 3.71s
Epoch 5/20 - Train loss: 0.0689 - Validation loss: 0.0679 - Time: 3.66s
Epoch 6/20 - Train loss: 0.0680 - Validation loss: 0.0647 - Time: 3.64s
Epoch 7/20 - Train loss: 0.0640 - Validation loss: 0.0698 - Time: 3.63s
Epoch 8/20 - Train loss: 0.0647 - Validation loss: 0.0658 - Time: 3.68s
Epoch 9/20 - Train loss: 0.0632 - Validation loss: 0.0632 - Time: 3.63s
Epoch 10/20 - Train loss: 0.0619 - Validation loss: 0.0634 - Time: 3.59s
Epoch 11/20 - Train loss: 0.0631 - Validation loss: 0.0661 - Time: 3.65s
Epoch 12/20 - Train loss: 0.0610 - Validation loss: 0.0655 - Time: 3.72s
Epoch 13/20 - Train loss: 0.0593 - Validation loss: 0.0618 - Time: 3.74s
Epoch 14/20 - Train loss: 0.0604 - Validation loss: 0.0629 -

In [38]:

# === Metriche ===
val_acc, val_prec, val_rec, val_f1 = evaluate_metrics(model, val_loader, device)
print(f"Val metrics - Acc: {val_acc:.3f} Prec: {val_prec:.3f} Rec: {val_rec:.3f} F1: {val_f1:.3f}")

test_acc, test_prec, test_rec, test_f1 = evaluate_metrics(model, test_loader, device)
print(f"Test metrics - Acc: {test_acc:.3f} Prec: {test_prec:.3f} Rec: {test_rec:.3f} F1: {test_f1:.3f}")


Val metrics - Acc: 0.981 Prec: 0.923 Rec: 0.950 F1: 0.936
Test metrics - Acc: 0.982 Prec: 0.932 Rec: 0.960 F1: 0.946


In [39]:

# === Visualizzazione esempi ===
output_dir = "images/atlas"
os.makedirs(output_dir, exist_ok=True)

for j in range(20):
    data_es = test_graphs[j]
    x = data_es.x.cpu().numpy()
    labels_true = data_es.y.cpu().numpy()
    pos = data_es.pos.cpu().numpy()

    with torch.no_grad():
        out = torch.sigmoid(model(data_es.to(device)))
        pred_labels = (out.squeeze() > 0.5).cpu().numpy()

    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')

    ax.scatter(pos[labels_true == 0, 0], pos[labels_true == 0, 1], pos[labels_true == 0, 2],
               c='gray', s=40, label='Rumore', alpha=0.7)
    ax.scatter(pos[labels_true == 1, 0], pos[labels_true == 1, 1], pos[labels_true == 1, 2],
               c='orange', s=60, label='Segnale', alpha=0.8)

    for i in range(len(pos)):
        if pred_labels[i] == 1:
            ax.plot([pos[i, 0]], [pos[i, 1]], [pos[i, 2]], marker='o', markersize=18,
                    markerfacecolor='none', markeredgecolor='blue', markeredgewidth=2, alpha=0.7)

    margin = 0.05
    for idx, set_lim in enumerate([ax.set_xlim, ax.set_ylim, ax.set_zlim]):
        data = pos[:, idx]
        delta = (data.max() - data.min()) * margin
        set_lim(data.min() - delta, data.max() + delta)

    ax.set_xlabel('Global X')
    ax.set_ylabel('Global Y')
    ax.set_zlabel('Global Z')
    ax.set_title(f'Evento {j}: arancione=vero segnale, grigio=rumore, blu=predetto segnale')
    ax.legend()
    plt.tight_layout()
    plt.savefig(f"{output_dir}/event_{j:02d}.png")
    plt.close(fig)

print("Salvati i primi 20 esempi in images/atlas/")


Salvati i primi 20 esempi in images/atlas/
