In [None]:
########################################
# Single Cell Code for Multi-Seed Training
########################################
import sys
import os 
sys.path.append('../../')
sys.path.append('../')
from time import time
import logging
import os.path as osp
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.utils import degree
from torch.autograd import Variable
import random
from torch.optim.lr_scheduler import StepLR
import schedulefree

# Import your utilities and models
from gmixupUtils import stat_graph, split_class_graphs, align_graphs
from gmixupUtils import two_graphons_mixup, universal_svd, get_graphon
from gmixupUtils import GIN
from Moment.tools import motifs_to_induced_motifs, lipmlp, train_momentnet
from SIGL.tools import *
import networkx as nx
from torch_geometric.utils import dense_to_sparse
import subprocess as sp
from scipy.special import comb
seed = 21
random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
print(f"INFO: Set CUBLAS_WORKSPACE_CONFIG to {os.environ.get('CUBLAS_WORKSPACE_CONFIG')}")
torch.use_deterministic_algorithms(True)

ORCA_DIR = '../../orca/'

def edge_list_reindexed(G):
    idx = 0
    id2idx = dict()
    for u in G.nodes():
        id2idx[str(u)] = idx
        idx += 1

    edges = []
    for (u, v) in G.edges():
        edges.append((id2idx[str(u)], id2idx[str(v)]))
    return edges


def orca(graph):
    tmp_file_path = os.path.join(ORCA_DIR, f'tmptmp55-{random.random():.4f}.txt')
    f = open(tmp_file_path, 'w+')
    f.write(str(graph.number_of_nodes()) + ' ' + str(graph.number_of_edges()) + '\n')
    for (u, v) in edge_list_reindexed(graph):
        f.write(str(u) + ' ' + str(v) + '\n')
    f.close()


    output = sp.check_output(["../../orca/orca",'4', tmp_file_path, 'outputnew.txt'])
    with open('outputnew.txt', 'r') as file:
        output = file.read()
    output = output.strip()
    node_orbit_counts = np.array([list(map(int, node_cnts.strip().split(' ')))
                                  for node_cnts in output.strip('\n').split('\n')])
    try:
        os.remove(tmp_file_path)
    except OSError:
        pass

    return node_orbit_counts


def count2density(node_orbit_counts, graph_size):


    all_possible_motifs = {}
    for size in [2, 3, 4]:
        all_possible_motifs[size] = comb(graph_size, size, exact=True)

    map_loc2motif = [1, 2, 1, 2, 2, 1, 3, 2, 1]
    node_size = np.array(1*[2] + 2*[3] + 6*[4])
    rewiring_normalizer = [ 1.,  3.,  1., 12.,  4.,  3., 12.,  6.,  1.]
    non_unique_count = np.zeros(9)
    density = np.zeros(9)

    # summing over nodes
    count_over_nodes = np.sum(node_orbit_counts, axis=0)
    non_unique_count[0] = count_over_nodes[0]
    for i in range(1, 9):
      start_idx = sum(map_loc2motif[:i])
      non_unique_count[i] = sum(count_over_nodes[start_idx: start_idx+map_loc2motif[i]])

    unique_count = non_unique_count / node_size

    for i in range(9):
      density[i] = unique_count[i] / (rewiring_normalizer[i] * all_possible_motifs[node_size[i]])
    return density



