In [None]:
########################################
# Single Cell Code for Multi-Seed Training
########################################

import sys
import os 
sys.path.append('../../')
sys.path.append('../')
from time import time
import logging
import os.path as osp
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.utils import degree
from torch.autograd import Variable
import random
from torch.optim.lr_scheduler import StepLR

# Import your utilities and models
from gmixupUtils import stat_graph, split_class_graphs, align_graphs
from gmixupUtils import two_graphons_mixup, universal_svd, get_graphon
from gmixupUtils import GIN
from Moment.tools import motifs_to_induced_motifs, lipmlp, train_momentnet
from SIGL.tools import *
import networkx as nx
from torch_geometric.utils import dense_to_sparse
import subprocess as sp
from scipy.special import comb

ORCA_DIR = '../../orca/'

def edge_list_reindexed(G):
    idx = 0
    id2idx = dict()
    for u in G.nodes():
        id2idx[str(u)] = idx
        idx += 1

    edges = []
    for (u, v) in G.edges():
        edges.append((id2idx[str(u)], id2idx[str(v)]))
    return edges


def orca(graph):
    tmp_file_path = os.path.join(ORCA_DIR, f'tmptmp-{random.random():.4f}.txt')
    f = open(tmp_file_path, 'w+')
    f.write(str(graph.number_of_nodes()) + ' ' + str(graph.number_of_edges()) + '\n')
    for (u, v) in edge_list_reindexed(graph):
        f.write(str(u) + ' ' + str(v) + '\n')
    f.close()


    output = sp.check_output(["../../orca/orca",'4', tmp_file_path, 'outputnew.txt'])
    with open('outputnew.txt', 'r') as file:
        output = file.read()
    output = output.strip()
    node_orbit_counts = np.array([list(map(int, node_cnts.strip().split(' ')))
                                  for node_cnts in output.strip('\n').split('\n')])
    try:
        os.remove(tmp_file_path)
    except OSError:
        pass

    return node_orbit_counts


def count2density(node_orbit_counts, graph_size):


    all_possible_motifs = {}
    for size in [2, 3, 4]:
        all_possible_motifs[size] = comb(graph_size, size, exact=True)

    map_loc2motif = [1, 2, 1, 2, 2, 1, 3, 2, 1]
    node_size = np.array(1*[2] + 2*[3] + 6*[4])
    rewiring_normalizer = [ 1.,  3.,  1., 12.,  4.,  3., 12.,  6.,  1.]
    non_unique_count = np.zeros(9)
    density = np.zeros(9)

    # summing over nodes
    count_over_nodes = np.sum(node_orbit_counts, axis=0)
    non_unique_count[0] = count_over_nodes[0]
    for i in range(1, 9):
      start_idx = sum(map_loc2motif[:i])
      non_unique_count[i] = sum(count_over_nodes[start_idx: start_idx+map_loc2motif[i]])

    unique_count = non_unique_count / node_size

    for i in range(9):
      density[i] = unique_count[i] / (rewiring_normalizer[i] * all_possible_motifs[node_size[i]])
    return density



def moment_graphon(resolution, net):
    # Generate data for the plot
    x = np.linspace(0, 1, resolution)
    y = np.linspace(0, 1, resolution)
    X, Y = np.meshgrid(x, y)

    # Network output
    inputs = torch.tensor(np.stack((X.flatten(), Y.flatten()), axis=1), dtype=torch.float32).to(device)
    Z_net = net(inputs).cpu().detach().numpy().reshape(X.shape)
    Z_sym = np.copy(Z_net)

    # Copy lower triangle to upper triangle (excluding the diagonal)
    #i_lower = np.tril_indices(Z_net.shape[0], -1)
    #Z_sym[i_lower] = Z_sym.T[i_lower]

    # Copy upper triangle to lower triangle (excluding the diagonal)
    i_upper = np.triu_indices(Z_net.shape[0], 1)
    Z_sym[i_upper] = Z_sym.T[i_upper]


    np.fill_diagonal(Z_sym, 0)
    #np.fill_diagonal(Z_sym, np.diag(Z_real))

    # replace 1 with 0.8 in Z_sym
    
    #print("GW Distance  = " + str(gw_distance(Z_sym, Z_real)))
    return Z_sym

