In [None]:
import os
import time
import numpy as np
import torch
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

class GCNTrainer:
    class EarlyStoppingConfig:
        def __init__(self, enabled=False, patience=10, threshold=1e-2):
            self.enabled = enabled
            self.patience = patience
            self.threshold = threshold

    class CorruptionConfig:
        def __init__(self, enabled=False, percentage=0.1):
            self.enabled = enabled
            self.percentage = percentage

    def __init__(self, model, optimizer, criterion, device='cpu'):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        self.model.to(device)
        self.best_model_state_dict = None

    def _apply_corruption(self, g_train, corruption_config):
        """Applica la corruzione alla rete se abilitata."""
        if corruption_config.enabled:
            g_corrupted = corruptNetwork(g_train, corruption_config.percentage)
            print("The network has been corrupted.")
            return graphToEdgelist(g_corrupted)
        return None

    def _early_stopping(self, val_auc, best_val_auc, early_stopping_counter, early_stopping_config):
        """Gestisce la logica per l'early stopping."""
        if val_auc - best_val_auc > early_stopping_config.threshold:
            best_val_auc = val_auc
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1

        if early_stopping_config.enabled and early_stopping_counter >= early_stopping_config.patience:
            print("Early stopping triggered.")
            return True, best_val_auc, early_stopping_counter
        
        return False, best_val_auc, early_stopping_counter

    def _perform_negative_sampling(self, edge_index, num_nodes, num_neg_samples):
        """Esegue il negative sampling."""
        neg_edge_index = negative_sampling(
            edge_index=edge_index, num_nodes=num_nodes,
            num_neg_samples=num_neg_samples, method='sparse')

        return neg_edge_index

    def train(self, train_data, g_train, corruption_config=CorruptionConfig(), shuffle=False):
        self.model.train()
        self.optimizer.zero_grad()

        # (Possibly) corrupt the network
        train_edge_index_corrupted = self._apply_corruption(g_train, corruption_config)
        if train_edge_index_corrupted is None:
            train_edge_index_corrupted = train_data.edge_index

        # Encoding
        z = self.model.encode(train_data.x.to(self.device), train_edge_index_corrupted.to(self.device))

        # Negative sampling
        neg_edge_index = self._perform_negative_sampling(
            edge_index=train_edge_index_corrupted, 
            num_nodes=train_data.num_nodes,
            num_neg_samples=train_data.edge_label_index.size(1)
        )

        edge_label_index = torch.cat([train_data.edge_label_index, neg_edge_index], dim=-1)
        edge_label = torch.cat([train_data.edge_label, 
                                train_data.edge_label.new_zeros(neg_edge_index.size(1))], dim=0)

        # (Possibly) shuffle the edge_label_index
        if shuffle:
            shuffle_idx = torch.randperm(edge_label_index.size(1))
            edge_label_index = edge_label_index[:, shuffle_idx]
            edge_label = edge_label[shuffle_idx]

        # Decoding
        out = self.model.decode(z, edge_label_index.to(self.device)).view(-1)

        # Loss and backpropagation
        loss = self.criterion(out, edge_label.to(self.device))
        loss.backward()
        self.optimizer.step()

        # Performance metrics
        accuracy = ((out > 0.5).float() == edge_label.to(self.device)).float().mean().item()
        auc = roc_auc_score(edge_label.cpu().numpy(), out.cpu().detach().numpy())

        return loss.item(), accuracy, auc

    @torch.no_grad()
    def test(self, data, corruption_config=CorruptionConfig(), full_output=False):
        self.model.eval()

        # (Possibly) corrupt the network
        test_edge_index_corrupted = self._apply_corruption(nx.from_edgelist(data.edge_index.t().tolist()), corruption_config)
        if test_edge_index_corrupted is None:
            test_edge_index_corrupted = data.edge_index

        # Encoding and decoding
        z = self.model.encode(data.x.to(self.device), test_edge_index_corrupted.to(self.device))
        out = self.model.decode(z, data.edge_label_index.to(self.device)).view(-1).sigmoid()

        out_cpu = out.cpu().numpy()
        label_cpu = data.edge_label.cpu().numpy()

        # Performance metrics
        accuracy = ((out > 0.5).float().cpu().numpy() == label_cpu).mean()
        auc = roc_auc_score(label_cpu, out_cpu)

        if full_output:
            return accuracy, auc, out_cpu, label_cpu
        else:
            return accuracy, auc

    def train_model(self, train_data, val_data, test_data, 
                    epochs=100, early_stopping_config=EarlyStoppingConfig(),
                    corruption_config=CorruptionConfig(), save_dir='data', run_timestamp=None, save_best_model=False):
        
        # Get the networkx graphs from the edge lists (used if corrupt=True)
        g_train = nx.from_edgelist(train_data.edge_index.t().tolist())
        
        # Create save directory if it doesn't exist
        os.makedirs(save_dir, exist_ok=True)
        
        # Initialize variables
        best_val_auc = 0
        train_acc_history, train_auc_history, val_acc_history, val_auc_history, loss_history = [np.zeros(epochs) for _ in range(5)]
        
        early_stopping_counter = 0
        total_train_time = 0
        total_val_time = 0

        # Genera un timestamp unico per l'esecuzione
        run_timestamp = time.strftime("%Y%m%d-%H%M%S")

        # Training epochs
        for epoch in tqdm(range(1, epochs + 1)):
            # Training
            train_start_time = time.time()
            loss, train_accuracy, train_auc = self.train(train_data, g_train, corruption_config=corruption_config)
            train_duration = time.time() - train_start_time
            total_train_time += train_duration
            
            # Validation
            val_start_time = time.time()
            val_acc, val_auc = self.test(val_data, corruption_config=corruption_config)
            val_duration = time.time() - val_start_time
            total_val_time += val_duration
            
            # Save metrics
            loss_history[epoch - 1] = loss
            train_acc_history[epoch - 1] = train_accuracy
            train_auc_history[epoch - 1] = train_auc
            val_acc_history[epoch - 1] = val_acc
            val_auc_history[epoch - 1] = val_auc

            # Evaluate best model
            stop, best_val_auc, early_stopping_counter = self._early_stopping(val_auc, best_val_auc, early_stopping_counter, early_stopping_config)
            if stop:
                print(f"Early stopping at epoch {epoch}.")
                return (train_acc_history[:epoch], train_auc_history[:epoch], 
                        val_acc_history[:epoch], val_auc_history[:epoch], 
                        loss_history[:epoch], run_timestamp)
            
            if save_best_model:
                self.best_model_state_dict = self.model.state_dict()  # Salva il miglior stato del modello

        print(f"Total training time: {total_train_time:.4f} seconds.")
        print(f"Total validation time: {total_val_time:.4f} seconds.")

        return train_acc_history[:epoch], train_auc_history[:epoch], val_acc_history[:epoch], val_auc_history[:epoch], loss_history[:epoch], run_timestamp