def moment_graphon(resolution, net):
    # Generate data for the plot
    x = np.linspace(0, 1, resolution)
    y = np.linspace(0, 1, resolution)
    X, Y = np.meshgrid(x, y)

    # Network output
    inputs = torch.tensor(np.stack((X.flatten(), Y.flatten()), axis=1), dtype=torch.float32).to(device)
    Z_net = net(inputs).cpu().detach().numpy().reshape(X.shape)
    Z_sym = np.copy(Z_net)

    # Copy lower triangle to upper triangle (excluding the diagonal)
    #i_lower = np.tril_indices(Z_net.shape[0], -1)
    #Z_sym[i_lower] = Z_sym.T[i_lower]

    # Copy upper triangle to lower triangle (excluding the diagonal)
    i_upper = np.triu_indices(Z_net.shape[0], 1)
    Z_sym[i_upper] = Z_sym.T[i_upper]


    np.fill_diagonal(Z_sym, 0)
    #np.fill_diagonal(Z_sym, np.diag(Z_real))

    # replace 1 with 0.8 in Z_sym
    
    #print("GW Distance  = " + str(gw_distance(Z_sym, Z_real)))
    return Z_sym

def visualize_nn_graphon(model, resolution=100, device=None, title="Neural Network Graphon Visualization"):
    """
    Visualizes a graphon approximated by a neural network.

    The graphon W(x,y) is visualized as a heatmap. The function ensures
    symmetry (W(x,y) = W(y,x)) and sets the diagonal W(x,x) = 0.
    The colorbar is fixed between 0 and 1.

    Args:
        model (torch.nn.Module): The neural network model that takes a (N, 2) tensor
                                 of (x,y) coordinates and outputs a (N, 1) tensor
                                 of graphon values.
        resolution (int, optional): The resolution of the grid for x and y.
                                    Defaults to 100.
        device (str, optional): The device to run the model on ('cpu', 'cuda', 'cuda:0', etc.).
                                If None, it tries to use the device of the model's parameters
                                or defaults to 'cpu'.
        title (str, optional): The title for the plot. Defaults to
                               "Neural Network Graphon Visualization".
    """
    if device is None:
        try:
            # Attempt to infer device from model parameters
            device = next(model.parameters()).device
        except StopIteration:
            # No parameters, or model is not a typical nn.Module, default to CPU
            device = torch.device('cpu')
            logger.info("Could not infer device from model, defaulting to CPU.")
        except AttributeError:
            # Model might not have 'parameters' attribute (e.g. if it's a function)
            device = torch.device('cpu')
            logger.info("Model does not have parameters, defaulting to CPU.")


    model.eval()  # Set the model to evaluation mode
    model.to(device)

    # 1. Generate grid points
    x_coords = np.linspace(0, 1, resolution)
    y_coords = np.linspace(0, 1, resolution)
    X, Y = np.meshgrid(x_coords, y_coords)

    # 2. Prepare input for the neural network
    # Shape: (resolution*resolution, 2)
    grid_points = np.stack((X.flatten(), Y.flatten()), axis=1)
    grid_points_tensor = torch.tensor(grid_points, dtype=torch.float32).to(device)

    # 3. Get model output
    with torch.no_grad():
        W_net_flat = model(grid_points_tensor)

    # Reshape to (resolution, resolution)
    W_net = W_net_flat.cpu().numpy().reshape(resolution, resolution)

    # 4. Ensure symmetry: W_sym(x,y) = (W_net(x,y) + W_net(y,x)) / 2
    # This handles cases where the network itself might not be perfectly symmetric.
    W_sym = (W_net + W_net.T) / 2.0

    # 5. Set diagonal to 0 (typically, graphons don't have self-loops W(x,x)=0)
    np.fill_diagonal(W_sym, 0)

    # 6. Clip values to be within [0, 1] as probabilities,
    #    though the fixed colorbar vmin/vmax will also handle this visually.
    W_sym = np.clip(W_sym, 0, 1)

    # 7. Plot the heatmap
    plt.figure(figsize=(8, 6.5))
    # Using origin='lower' makes (0,0) at the bottom-left corner, like typical plots.
    #imshow_obj = plt.imshow(W_sym, extent=[0, 1, 0, 1], origin='lower', cmap='viridis', vmin=0, vmax=1)
    
    # Using origin='upper' makes (0,0) at the top-left corner, like a matrix.
    # This was implied by how Z_sym was indexed in the original moment_graphon (e.g. Z_sym[i_upper]).
    # Let's use 'upper' to be consistent with matrix indexing if that's the expectation.
    # If the user expects (0,0) at bottom-left, then use origin='lower' and potentially flip Y coordinates
    # during input generation or flip W_sym before plotting using W_sym = np.flipud(W_sym).
    # For now, 'upper' with imshow means row 0 is at the top.
    imshow_obj = plt.imshow(W_sym, extent=[0, 1, 1, 0], origin='upper', cmap='viridis', vmin=0, vmax=1)


    plt.title(title)
    plt.xlabel("x")
    plt.ylabel("y")

    # Add colorbar fixed between 0 and 1
    cbar = plt.colorbar(imshow_obj, fraction=0.046, pad=0.04)
    cbar.set_label("W(x,y) - Edge Probability")
    cbar.set_ticks(np.linspace(0, 1, 6)) # Example: 0, 0.2, 0.4, 0.6, 0.8, 1.0

    plt.show()

