In [26]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, BatchNorm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


In [27]:
# =============================================================================
# PARAMETRI DI CLUSTERING
# =============================================================================
CLUSTER_PARAMS = dict(
    charge_threshold=750,
    min_consecutive_strips=2,
    max_gap=2,
    max_cluster_size=15,
    max_internal_gap=2
)


In [28]:

# =============================================================================
# FUNZIONI DI CLUSTERING
# =============================================================================

def cluster_strips_with_charges(event, params):
    """Identifica cluster consecutivi di strip con carica sopra soglia."""
    clusters = []
    charges = []
    current_cluster = []
    current_charges = []
    previous_index = None

    for strip_index, charge in enumerate(event):
        if charge >= params['charge_threshold']:
            if previous_index is not None and strip_index - previous_index - 1 >= params['max_gap']:
                if params['min_consecutive_strips'] <= len(current_cluster) < params['max_cluster_size']:
                    n_gaps = (max(current_cluster) - min(current_cluster) + 1) - len(current_cluster)
                    if n_gaps <= params['max_internal_gap']:
                        clusters.append(current_cluster)
                        charges.append(current_charges)
                current_cluster, current_charges = [], []
            current_cluster.append(strip_index)
            current_charges.append(charge)
            previous_index = strip_index

    # Aggiungi ultimo cluster se valido
    if params['min_consecutive_strips'] <= len(current_cluster) < params['max_cluster_size']:
        n_gaps = (max(current_cluster) - min(current_cluster) + 1) - len(current_cluster)
        if n_gaps <= params['max_internal_gap']:
            clusters.append(current_cluster)
            charges.append(current_charges)

    return clusters, charges

def assign_labels(event, params):
    """Assegna etichetta 1 alle strip che fanno parte di un cluster valido."""
    labels = np.zeros(len(event), dtype=int)
    clusters, _ = cluster_strips_with_charges(event, params)
    for cluster in clusters:
        for idx in range(min(cluster), max(cluster) + 1):
            labels[idx] = 1
    return labels, len(clusters)


In [29]:

# =============================================================================
# CARICAMENTO DATI E LABELING
# =============================================================================

x_events = np.load('data/mimega/signal_x.npy', allow_pickle=True)
y_events = np.load('data/mimega/signal_y.npy', allow_pickle=True)

x_labels, y_labels = [], []
n_x_clusters = n_y_clusters = 0

for x_ev, y_ev in zip(x_events, y_events):
    x_lab, n_x = assign_labels(x_ev, CLUSTER_PARAMS)
    y_lab, n_y = assign_labels(y_ev, CLUSTER_PARAMS)
    x_labels.append(x_lab)
    y_labels.append(y_lab)
    n_x_clusters += n_x
    n_y_clusters += n_y

print(f"Numero di cluster trovati in X: {n_x_clusters}")
print(f"Numero di cluster trovati in Y: {n_y_clusters}")


Numero di cluster trovati in X: 908
Numero di cluster trovati in Y: 3047


In [30]:

# =============================================================================
# CREAZIONE DEI GRAFICI
# =============================================================================

def events_to_graphs(events, labels):
    """Trasforma eventi e label in grafi PyG per la GNN."""
    data_list = []
    for event, label in zip(events, labels):
        event = np.array(event, dtype=np.float32)
        label = np.array(label, dtype=np.int64)
        x = torch.tensor(event, dtype=torch.float).unsqueeze(1)
        y = torch.tensor(label, dtype=torch.long)
        N = x.shape[0]
        edge_index = torch.tensor(
            [[i, i+1] for i in range(N-1)] + [[i+1, i] for i in range(N-1)],
            dtype=torch.long
        ).t().contiguous()
        data_list.append(Data(x=x, edge_index=edge_index, y=y))
    return data_list

# Split train/val/test
all_events = np.concatenate([x_events, y_events])
all_labels = np.concatenate([x_labels, y_labels])
all_data = list(zip(all_events, all_labels))