def train_Moment(moments, weights, model_name):
    """
    Train the Moment model using the dataset and return GW loss, centrality NMSE averages, and standard deviations.

    Args:
        dataset_name (str): Name of the dataset file in the dataset folder.
        graphon_idx (int): Index of the graphon.

    Returns:
        tuple: Avg GW loss, Std GW loss, Avg and Std NMSE for each centrality measure.
    """

    # Default parameters
    epochs = 2000
    #patience = 600
    #lr = 1e-3
    #N = 30000
    #hid_dim = 64
    num_motifs = 9
    weight_mode = 0
    #lr = 9.26e-05
    lr = 1e-4
    N = 20000
    hid_dim = 5*[96]
    num_layers = 5
    
    patience = 600
    w0 = 9.

  

    Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],
                                                                 [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],
      [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],
       [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,
       [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],
       [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],
       [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
       ]]

    induced_list = motifs_to_induced_motifs(Es)
    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

    real_moments = torch.tensor(moments).to(device)
    # print density counting is finished
    print("Density counting is finished")
    print(real_moments)
    while True:
        
        if model_name == 'SIREN':
            model = SirenNet(2, hid_dim, 1, num_layers=num_layers, w0=w0, w0_initial=30.).train().to(device)
        elif model_name == 'MLP':
            lr = 1e-3
            hid_dim = 64
            model = lipmlp([2, hid_dim, 1]).train().to(device)
        losses = train_momentnet(model, induced_list[:num_motifs], real_moments, 4, N, epochs, patience, lr, device, weights)

        try:
            if (losses[0] - losses[-1])/losses[0] > 1e-3:
                break
            else:
                print((losses[0] - losses[-1])/losses[0])
        except:
            
            continue



    return model


# Set up logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(levelname)s: - %(message)s', datefmt='%Y-%m-%d')
if not logger.handlers:
    # Optionally, log to the screen
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
logging.getLogger('matplotlib').setLevel(logging.WARNING)


########################################
# Helper Function Definitions
########################################

def prepare_dataset_x(dataset):
    logger.info("Prepare dataset x")
    if dataset[0].x is None:
        logger.info("dataset[0].x is None")
        max_degree = 0
        degs_all = []
        for data in dataset:
            d = degree(data.edge_index[0], dtype=torch.long)
            degs_all.append(d)
            max_degree = max(max_degree, d.max().item())
            data.num_nodes = int(torch.max(data.edge_index)) + 1

        if max_degree < 2000:
            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = F.one_hot(degs, num_classes=max_degree+1).to(torch.float)
        else:
            deg = torch.cat(degs_all, dim=0).to(torch.float)
            mean, std = deg.mean().item(), deg.std().item()
            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = ((degs - mean) / std).view(-1, 1)
    return dataset


def prepare_dataset_onehot_y(dataset):
    y_set = set()
    for data in dataset:
        y_set.add(int(data.y))
    num_classes = len(y_set)
    for data in dataset:
        # Each data.y becomes a one-hot vector; here we take the first element since 
        # F.one_hot returns a vector wrapped in an extra dimension.
        data.y = F.one_hot(data.y, num_classes=num_classes).to(torch.float)[0]
    return dataset


def mixup_cross_entropy_loss(input, target, size_average=True):
    """Mixup version of cross entropy loss."""
    assert input.size() == target.size()
    assert isinstance(input, Variable) and isinstance(target, Variable)
    loss = - torch.sum(input * target)
    return loss / input.size()[0] if size_average else loss


def train(model, train_loader, optimizer, device, num_classes):
    model.train()
    loss_all = 0
    graph_all = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.x, data.edge_index, data.batch)
        y = data.y.view(-1, num_classes)
        loss = mixup_cross_entropy_loss(output, y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        graph_all += data.num_graphs
        optimizer.step()
    return loss_all / graph_all


def test(model, loader, device, num_classes):
    model.eval()
    correct = 0
    total = 0
    loss_all = 0
    for data in loader:
        data = data.to(device)
        output = model(data.x, data.edge_index, data.batch)
        pred = output.max(dim=1)[1]
        y = data.y.view(-1, num_classes)
        loss_all += mixup_cross_entropy_loss(output, y).item() * data.num_graphs
        # Convert one-hot to class indices for comparison
        y_labels = y.max(dim=1)[1]
        correct += pred.eq(y_labels).sum().item()
        total += data.num_graphs
    return correct / total, loss_all / total



def two_moments_mixup(two_moments, la=0.5, num_sample=20, ge='ISGL', resolution=None):

    label = la * two_moments[0][0] + (1 - la) * two_moments[1][0]
    sample_graph_label = torch.from_numpy(label).type(torch.float32)

    
    # Res = [int(resolution[0]) for _ in range(num_sample)]
    #print(Res)

    #print(two_moments[0][1])
    #print(two_moments[1][1])
    new_moment = la * two_moments[0][1] + (1 - la) * two_moments[1][1]
    #sample_graph = (np.random.rand(*new_graphon.shape) <= new_graphon).astype(np.int32)

    trained_model = train_Moment(new_moment, 0, "SIREN")
    graphon = moment_graphon(resolution, trained_model)





    sample_graphs = []
    for i in range(num_sample):
        
        
        
        num_nodes = random.randint(6, resolution)
        sample_graph = simulate_graphs(graphon, seed_gsize =123, seed_edge =123, num_graphs = 1, num_nodes = num_nodes, graph_size = 'fixed', offset =0)[0]

        #sample_graph = np.triu(sample_graph)
        #sample_graph = sample_graph + sample_graph.T - np.diag(np.diag(sample_graph))
        #sample_graph = sample_graph[sample_graph.sum(axis=1) != 0]
        #sample_graph = sample_graph[:, sample_graph.sum(axis=0) != 0]

        A = torch.from_numpy(sample_graph)
        edge_index, _ = dense_to_sparse(A)
        num_nodes = sample_graph.shape[0]

        if num_nodes == 0:
            print('num_nodes is 0')
            continue

        # continue if the graph is empty in terms of edges
        if edge_index.shape[1] == 0:
            print('edge_index is 0')
            continue


        pyg_graph = Data()
        pyg_graph.y = sample_graph_label
        pyg_graph.edge_index = edge_index
        pyg_graph.num_nodes = num_nodes
        sample_graphs.append(pyg_graph)
        
    return sample_graphs

########################################
# Default Parameters and Seed List Setup
########################################

# Default parameters
data_path      = "./"
dataset_name   = "AIDS"
model_name     = "GIN"
num_epochs     = 400
batch_size     = 128
learning_rate  = 0.01
num_hidden     = 64
lam_range      = [0.1, 0.2]
aug_ratio      = 0.2
aug_num        = 10
ge             = "Moment"   # options: "ISGL", "Moment", "IGNR", etc.
log_screen     = True
gmixup         = True
n_epochs_inr   = 20

Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],
                                                                 [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],
      [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],
       [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,
       [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],
       [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],
       [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
       ]]