def train_Moment(moments, weights, model_name):
    """
    Train the Moment model using the dataset and return GW loss, centrality NMSE averages, and standard deviations.

    Args:
        dataset_name (str): Name of the dataset file in the dataset folder.
        graphon_idx (int): Index of the graphon.

    Returns:
        tuple: Avg GW loss, Std GW loss, Avg and Std NMSE for each centrality measure.
    """

    # Default parameters
    epochs = 2000
    #patience = 600
    #lr = 1e-3
    #N = 30000
    #hid_dim = 64
    num_motifs = 9
    weight_mode = 0
    #lr = 9.26e-05
    lr = 1e-4
    N = 20000
    hid_dim = 5*[96]
    num_layers = 5
    
    patience = 1000
    w0 = 9.

  

    Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],
                                                                 [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],
      [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],
       [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,
       [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],
       [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],
       [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
       ]]

    induced_list = motifs_to_induced_motifs(Es)
    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

    real_moments = torch.tensor(moments).to(device)

    #real_moments[1] = (real_moments[0]**2)*(1-real_moments[0])
    #eal_moments[2] = (real_moments[0]**3)
    # print density counting is finished
    #print("Density counting is finished")
    #print(real_moments)
    while True:
        
        if model_name == 'SIREN':

            lr = 9.26e-05
            N = 30000
            hid_dim = 6 * [96]
            num_layers = 6
            weight_mode = 0
            patience = 700
            w0 = 9.0
            model = SirenNet(2, hid_dim, 1, num_layers=num_layers, w0=w0, w0_initial=30.).train().to(device)
        elif model_name == 'MLP':
            lr = 1e-4
            hid_dim = 64
            model = lipmlp([2, hid_dim, 1]).train().to(device)
        losses = train_momentnet(model, induced_list[2:num_motifs], real_moments[2:num_motifs], 4, N, epochs, patience, lr, device, 1)

        try:
            if (losses[0] - losses[-1])/losses[0] > 1e-3:
                break
            else:
                print((losses[0] - losses[-1])/losses[0])
        except:
            
            continue


    #visualize_nn_graphon(model)
    return model


# Set up logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(levelname)s: - %(message)s', datefmt='%Y-%m-%d')
if not logger.handlers:
    # Optionally, log to the screen
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
logging.getLogger('matplotlib').setLevel(logging.WARNING)


########################################
# Helper Function Definitions
########################################

def prepare_dataset_x(dataset):
    logger.info("Prepare dataset x")
    if dataset[0].x is None:
        logger.info("dataset[0].x is None")
        max_degree = 0
        degs_all = []
        for data in dataset:
            d = degree(data.edge_index[0], dtype=torch.long)
            degs_all.append(d)
            max_degree = max(max_degree, d.max().item())
            data.num_nodes = int(torch.max(data.edge_index)) + 1

        if max_degree < 2000:
            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = F.one_hot(degs, num_classes=max_degree+1).to(torch.float)
        else:
            deg = torch.cat(degs_all, dim=0).to(torch.float)
            mean, std = deg.mean().item(), deg.std().item()
            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = ((degs - mean) / std).view(-1, 1)
    return dataset