train_data, temp_data = train_test_split(all_data, test_size=0.3, random_state=42)
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)

train_events, train_labels = zip(*train_data)
val_events, val_labels = zip(*val_data)
test_events, test_labels = zip(*test_data)

train_graphs = events_to_graphs(train_events, train_labels)
val_graphs = events_to_graphs(val_events, val_labels)
test_graphs = events_to_graphs(test_events, test_labels)

train_loader = DataLoader(train_graphs, batch_size=32, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=32)
test_loader = DataLoader(test_graphs, batch_size=32)


In [31]:

# =============================================================================
# DEFINIZIONE MODELLO GNN
# =============================================================================

channels = 1  # charge

class GNN(nn.Module):
    def __init__(self, in_channels=channels, hidden_channels=64, out_channels=1, num_layers=3, dropout=0.2):
        super().__init__()
        self.pre_mlp = nn.Sequential(
            nn.Linear(in_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU()
        )
        self.convs = nn.ModuleList([GCNConv(hidden_channels, hidden_channels) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)
        self.post_mlp = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.pre_mlp(x)
        for conv in self.convs:
            x_res = x
            x = F.relu(conv(x, edge_index))
            x = self.dropout(x)
            x = x + x_res 
        x = self.post_mlp(x)
        return x.view(-1)

In [32]:

# =============================================================================
# INIZIALIZZAZIONE E LOSS
# =============================================================================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = GNN(in_channels=1, hidden_channels=32, num_layers=4, dropout=0.5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

pos_weight = torch.tensor([5], dtype=torch.float32).to(device)
loss_fn = nn.BCEWithLogitsLoss(pos_weight = pos_weight)


Using device: cuda


In [33]:

# =============================================================================
# SALVATAGGIO E CARICAMENTO MODELLO
# =============================================================================

def save_checkpoint(model, optimizer, epoch, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, path)

def save_model(model, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    print(f"Modello finale salvato in {path}")

def load_model(model, optimizer, checkpoint_path, load_checkpoint=True):
    if load_checkpoint and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        if optimizer and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint.get('epoch', 0) + 1
        print(f"Checkpoint caricato da {checkpoint_path}, riprendo da epoch {start_epoch}")
    else:
        print("Nessun checkpoint caricato.")
        start_epoch = 0
    return model, optimizer, start_epoch

checkpoint_path = "gnn_model/mimega/checkpoint.pt"
model_path = "gnn_model/mimega/model.pt"
model, optimizer, start_epoch = load_model(model, optimizer, checkpoint_path, load_checkpoint=False)


Nessun checkpoint caricato.


In [34]:

# =============================================================================
# TRAINING
# =============================================================================

epochs = 20
for epoch in range(start_epoch, epochs):
    start = time.time()
    model.train()
    total_loss = 0
    total_samples = 0
    for batch in train_loader:
        batch = batch.to(device)
        output = model(batch).squeeze(-1)
        loss = loss_fn(output, batch.y.float())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item() * batch.y.size(0)
        total_samples += batch.y.size(0)
    train_loss = total_loss / total_samples

    model.eval()
    val_loss, val_samples = 0, 0
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            output = model(batch).squeeze(-1)
            loss = loss_fn(output, batch.y.float())
            val_loss += loss.item() * batch.y.size(0)
            val_samples += batch.y.size(0)
    val_loss /= val_samples

    print(f"Epoch {epoch+1}/{epochs} - Train loss: {train_loss:.4f} - Validation loss: {val_loss:.4f} - Time: {time.time()-start:.2f}s")
    save_checkpoint(model, optimizer, epoch, checkpoint_path)


print("Training completo.")
save_model(model, model_path)

Epoch 1/20 - Train loss: 0.2116 - Validation loss: 0.0137 - Time: 3.33s
Epoch 2/20 - Train loss: 0.0141 - Validation loss: 0.0072 - Time: 3.30s
Epoch 3/20 - Train loss: 0.0098 - Validation loss: 0.0071 - Time: 3.28s
Epoch 4/20 - Train loss: 0.0111 - Validation loss: 0.0076 - Time: 3.26s
Epoch 5/20 - Train loss: 0.0081 - Validation loss: 0.0040 - Time: 3.26s
Epoch 6/20 - Train loss: 0.0094 - Validation loss: 0.0049 - Time: 3.28s
Epoch 7/20 - Train loss: 0.0072 - Validation loss: 0.0037 - Time: 3.26s
Epoch 8/20 - Train loss: 0.0065 - Validation loss: 0.0080 - Time: 3.26s
Epoch 9/20 - Train loss: 0.0065 - Validation loss: 0.0038 - Time: 3.25s
Epoch 10/20 - Train loss: 0.0064 - Validation loss: 0.0045 - Time: 3.25s
Epoch 11/20 - Train loss: 0.0062 - Validation loss: 0.0039 - Time: 3.25s
Epoch 12/20 - Train loss: 0.0063 - Validation loss: 0.0116 - Time: 3.26s
Epoch 13/20 - Train loss: 0.0063 - Validation loss: 0.0058 - Time: 3.26s
Epoch 14/20 - Train loss: 0.0059 - Validation loss: 0.0060 -

In [35]:

# =============================================================================
# VALUTAZIONE DEL MODELLO
# =============================================================================

def evaluate_metrics(model, loader, device):
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            output = model(batch).squeeze()
            preds.append((output > 0.5).cpu().numpy())
            labels.append(batch.y.cpu().numpy())
    preds = np.concatenate(preds)
    labels = np.concatenate(labels)
    return (
        accuracy_score(labels, preds),
        precision_score(labels, preds, zero_division=0),
        recall_score(labels, preds, zero_division=0),
        f1_score(labels, preds, zero_division=0)
    )

val_acc, val_prec, val_rec, val_f1 = evaluate_metrics(model, val_loader, device)
print(f"Val metrics - Acc: {val_acc:.3f} Prec: {val_prec:.3f} Rec: {val_rec:.3f} F1: {val_f1:.3f}")

test_acc, test_prec, test_rec, test_f1 = evaluate_metrics(model, test_loader, device)
print(f"Test metrics - Acc: {test_acc:.3f} Prec: {test_prec:.3f} Rec: {test_rec:.3f} F1: {test_f1:.3f}")


Val metrics - Acc: 0.999 Prec: 0.678 Rec: 0.992 F1: 0.806
Test metrics - Acc: 0.998 Prec: 0.653 Rec: 0.992 F1: 0.788


In [36]:

# =============================================================================
# VISUALIZZAZIONE RISULTATI SU 20 EVENTI
# =============================================================================

os.makedirs("images/mimega", exist_ok=True)
model.eval()
count = 0

with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        for i in range(batch.num_graphs):
            data = batch[i]
            charges = data.x.cpu().numpy().flatten()
            true_labels = data.y.cpu().numpy()
            pred_labels = (torch.sigmoid(model(data)).squeeze().cpu().numpy() > 0.5).astype(int)

            plt.figure(figsize=(10, 5))
            plt.subplot(2, 1, 1)
            plt.bar(np.arange(len(charges)), charges, color='gray', alpha=0.7)
            plt.title('Carica per strip (test event)')
            plt.xlabel('Strip')
            plt.ylabel('Carica')

            plt.subplot(2, 1, 2)
            plt.plot(true_labels, label='Label ricostruite', drawstyle='steps-mid')
            plt.plot(pred_labels, label='Predizione modello', drawstyle='steps-mid', alpha=0.7)
            plt.xlabel('Strip')
            plt.ylabel('Cluster')
            plt.legend()
            plt.title('Confronto: Ricostruito vs Predetto')
            plt.tight_layout()
            plt.savefig(f"images/mimega/test_example_{count+1}.png")
            plt.close()
            count += 1
            if count >= 20:
                break
        if count >= 20:
            break

print("Salvati i primi 20 plot in 'images/mimega'")


Salvati i primi 20 plot in 'images/mimega'
