In [None]:
########################################
# Single Cell Code for Multi-Seed Training
########################################

import sys
import os 
sys.path.append('../../')
sys.path.append('../')
from time import time
import logging
import os.path as osp
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.utils import degree
from torch.autograd import Variable
import random
from torch.optim.lr_scheduler import StepLR

# Import your utilities and models
from gmixupUtils import stat_graph, split_class_graphs, align_graphs
from gmixupUtils import two_graphons_mixup, universal_svd, get_graphon
from gmixupUtils import GIN
from Moment.tools import motifs_to_induced_motifs, orca, count2density
from SIGL.tools import *
from Moment.trainMoment import train_Moment
import networkx as nx

# Set up logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(levelname)s: - %(message)s', datefmt='%Y-%m-%d')
if not logger.handlers:
    # Optionally, log to the screen
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
logging.getLogger('matplotlib').setLevel(logging.WARNING)


########################################
# Helper Function Definitions
########################################

def prepare_dataset_x(dataset):
    logger.info("Prepare dataset x")
    if dataset[0].x is None:
        logger.info("dataset[0].x is None")
        max_degree = 0
        degs_all = []
        for data in dataset:
            d = degree(data.edge_index[0], dtype=torch.long)
            degs_all.append(d)
            max_degree = max(max_degree, d.max().item())
            data.num_nodes = int(torch.max(data.edge_index)) + 1

        if max_degree < 2000:
            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = F.one_hot(degs, num_classes=max_degree+1).to(torch.float)
        else:
            deg = torch.cat(degs_all, dim=0).to(torch.float)
            mean, std = deg.mean().item(), deg.std().item()
            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = ((degs - mean) / std).view(-1, 1)
    return dataset


def prepare_dataset_onehot_y(dataset):
    y_set = set()
    for data in dataset:
        y_set.add(int(data.y))
    num_classes = len(y_set)
    for data in dataset:
        # Each data.y becomes a one-hot vector; here we take the first element since 
        # F.one_hot returns a vector wrapped in an extra dimension.
        data.y = F.one_hot(data.y, num_classes=num_classes).to(torch.float)[0]
    return dataset


def mixup_cross_entropy_loss(input, target, size_average=True):
    """Mixup version of cross entropy loss."""
    assert input.size() == target.size()
    assert isinstance(input, Variable) and isinstance(target, Variable)
    loss = - torch.sum(input * target)
    return loss / input.size()[0] if size_average else loss


def train(model, train_loader, optimizer, device, num_classes):
    model.train()
    loss_all = 0
    graph_all = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.x, data.edge_index, data.batch)
        y = data.y.view(-1, num_classes)
        loss = mixup_cross_entropy_loss(output, y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        graph_all += data.num_graphs
        optimizer.step()
    return loss_all / graph_all


def test(model, loader, device, num_classes):
    model.eval()
    correct = 0
    total = 0
    loss_all = 0
    for data in loader:
        data = data.to(device)
        output = model(data.x, data.edge_index, data.batch)
        pred = output.max(dim=1)[1]
        y = data.y.view(-1, num_classes)
        loss_all += mixup_cross_entropy_loss(output, y).item() * data.num_graphs
        # Convert one-hot to class indices for comparison
        y_labels = y.max(dim=1)[1]
        correct += pred.eq(y_labels).sum().item()
        total += data.num_graphs
    return correct / total, loss_all / total


########################################
# Default Parameters and Seed List Setup
########################################

# Default parameters
data_path      = "./"
dataset_name   = "REDDIT-BINARY"
model_name     = "GIN"
num_epochs     = 400
batch_size     = 128
learning_rate  = 0.01
num_hidden     = 64
lam_range      = [0.1, 0.2]
aug_ratio      = 0.2
aug_num        = 10
ge             = "USVT"   # options: "ISGL", "Moment", "IGNR", etc.
log_screen     = True
gmixup         = True
n_epochs_inr   = 20

Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],
                                                                 [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],
      [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],
       [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,
       [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],
       [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],
       [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
       ]]

induced_list = motifs_to_induced_motifs(Es)

# Example seed list
seed_list = [1314, 11314, 21314, 31314, 41314, 51314, 61314, 71314]

# To store metrics across seeds
all_best_test_acc = []
all_last10_avg_acc = []

########################################
# Main Loop: Iterate over Seed List
########################################