def prepare_dataset_onehot_y(dataset):
    y_set = set()
    for data in dataset:
        y_set.add(int(data.y))
    num_classes = len(y_set)
    for data in dataset:
        # Each data.y becomes a one-hot vector; here we take the first element since 
        # F.one_hot returns a vector wrapped in an extra dimension.
        data.y = F.one_hot(data.y, num_classes=num_classes).to(torch.float)[0]
    return dataset


def mixup_cross_entropy_loss(input, target, size_average=True):
    """Mixup version of cross entropy loss."""
    assert input.size() == target.size()
    assert isinstance(input, Variable) and isinstance(target, Variable)
    loss = - torch.sum(input * target)
    return loss / input.size()[0] if size_average else loss


def train(model, train_loader, optimizer, device, num_classes):
    model.train()
    loss_all = 0
    graph_all = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.x, data.edge_index, data.batch)
        y = data.y.view(-1, num_classes)
        loss = mixup_cross_entropy_loss(output, y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        #print(f"loss : {loss_all}")
        graph_all += data.num_graphs
        optimizer.step()
    #print(f"Output : {output}")
    return loss_all / graph_all


def test(model, loader, device, num_classes):
    model.eval()
    correct = 0
    total = 0
    loss_all = 0
    for data in loader:
        data = data.to(device)
        output = model(data.x, data.edge_index, data.batch)
        pred = output.max(dim=1)[1]
        y = data.y.view(-1, num_classes)
        loss_all += mixup_cross_entropy_loss(output, y).item() * data.num_graphs
        
        # Convert one-hot to class indices for comparison
        y_labels = y.max(dim=1)[1]
        correct += pred.eq(y_labels).sum().item()
        total += data.num_graphs
    return correct / total, loss_all / total



def two_moments_mixup(two_moments, la=0.5, num_sample=20, ge='ISGL', resolution=None):

    label = la * two_moments[0][0] + (1 - la) * two_moments[1][0]
    sample_graph_label = torch.from_numpy(label).type(torch.float32)

    
    # Res = [int(resolution[0]) for _ in range(num_sample)]
    #print(Res)

    #print(two_moments[0][1])
    #print(two_moments[1][1])
    new_moment = la * two_moments[0][1] + (1 - la) * two_moments[1][1]
    
    #sample_graph = (np.random.rand(*new_graphon.shape) <= new_graphon).astype(np.int32)

    trained_model = train_Moment(new_moment, 0, "MLP")
    graphon = moment_graphon(resolution, trained_model)





    sample_graphs = []
    for i in range(num_sample):
        
        
        done = 0
        while True:
            if done > 20:
                return []
            num_nodes = random.randint(30, 60)
            sample_graph = simulate_graphs(graphon, seed_gsize =19, seed_edge =19, num_graphs = 1, num_nodes = num_nodes, graph_size = 'fixed', offset =0)[0]

            # print number of nodes
            #print(f"Number of nodes: {num_nodes}")

            #sample_graph = np.triu(sample_graph)
            #sample_graph = sample_graph + sample_graph.T - np.diag(np.diag(sample_graph))
            sample_graph = sample_graph[sample_graph.sum(axis=1) != 0]
            sample_graph = sample_graph[:, sample_graph.sum(axis=0) != 0]

            A = torch.from_numpy(sample_graph)
            edge_index, _ = dense_to_sparse(A)
            num_nodes = sample_graph.shape[0]

            if num_nodes == 0:
                print('num_nodes is 0')
                done += 1
                continue

            # continue if the graph is empty in terms of edges
            if edge_index.shape[1] == 0:
                print('edge_index is 0')
                done += 1
                continue

            break


        pyg_graph = Data()
        pyg_graph.y = sample_graph_label
        pyg_graph.edge_index = edge_index
        pyg_graph.num_nodes = num_nodes
        sample_graphs.append(pyg_graph)
        
    return sample_graphs

########################################
# Default Parameters and Seed List Setup
########################################

# Default parameters
data_path      = "./"
dataset_name   = "IMDB-MULTI"
model_name     = "GIN"
num_epochs     = 400
batch_size     = 128
learning_rate  = 0.001
num_hidden     = 64
lam_range      = [0.1, 0.4]
aug_ratio      = 0.2
aug_num        = 10
ge             = "Moment"   # options: "ISGL", "Moment", "IGNR", etc.
log_screen     = True
gmixup         = True
n_epochs_inr   = 20

Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],
                                                                 [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],
      [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],
       [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,
       [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],
       [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],
       [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
       ]]