induced_list = motifs_to_induced_motifs(Es)

# Example seed list
seed_list = [1314, 11314, 21314, 31314, 41314, 51314, 61314, 71314]

# To store metrics across seeds
all_best_test_acc = []
all_last10_avg_acc = []

########################################
# Main Loop: Iterate over Seed List
########################################

for seed in seed_list:
    logger.info("==========================================")
    logger.info(f"Starting training for seed {seed}")

    # Set seeds for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Running on device: {device}")

    # Load dataset
    path = osp.join(data_path, dataset_name)
    dataset = TUDataset(path, name=dataset_name)
    dataset = list(dataset)
    
    # Reshape labels
    for graph in dataset:
        graph.y = graph.y.view(-1)
    dataset = prepare_dataset_onehot_y(dataset)

    # Log global dataset statistics
    avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(dataset)
    logger.info(f"Dataset {dataset_name}: {len(dataset)} graphs, avg nodes: {avg_num_nodes}")

    # Shuffle and split dataset: 70% train, 10% validation (from train), 20% test
    random.shuffle(dataset)
    train_nums = int(len(dataset) * 0.7)
    train_val_nums = int(len(dataset) * 0.8)
    
    # Log training subset statistics
    avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(dataset[:train_nums])
    logger.info(f"Training graphs: avg nodes: {avg_num_nodes}, max nodes: {max_num_nodes}")

    resolution = int(avg_num_nodes)
    logger.info(f"Resolution set to: {resolution}")
    
    # Augment dataset using graphon mixup if enabled
    if gmixup:
        class_graphs = split_class_graphs(dataset[:train_nums])

        ###########################################
        def compute_moment_scores(moment_matrix):
            """
            Computes discrimination scores for each moment (column) in a matrix.
            
            Parameters:
                moment_matrix (np.ndarray): A 2D NumPy array of shape (n_classes, m_moments)
                                            where each row corresponds to a class.
            
            Returns:
                std_scores (np.ndarray): Standard deviation scores for each moment.
                range_scores (np.ndarray): Range (max-min) scores for each moment.
                avg_pairwise (np.ndarray): Average pairwise difference scores for each moment.
            """
            # Method 1: Standard deviation across classes for each moment.
            std_scores = np.std(moment_matrix, axis=0)
            
            # Method 2: Range (maximum minus minimum) for each moment.
            range_scores = np.ptp(moment_matrix, axis=0)  # np.ptp returns the range (peak to peak)
            
            # Method 3: Average pairwise absolute differences between classes for each moment.
            n, m = moment_matrix.shape
            avg_pairwise = np.zeros(m)
            for j in range(m):
                pairwise_diffs = []
                for i in range(n):
                    for k in range(i + 1, n):
                        diff = abs(moment_matrix[i, j] - moment_matrix[k, j])
                        pairwise_diffs.append(diff)
                avg_pairwise[j] = np.mean(pairwise_diffs)
                
            return std_scores, range_scores, avg_pairwise



        Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],
                                                                        [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],
            [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],
            [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,
            [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],
            [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],
            [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
            ]]

        induced_list = motifs_to_induced_motifs(Es)

        # Suppose we have data from 3 classes and 4 moments per class:
        moment_matrix = []

        for label, graphs in class_graphs:
            estimated_densities = np.zeros(9)
            nx_graphs = [nx.from_numpy_array(graph) for graph in graphs]
            for graph in nx_graphs:
                # only if number of nodes is greater than 6
                if graph.number_of_nodes() < 6:
                    continue
                node_orbit_counts = orca(graph)
                density = count2density(node_orbit_counts, graph.number_of_nodes())
                estimated_densities += density
            #print(f"Estimated densities for label {label}: {estimated_densities}")
            real_moments = estimated_densities / len(graphs)
            moment_matrix.append([label, real_moments])
        
                
        num_sample = int(train_nums * aug_ratio / aug_num)
        lam_list = np.random.uniform(low=lam_range[0], high=lam_range[1], size=(aug_num,))
        new_graph = []
        for lam in lam_list:
            logger.info(f"lam: {lam}")
            logger.info(f"num_sample: {num_sample}")
            two_moments = random.sample(moment_matrix, 2)
            
            new_graph += two_moments_mixup(two_moments, la=lam, num_sample=num_sample, ge="USVT", resolution=resolution)
            logger.info(f"New graph label: {new_graph[-1].y}")
            
        avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(new_graph)
        logger.info(f"New graphs: avg nodes: {avg_num_nodes}, min nodes: {min_num_nodes}, max nodes: {max_num_nodes}")
        logger.info(f"Real augmentation ratio: {len(new_graph)/train_nums}")
        dataset = new_graph + dataset
        train_nums += len(new_graph)
        train_val_nums += len(new_graph)

    # Prepare node features for whole dataset
    dataset = prepare_dataset_x(dataset)
    logger.info(f"Dataset feature shape: {dataset[0].x.shape}")
    logger.info(f"Dataset label shape: {dataset[0].y.shape}")

    num_features = dataset[0].x.shape[1]
    num_classes = dataset[0].y.shape[0]

    for data in dataset:
        if not hasattr(data, 'edge_attr') or data.edge_attr is None:
            data.edge_attr = torch.ones((data.edge_index.size(1), 3)) 

    # Split dataset into train, validation and test
    train_dataset = dataset[:train_nums]
    random.shuffle(train_dataset)
    val_dataset = dataset[train_nums:train_val_nums]
    test_dataset = dataset[train_val_nums:]
    
    logger.info(f"Train dataset size: {len(train_dataset)}")
    logger.info(f"Validation dataset size: {len(val_dataset)}")
    logger.info(f"Test dataset size: {len(test_dataset)}")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    # Instantiate the model
    if model_name == "GIN":
        model_instance = GIN(num_features=num_features, num_classes=num_classes, num_hidden=num_hidden).to(device)
    else:
        logger.info("No valid model specified.")
        continue

    optimizer = torch.optim.Adam(model_instance.parameters(), lr=learning_rate, weight_decay=5e-4)
    scheduler = StepLR(optimizer, step_size=100, gamma=0.5)

    # Train the model
    last_10_epoch_acc = []  # To record test acc for the last 10 epochs
    best_val_acc = 0
    best_test_acc = 0

    for epoch in range(1, num_epochs):
        train_loss = train(model_instance, train_loader, optimizer, device, num_classes)
        val_acc, val_loss = test(model_instance, val_loader, device, num_classes)
        test_acc, test_loss = test(model_instance, test_loader, device, num_classes)
        scheduler.step()

        if val_acc >= best_val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc
            best_epoch = epoch

        if epoch > num_epochs - 10:
            last_10_epoch_acc.append(test_acc)

        if epoch % 30 == 0:
            print('Epoch: {:03d}, Train Loss: {:.6f}, Val Loss: {:.6f}, Test Loss: {:.6f}, Val Acc: {:.6f}, Test Acc: {:.6f}'.format(
                epoch, train_loss, val_loss, test_loss, val_acc, test_acc))
    
    avg_last10 = np.mean(last_10_epoch_acc)
    print(f"Seed {seed}: Last 10 epochs average test acc: {np.round(avg_last10,3)}")
    print(f"Seed {seed}: Best test acc (based on validation) = {best_test_acc} at epoch {best_epoch}")

    all_best_test_acc.append(best_test_acc)
    all_last10_avg_acc.append(avg_last10)

########################################
# After All Seeds: Print Final Metrics
########################################
print("==========================================")
print("Overall Results Across Seeds:")
print(f"Best Test Accuracy: Mean = {np.mean(all_best_test_acc):.3f}, Std = {np.std(all_best_test_acc):.3f}")
print(f"Last 10 Epoch Test Accuracy Averag e: Mean = {np.mean(all_last10_avg_acc):.3f}, Std = {np.std(all_last10_avg_acc):.3f}")