for seed in seed_list:
    logger.info("==========================================")
    logger.info(f"Starting training for seed {seed}")

    # Set seeds for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Running on device: {device}")

    # Load dataset
    path = osp.join(data_path, dataset_name)
    dataset = TUDataset(path, name=dataset_name)
    dataset = list(dataset)
    
    # Reshape labels
    for graph in dataset:
        graph.y = graph.y.view(-1)
    dataset = prepare_dataset_onehot_y(dataset)

    # Log global dataset statistics
    avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(dataset)
    logger.info(f"Dataset {dataset_name}: {len(dataset)} graphs, avg nodes: {avg_num_nodes}")

    # Shuffle and split dataset: 70% train, 10% validation (from train), 20% test
    random.shuffle(dataset)
    train_nums = int(len(dataset) * 0.7)
    train_val_nums = int(len(dataset) * 0.8)
    
    # Log training subset statistics
    avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(dataset[:train_nums])
    logger.info(f"Training graphs: avg nodes: {avg_num_nodes}, max nodes: {max_num_nodes}")

    resolution = int(median_num_nodes)
    logger.info(f"Resolution set to: {resolution}")
    
    # Augment dataset using graphon mixup if enabled
    if gmixup:
        class_graphs = split_class_graphs(dataset[:train_nums])

        ###########################################
        def compute_moment_scores(moment_matrix):
            """
            Computes discrimination scores for each moment (column) in a matrix.
            
            Parameters:
                moment_matrix (np.ndarray): A 2D NumPy array of shape (n_classes, m_moments)
                                            where each row corresponds to a class.
            
            Returns:
                std_scores (np.ndarray): Standard deviation scores for each moment.
                range_scores (np.ndarray): Range (max-min) scores for each moment.
                avg_pairwise (np.ndarray): Average pairwise difference scores for each moment.
            """
            # Method 1: Standard deviation across classes for each moment.
            std_scores = np.std(moment_matrix, axis=0)
            
            # Method 2: Range (maximum minus minimum) for each moment.
            range_scores = np.ptp(moment_matrix, axis=0)  # np.ptp returns the range (peak to peak)
            
            # Method 3: Average pairwise absolute differences between classes for each moment.
            n, m = moment_matrix.shape
            avg_pairwise = np.zeros(m)
            for j in range(m):
                pairwise_diffs = []
                for i in range(n):
                    for k in range(i + 1, n):
                        diff = abs(moment_matrix[i, j] - moment_matrix[k, j])
                        pairwise_diffs.append(diff)
                avg_pairwise[j] = np.mean(pairwise_diffs)
                
            return std_scores, range_scores, avg_pairwise



        Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],
                                                                        [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],
            [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],
            [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,
            [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],
            [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],
            [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
            ]]

        induced_list = motifs_to_induced_motifs(Es)

        # Suppose we have data from 3 classes and 4 moments per class:
        moment_matrix = []

        for label, graphs in class_graphs:
            estimated_densities = np.zeros(9)
            nx_graphs = [nx.from_numpy_array(graph) for graph in graphs]
            for graph in nx_graphs:
                node_orbit_counts = orca(graph)
                density = count2density(node_orbit_counts, graph.number_of_nodes())
                estimated_densities += density
            real_moments = estimated_densities / len(graphs)
            moment_matrix.append(real_moments)
        moment_matrix = np.array(moment_matrix)


        std_scores, range_scores, avg_pairwise = compute_moment_scores(moment_matrix)






        ############################################
        graphons = []
        start_time = time()
        for label, graphs in class_graphs:
            logger.info(f"Label {label}: {len(graphs)} graphs")
            num_estimate = int(aug_ratio * len(graphs))
            inr_graphs = random.sample(graphs, num_estimate)
            logger.info(f"Selected {len(inr_graphs)} graphs for graphon estimation for label {label}")
            
            if ge == "ISGL":
                logger.info("Using ISGL for graphon estimation")
                gnn_dim_hidden = [8]
                epoch_show = int(n_epochs_inr / 5)
                inr_dim_hidden = [20, 20]
                batch_size_inr = 1024
                inr_lr = 0.01
                inr_w = 10
                # Call the function from your module (ensure this is defined)
                model_ISGL_0, _ = coords_prediction(inr_dim_hidden, gnn_dim_hidden, int(2*n_epochs_inr), epoch_show, inr_w, inr_graphs, inr_lr)
                X_all_0, y_all_0, w_all_0 = graph2XY(inr_graphs, model_ISGL_0)
                logger.info("Number of datapoints for graphon estimation: {}".format(X_all_0.shape[0]))
                trained_inr_0 = train_graphon(inr_dim_hidden, inr_w, X_all_0, y_all_0, w_all_0, n_epochs_inr, epoch_show, inr_lr, batch_size_inr)
                graphon = get_graphon(100, trained_inr_0, coords=None)
                graphons.append((label, graphon, trained_inr_0))
                
            elif ge == "Moment":
                logger.info("Using Moment for graphon estimation")
                import networkx as nx
                nx_graphs = [nx.from_numpy_array(graph) for graph in inr_graphs]
                trained_model = train_Moment(nx_graphs, 0, "MLP")
                graphon = get_graphon(100, trained_model, coords=None)
                graphons.append((label, graphon, trained_model))
                
            elif ge == "IGNR":
                logger.info("Using IGNR for graphon estimation")
                gl_mlp = IGNR_pg_wrapper([20,20,20], w0=30)
                loss = gl_mlp.train(inr_graphs, K='input', n_epoch=n_epochs_inr, f_sample='fixed')
                W1 = gl_mlp.get_W(100)
                graphons.append((label, W1, gl_mlp))
                
            else:
                # Fallback: align graphs and compute SVD-based graphon
                align_graphs_list, normalized_node_degrees, max_num, min_num = align_graphs(inr_graphs, padding=True, N=resolution)
                logger.info(f"Aligned graph shape: {align_graphs_list[0].shape}")
                graphon = universal_svd(align_graphs_list, threshold=0.2)
                logger.info(f"Graphon shape: {graphon.shape}")
                graphons.append((label, graphon, None))
                
        end_time = time()
        num_classes = len(graphons)
        logger.info(f"Graphon estimation time per class: {(end_time - start_time)/num_classes} s")

        plt.figure(figsize=(int(1 + 3*num_classes), 3))
        c = 1
        for label, graphon, _ in graphons:
            print(f"graphon info: label:{label}; mean: {graphon.mean()}, shape, {graphon.shape}")
            plt.subplot(1, num_classes, c)
            plt.imshow(graphon, cmap='hot', extent=[0, 1, 0, 1])
            plt.xticks([])
            plt.yticks([])
            plt.title(r"Class " + str(c))
            c += 1
        plt.tight_layout()
        plt.show()
                
        num_sample = int(train_nums * aug_ratio / aug_num)
        lam_list = np.random.uniform(low=lam_range[0], high=lam_range[1], size=(aug_num,))
        new_graph = []
        for lam in lam_list:
            logger.info(f"lam: {lam}")
            logger.info(f"num_sample: {num_sample}")
            two_graphons = random.sample(graphons, 2)
            upper_bound = 600
            lower_bound = 300
            lower_bound = max(lower_bound, min_num_nodes)
            new_graph += two_graphons_mixup(two_graphons, la=lam, num_sample=num_sample, ge=ge, resolution=[lower_bound, upper_bound])
            logger.info(f"New graph label: {new_graph[-1].y}")
            
        avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(new_graph)
        logger.info(f"New graphs: avg nodes: {avg_num_nodes}, min nodes: {min_num_nodes}, max nodes: {max_num_nodes}")
        logger.info(f"Real augmentation ratio: {len(new_graph)/train_nums}")
        dataset = new_graph + dataset
        train_nums += len(new_graph)
        train_val_nums += len(new_graph)

    # Prepare node features for whole dataset
    dataset = prepare_dataset_x(dataset)
    logger.info(f"Dataset feature shape: {dataset[0].x.shape}")
    logger.info(f"Dataset label shape: {dataset[0].y.shape}")

    num_features = dataset[0].x.shape[1]
    num_classes = dataset[0].y.shape[0]

    for data in dataset:
        if not hasattr(data, 'edge_attr') or data.edge_attr is None:
            data.edge_attr = torch.ones((data.edge_index.size(1), 3)) 

    # Split dataset into train, validation and test
    train_dataset = dataset[:train_nums]
    random.shuffle(train_dataset)
    val_dataset = dataset[train_nums:train_val_nums]
    test_dataset = dataset[train_val_nums:]
    
    logger.info(f"Train dataset size: {len(train_dataset)}")
    logger.info(f"Validation dataset size: {len(val_dataset)}")
    logger.info(f"Test dataset size: {len(test_dataset)}")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    # Instantiate the model
    if model_name == "GIN":
        model_instance = GIN(num_features=num_features, num_classes=num_classes, num_hidden=num_hidden).to(device)
    else:
        logger.info("No valid model specified.")
        continue

    optimizer = torch.optim.Adam(model_instance.parameters(), lr=learning_rate, weight_decay=5e-4)
    scheduler = StepLR(optimizer, step_size=100, gamma=0.5)

    # Train the model
    last_10_epoch_acc = []  # To record test acc for the last 10 epochs
    best_val_acc = 0
    best_test_acc = 0

    for epoch in range(1, num_epochs):
        train_loss = train(model_instance, train_loader, optimizer, device, num_classes)
        val_acc, val_loss = test(model_instance, val_loader, device, num_classes)
        test_acc, test_loss = test(model_instance, test_loader, device, num_classes)
        scheduler.step()

        if val_acc >= best_val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc
            best_epoch = epoch

        if epoch > num_epochs - 10:
            last_10_epoch_acc.append(test_acc)

        if epoch % 30 == 0:
            print('Epoch: {:03d}, Train Loss: {:.6f}, Val Loss: {:.6f}, Test Loss: {:.6f}, Val Acc: {:.6f}, Test Acc: {:.6f}'.format(
                epoch, train_loss, val_loss, test_loss, val_acc, test_acc))
    
    avg_last10 = np.mean(last_10_epoch_acc)
    print(f"Seed {seed}: Last 10 epochs average test acc: {np.round(avg_last10,3)}")
    print(f"Seed {seed}: Best test acc (based on validation) = {best_test_acc} at epoch {best_epoch}")

    all_best_test_acc.append(best_test_acc)
    all_last10_avg_acc.append(avg_last10)

########################################
# After All Seeds: Print Final Metrics
########################################
print("==========================================")
print("Overall Results Across Seeds:")
print(f"Best Test Accuracy: Mean = {np.mean(all_best_test_acc):.3f}, Std = {np.std(all_best_test_acc):.3f}")
print(f"Last 10 Epoch Test Accuracy Averag e: Mean = {np.mean(all_last10_avg_acc):.3f}, Std = {np.std(all_last10_avg_acc):.3f}")