induced_list = motifs_to_induced_motifs(Es)

# Example seed list

# To store metrics across seeds
all_best_test_acc = []
all_last10_avg_acc = []

def plot_moment_vectors(moment_matrix, title="Moment Vectors"):
    """
    Plots moment vectors from a list of [label, moments_vector] pairs.

    Args:
        moment_matrix (list): A list where each element is a list or tuple
                              of the form [label (str), real_moments (list or np.array)].
        title (str): The title for the plot.
    """
    if not moment_matrix:
        print("The moment_matrix is empty. Nothing to plot.")
        return

    plt.figure(figsize=(10, 6)) # Adjust figure size as needed

    for item in moment_matrix:
        if len(item) != 2:
            print(f"Skipping invalid item: {item}. Expected [label, real_moments].")
            continue

        label, real_moments = item
        
        # Ensure real_moments is a numpy array for easier handling
        moments_vector = np.array(real_moments)

        if moments_vector.ndim != 1:
            print(f"Skipping label '{label}': moments_vector is not 1-dimensional.")
            continue
        
        if moments_vector.size == 0:
            print(f"Skipping label '{label}': moments_vector is empty.")
            continue

        # Create an x-axis based on the index of the moments
        x_values = np.arange(len(moments_vector))
        
        plt.plot(x_values, moments_vector, marker='o', linestyle='-', label=str(label))

    plt.title(title)
    plt.xlabel("Moment Index")
    plt.ylabel("Moment Value")
    
    # Add a legend if there are any lines plotted
    # Get current handles and labels
    handles, labels = plt.gca().get_legend_handles_labels()
    if handles: # Check if any lines were actually plotted
        plt.legend()
    else:
        print("No valid data was plotted, so no legend will be shown.")
        
    plt.grid(True)
    plt.tight_layout() # Adjusts plot to prevent labels from overlapping
    plt.show()

########################################
# Main Loop: Iterate over Seed List
########################################
seeds = [61314, 1314, 11314, 21314, 31314, 41314, 51314, 71314]

for seed in seeds:
    for seed_torch in [17]:
        logger.info("==========================================")
        logger.info(f"Starting training for seed {seed}")

        if seed == 61314:
            seed_torch = 4

        # Set seeds for reproducibility
        #torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

        device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
        logger.info(f"Running on device: {device}")

        # Load dataset
        path = osp.join(data_path, dataset_name)
        dataset = TUDataset(path, name=dataset_name)
        dataset = list(dataset)
        
        # Reshape labels
        for graph in dataset:
            graph.y = graph.y.view(-1)
        dataset = prepare_dataset_onehot_y(dataset)

        # Log global dataset statistics
        avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(dataset)
        logger.info(f"Dataset {dataset_name}: {len(dataset)} graphs, avg nodes: {avg_num_nodes}")

        # Shuffle and split dataset: 70% train, 10% validation (from train), 20% test
        random.shuffle(dataset)
        train_nums = int(len(dataset) * 0.7)
        train_val_nums = int(len(dataset) * 0.8)
        
        # Log training subset statistics
        avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(dataset[:train_nums])
        logger.info(f"Training graphs: avg nodes: {avg_num_nodes}, max nodes: {max_num_nodes}")

        resolution = 10*int(median_num_nodes)
        logger.info(f"Resolution set to: {resolution}")
        
        # Augment dataset using graphon mixup if enabled
        if gmixup:
            class_graphs = split_class_graphs(dataset[:train_nums])
            test_class_graphs = split_class_graphs(dataset[train_val_nums:])
            ###########################################
            def compute_moment_scores(moment_matrix):
                """
                Computes discrimination scores for each moment (column) in a matrix.
                
                Parameters:
                    moment_matrix (np.ndarray): A 2D NumPy array of shape (n_classes, m_moments)
                                                where each row corresponds to a class.
                
                Returns:
                    std_scores (np.ndarray): Standard deviation scores for each moment.
                    range_scores (np.ndarray): Range (max-min) scores for each moment.
                    avg_pairwise (np.ndarray): Average pairwise difference scores for each moment.
                """
                # Method 1: Standard deviation across classes for each moment.
                std_scores = np.std(moment_matrix, axis=0)
                
                # Method 2: Range (maximum minus minimum) for each moment.
                range_scores = np.ptp(moment_matrix, axis=0)  # np.ptp returns the range (peak to peak)
                
                # Method 3: Average pairwise absolute differences between classes for each moment.
                n, m = moment_matrix.shape
                avg_pairwise = np.zeros(m)
                for j in range(m):
                    pairwise_diffs = []
                    for i in range(n):
                        for k in range(i + 1, n):
                            diff = abs(moment_matrix[i, j] - moment_matrix[k, j])
                            pairwise_diffs.append(diff)
                    avg_pairwise[j] = np.mean(pairwise_diffs)
                    
                return std_scores, range_scores, avg_pairwise



            Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],
                                                                            [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],
                [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],
                [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,
                [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],
                [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],
                [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
                ]]

            induced_list = motifs_to_induced_motifs(Es)
            
            torch.manual_seed(seed_torch)
            torch.cuda.manual_seed(seed_torch)
            torch.cuda.manual_seed_all(seed_torch)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
            torch.use_deterministic_algorithms(True)
            num_sample = int(train_nums * aug_ratio / aug_num)
            lam_list = np.random.uniform(low=lam_range[0], high=lam_range[1], size=(aug_num,))
            new_graph = []
            #lam_list = [0.18]
            for lam in lam_list:
                logger.info(f"lam: {lam}")
                logger.info(f"num_sample: {num_sample}")


                class_graphs = split_class_graphs(dataset[:train_nums])
                # pick 10 graphs from each class
                class_graphs = [(label, random.sample(graphs, min(10, len(graphs)))) for label, graphs in class_graphs]
                moment_matrix = []
                for label, graphs in class_graphs:
                    estimated_densities = np.zeros(9)
                    nx_graphs = [nx.from_numpy_array(graph) for graph in graphs]
                    for graph in nx_graphs:
                        # only if number of nodes is greater than 6
                        if graph.number_of_nodes() < 6:
                            continue

                        #if graph.number_of_nodes() < resolution:
                        #    continue
                        node_orbit_counts = orca(graph)
                        density = count2density(node_orbit_counts, graph.number_of_nodes())
                        estimated_densities += density
                    #print(f"Estimated densities for label {label}: {estimated_densities}")
                    real_moments = estimated_densities / len(graphs)
                    moment_matrix.append([label, real_moments])
                
                
 
                all_moments = sorted(moment_matrix, key=lambda x: x[1][0])
                two_moments = [all_moments[0], all_moments[-1]]

                
                new_graph += two_moments_mixup(two_moments, la=lam, num_sample=num_sample, ge="USVT", resolution=resolution)
                
            avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(new_graph)
            logger.info(f"New graphs: avg nodes: {avg_num_nodes}, min nodes: {min_num_nodes}, max nodes: {max_num_nodes}")
            logger.info(f"Real augmentation ratio: {len(new_graph)/train_nums}")
            dataset = new_graph + dataset
            train_nums += len(new_graph)
            train_val_nums += len(new_graph)

        # Prepare node features for whole dataset
        dataset = prepare_dataset_x(dataset)
        logger.info(f"Dataset feature shape: {dataset[0].x.shape}")
        logger.info(f"Dataset label shape: {dataset[0].y.shape}")

        num_features = dataset[0].x.shape[1]
        num_classes = dataset[0].y.shape[0]

        for data in dataset:
            if not hasattr(data, 'edge_attr') or data.edge_attr is None:
                data.edge_attr = torch.ones((data.edge_index.size(1), 3)) 

        # Split dataset into train, validation and test
        train_dataset = dataset[:train_nums]
        random.shuffle(train_dataset)
        val_dataset = dataset[train_nums:train_val_nums]
        test_dataset = dataset[train_val_nums:]
        
        logger.info(f"Train dataset size: {len(train_dataset)}")
        logger.info(f"Validation dataset size: {len(val_dataset)}")
        logger.info(f"Test dataset size: {len(test_dataset)}")

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)
        test_loader = DataLoader(test_dataset, batch_size=batch_size)

        # Instantiate the model
        if model_name == "GIN":
            model_instance = GIN(num_features=num_features, num_classes=num_classes, num_hidden=num_hidden).to(device)
        else:
            logger.info("No valid model specified.")
            continue

        #optimizer = torch.optim.Adam(model_instance.parameters(), lr=learning_rate, weight_decay=5e-4)
        optimizer = schedulefree.AdamWScheduleFree(model_instance.parameters(), lr = learning_rate)
        optimizer.train()
        scheduler = StepLR(optimizer, step_size=100, gamma=0.5)



        # Train the model
        last_10_epoch_acc = []  # To record test acc for the last 10 epochs
        best_val_acc = 0
        best_test_acc = 0

        for epoch in range(1, num_epochs):
        
            train_loss = train(model_instance, train_loader, optimizer, device, num_classes)

            val_acc, val_loss = test(model_instance, val_loader, device, num_classes)
            test_acc, test_loss = test(model_instance, test_loader, device, num_classes)
            scheduler.step()

            if val_acc >= best_val_acc and epoch > 10:
                best_val_acc = val_acc
                best_test_acc = test_acc
                best_epoch = epoch

            if epoch > num_epochs - 10:
                last_10_epoch_acc.append(test_acc)

            if epoch % 30 == 0:
                print('Epoch: {:03d}, Train Loss: {:.6f}, Val Loss: {:.6f}, Test Loss: {:.6f}, Val Acc: {:.6f}, Test Acc: {:.6f}'.format(
                    epoch, train_loss, val_loss, test_loss, val_acc, test_acc))
        
        avg_last10 = np.mean(last_10_epoch_acc)

        print("####################################################")
        print(f"Seed {seed_torch}: Last 10 epochs average test acc: {np.round(avg_last10,3)}")
        print(f"Seed {seed_torch}: Best test acc (based on validation) = {best_test_acc} at epoch {best_epoch}")
        print("####################################################")

    all_best_test_acc.append(best_test_acc)
    all_last10_avg_acc.append(avg_last10)

print("==========================================")
print("Overall Results Across Seeds:")
print(f"Best Test Accuracy: Mean = {np.mean(all_best_test_acc):.3f}, Std = {np.std(all_best_test_acc):.3f}")
print(f"Last 10 Epoch Test Accuracy Averag e: Mean = {np.mean(all_last10_avg_acc):.3f}, Std = {np.std(all_last10_avg_acc):.3f}")


