In [1]:
"""
    IMPORTING LIBS
"""
import random
import sys

# Add the parent directory of 'src' to the Python path
sys.path.append('/ranjan/GT')
%load_ext autoreload
%autoreload 2  

sys.dont_write_bytecode = True

import numpy as np
import os
import time
import torch
import glob
import torch.optim as optim
import argparse
import dgl 
from tqdm import tqdm
import json
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
"""
    IMPORTING CUSTOM MODULES/METHODS
"""
from src.data.data import LoadData, partition_graph
from src.data.embedding import mean_pooling, compute_laplacian_positional_embedding, compute_gcn_embeddings
from src.nets.load_net import gnn_model 
from src.utils.utils import *
from src.train.trainer import collate_graphs, evaluate_network, train_epoch
# from src.utils.supergraph import  create_DGLSupergraph
from src.utils.supergraph import create_feature_dataset
# from src.configs.config import load_config

from torch.utils.data import DataLoader
import dgl
import torch



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def gpu_setup(use_gpu, gpu_id):
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)  
    device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")
    # if device.type == "cuda":
    #     print('cuda available with GPU:', torch.cuda.get_device_name(0))
    # else:
    #     print('cuda not available, using CPU')
    return device


In [3]:
"""
    VIEWING MODEL CONFIG AND PARAMS
"""
def view_model_param(MODEL_NAME, net_params):
    model = gnn_model(MODEL_NAME, net_params)
    total_param = 0
    # print("MODEL DETAILS:\n")
    #print(model)
    for param in model.parameters():
        # print(param.data.size())
        total_param += np.prod(list(param.data.size()))
    # print('MODEL/Total parameters:', MODEL_NAME, total_param)
    return total_param

In [11]:

"""
    TRAINING CODE
"""
def train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs,graph,node_labels,node_counts,subgraphs):

    start0 = time.time()
    per_epoch_time = []
    
    DATASET_NAME = 'Cora'

    # Extract the masks
    train_mask = graph.ndata['train_mask']
    val_mask = graph.ndata['val_mask']
    test_mask = graph.ndata['test_mask']
    # print("train_mask : ",train_mask.shape)
    
    trainset = dataset
    valset = dataset
    testset = dataset

    root_log_dir, root_ckpt_dir, write_file_name, write_config_file = dirs
    device = net_params['device']
    
    # Write the network and optimization hyper-parameters in folder config/
    with open(write_config_file + '.txt', 'w') as f:
        f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n""".format(DATASET_NAME, MODEL_NAME, params, net_params, net_params['total_param']))
        
    log_dir = os.path.join(root_log_dir, "RUN_" + str(0))
    writer = SummaryWriter(log_dir=log_dir)    
    
    
    model = gnn_model(MODEL_NAME, net_params)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                    factor=params['lr_reduce_factor'],
                                                    patience=params['lr_schedule_patience'],
                                                    verbose=True)
    
    epoch_train_losses, epoch_val_losses = [], []
    epoch_train_accs, epoch_val_accs = [], [] 

    train_loader = DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, collate_fn=collate_graphs)
    val_loader = DataLoader(valset, batch_size=params['batch_size'], shuffle=False, collate_fn=collate_graphs)
    test_loader = DataLoader(testset, batch_size=params['batch_size'], shuffle=False, collate_fn=collate_graphs)
    # print("******** train_loader *********",len(train_loader)) # it is 1

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        with tqdm(range(params['epochs'])) as t:
            for epoch in t:   
                
                t.set_description('Epoch %d' % epoch)
            
                start = time.time()
                
                epoch_train_loss, epoch_train_acc, optimizer = train_epoch(model, optimizer, device, train_loader, epoch, train_mask,node_labels,node_counts)                

                epoch_val_loss, epoch_val_acc = evaluate_network(model, device, val_loader, epoch,  val_mask, node_labels, node_counts, phase="val")
                _, epoch_test_acc = evaluate_network(model, device, test_loader, epoch, test_mask, node_labels, node_counts, phase="test")                    

                epoch_train_losses.append(epoch_train_loss)
                epoch_val_losses.append(epoch_val_loss)
                epoch_train_accs.append(epoch_train_acc)
                epoch_val_accs.append(epoch_val_acc)                

                writer.add_scalar('train/_loss', epoch_train_loss, epoch)
                writer.add_scalar('val/_loss', epoch_val_loss, epoch)
                writer.add_scalar('train/_acc', epoch_train_acc, epoch)
                writer.add_scalar('val/_acc', epoch_val_acc, epoch)
                writer.add_scalar('test/_acc', epoch_test_acc, epoch)
                writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)                
                
                t.set_postfix(time=time.time()-start, lr=optimizer.param_groups[0]['lr'],
                            train_loss=epoch_train_loss, val_loss=epoch_val_loss,
                            train_acc=epoch_train_acc, val_acc=epoch_val_acc,
                            test_acc=epoch_test_acc)         
                
                per_epoch_time.append(time.time()-start)
                
                # Saving checkpoint
                ckpt_dir = os.path.join(root_ckpt_dir, "RUN_")
                if not os.path.exists(ckpt_dir):
                    os.makedirs(ckpt_dir)
                torch.save(model.state_dict(), '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch)))
    
                files = glob.glob(ckpt_dir + '/*.pkl')
                for file in files:
                    epoch_nb = file.split('_')[-1]
                    epoch_nb = int(epoch_nb.split('.')[0])
                    if epoch_nb < epoch-1:
                        os.remove(file)
    
                scheduler.step(epoch_val_loss)
    
                if optimizer.param_groups[0]['lr'] < params['min_lr']:
                    print("\n!! LR SMALLER OR EQUAL TO MIN LR THRESHOLD.")
                    break    
                
                # Stop training after params['max_time'] hours
                if time.time()-start0 > params['max_time']*3600:
                    print('-' * 89)
                    print("Max_time for training elapsed {:.2f} hours, so stopping".format(params['max_time']))
                    break

    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early because of KeyboardInterrupt')

    _, test_acc = evaluate_network(model, device, test_loader, epoch,test_mask, node_labels,node_counts)
    _, train_acc = evaluate_network(model, device, train_loader, epoch,train_mask, node_labels,node_counts)
    print("Test Accuracy: {:.4f}".format(test_acc))
    print("Train Accuracy: {:.4f}".format(train_acc))
    print("Convergence Time (Epochs): {:.4f}".format(epoch))
    print("TOTAL TIME TAKEN: {:.4f}s".format(time.time()-start0))
    print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))

    writer.close()

    """
        Write the results in out_dir/results folder
    """
    with open(write_file_name + '.txt', 'w') as f:
        f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
    FINAL RESULTS\nTEST ACCURACY: {:.4f}\nTRAIN ACCURACY: {:.4f}\n\n
    Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""\
          .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'],
                  test_acc, train_acc, epoch, (time.time()-start0)/3600, np.mean(per_epoch_time)))



    
    # print("\n Subgraph Comparison After Model Training :\n")
    test_node_predictions, test_batch_labels = evaluate_network(model, device, test_loader, epoch,test_mask, node_labels,node_counts, phase="test", CompareSubgraphFlag = True)
    
    # Get predicted classes
    predicted_classes = torch.argmax(test_node_predictions, dim=1)
    
    match_count = 0
    for i in range(len(predicted_classes)):
        # print("True Label : ",test_batch_labels[i].item()," Predicted Label : ",predicted_classes[i].item())
        if predicted_classes[i] == test_batch_labels[i]:
            match_count = match_count + 1
    print("match count  = ",match_count)
        
    
    
    # Example usage - loop through some subgraphs to plot comparisons
    num_plots = min(5,len(subgraphs))  # Limit number of plots 
    
    # Define colors outside the loop
    distinct_colors = [
        '#e41a1c', '#377eb8', '#4daf4a', '#984ea3', 
        '#ff7f00', '#a65628', '#756bb1', '#636363'
    ]
    num_colors = len(distinct_colors)
    
    for i in range(num_plots):
        # Calculate indices for this subgraph
        start_idx = sum(node_counts[:i].cpu().numpy())
        end_idx = start_idx + node_counts[i].cpu().numpy()
        
        # Get predictions and true labels for this subgraph
        subgraph_predictions = predicted_classes[start_idx:end_idx]
        subgraph_true_labels = test_batch_labels[start_idx:end_idx]
        
        # Create figure
        plt.figure(figsize=(15, 6))
        
        # Get the specific subgraph
        subgraph = subgraphs[i]
        G = subgraph.to_networkx().to_undirected()
        pos = nx.spring_layout(G, k=2/np.sqrt(len(G.nodes())), iterations=50, seed=42)

        # Plot original labels
        plt.subplot(1, 2, 1)
        colors = [distinct_colors[int(lab) % num_colors] for lab in subgraph_true_labels.cpu().numpy()]
        nx.draw_networkx_edges(G, pos, edge_color='black', alpha=0.2, width=0.9)
        nx.draw_networkx_nodes(G, pos, node_size=100, node_color=colors)
        plt.title('Original Labels')
        plt.axis('off')

        # Plot predicted labels  
        plt.subplot(1, 2, 2)
        colors = [distinct_colors[int(lab) % num_colors] for lab in subgraph_predictions.cpu().numpy()]
        nx.draw_networkx_edges(G, pos, edge_color='black', alpha=0.2, width=0.9)
        nx.draw_networkx_nodes(G, pos, node_size=100, node_color=colors)
        plt.title('Predicted Labels')
        plt.axis('off')

        # Add legend
        unique_labels = torch.unique(subgraph_true_labels)
        legend_elements = [plt.Line2D([0], [0], marker='o', color='w',
                                    markerfacecolor=distinct_colors[int(i)],
                                    label=f'Class {int(i)}',
                                    markersize=10) 
                        for i in unique_labels]
        plt.legend(handles=legend_elements, loc='center left', 
                bbox_to_anchor=(1, 0.5))

        # Add accuracy information
        accuracy = (subgraph_predictions == subgraph_true_labels).float().mean().item()
        plt.suptitle(f'Subgraph {i} Comparison\n'
                    f'Accuracy: {accuracy:.2%}')

        plt.tight_layout()
        # save at this directory /ranjan/GT/out/subgraph_prediction/
        plt.savefig(f'/ranjan/GT/out/subgraph_prediction/test/subgraph{i}', bbox_inches='tight', dpi=300)
        plt.close()
        
    print("\n Subgraph Comparison Completed\n")
        

In [12]:

from argparse import Namespace


def main():

    args = Namespace(config='src/configs/default_config.json', gpu_id=0, model=None, dataset='Cora', out_dir='out/', seed=None, epochs=None, batch_size=None, init_lr=None, lr_reduce_factor=None, lr_schedule_patience=None, min_lr=None, weight_decay=None, print_epoch_interval=None, L=None, hidden_dim=None, out_dim=None, residual=None, edge_feat=None, readout=None, n_heads=None, in_feat_dropout=None, dropout=None, layer_norm=None, batch_norm=None, self_loop=None, max_time=None, pos_enc_dim=None, lap_pos_enc=None, wl_pos_enc=None)

    with open('/ranjan/GT/src/configs/default_config.json') as f:
        config = json.load(f)

        # device
    if args.gpu_id is not None:
        config['gpu']['id'] = int(args.gpu_id)
        config['gpu']['use'] = True
    device = gpu_setup(config['gpu']['use'], config['gpu']['id']) 
    
    
    # model, dataset, out_dir
    if args.model is not None:
        MODEL_NAME = args.model
    else:
        MODEL_NAME = config['model']
    if args.dataset is not None:
        DATASET_NAME = args.dataset
    else:
        DATASET_NAME = config['dataset']
    
    if args.out_dir is not None:
        out_dir = args.out_dir
    else:
        out_dir = config['out_dir']
        
    # parameters
    
    params = config['params'] 
    if args.seed is not None:
        params['seed'] = int(args.seed)
    if args.epochs is not None:
        params['epochs'] = int(args.epochs)
    if args.batch_size is not None:
        params['batch_size'] = int(args.batch_size)
    if args.init_lr is not None:
        params['init_lr'] = float(args.init_lr)
    if args.lr_reduce_factor is not None:
        params['lr_reduce_factor'] = float(args.lr_reduce_factor)
    if args.lr_schedule_patience is not None:
        params['lr_schedule_patience'] = int(args.lr_schedule_patience)
    if args.min_lr is not None:
        params['min_lr'] = float(args.min_lr)
    if args.weight_decay is not None:
        params['weight_decay'] = float(args.weight_decay)
    if args.print_epoch_interval is not None:
        params['print_epoch_interval'] = int(args.print_epoch_interval)
    if args.max_time is not None:
        params['max_time'] = float(args.max_time)    
    
    # network parameters
    net_params = config['net_params']
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']
    if args.L is not None:
        net_params['L'] = int(args.L)
    if args.hidden_dim is not None:
        net_params['hidden_dim'] = int(args.hidden_dim)
    if args.out_dim is not None:
        net_params['out_dim'] = int(args.out_dim)   
    if args.residual is not None:
        net_params['residual'] = True if args.residual=='True' else False
    if args.edge_feat is not None:
        net_params['edge_feat'] = True if args.edge_feat=='True' else False
    if args.readout is not None:
        net_params['readout'] = args.readout
    if args.n_heads is not None:
        net_params['n_heads'] = int(args.n_heads)
    if args.in_feat_dropout is not None:
        net_params['in_feat_dropout'] = float(args.in_feat_dropout)
    if args.dropout is not None:
        net_params['dropout'] = float(args.dropout)
    if args.layer_norm is not None:
        net_params['layer_norm'] = True if args.layer_norm=='True' else False
    if args.batch_norm is not None:
        net_params['batch_norm'] = True if args.batch_norm=='True' else False
    if args.self_loop is not None:
        net_params['self_loop'] = True if args.self_loop=='True' else False
    if args.lap_pos_enc is not None:
        net_params['lap_pos_enc'] = True if args.pos_enc=='True' else False
    if args.pos_enc_dim is not None:
        net_params['pos_enc_dim'] = int(args.pos_enc_dim)
    if args.wl_pos_enc is not None:
        net_params['wl_pos_enc'] = True if args.pos_enc=='True' else False

    # Cora
    net_params['in_dim'] = config['gcn']['output_dim']  # This is 16 from GCN output
    net_params['n_classes'] = 7
    
    root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file
    
    if not os.path.exists(out_dir + 'results'):
        os.makedirs(out_dir + 'results')
        
    if not os.path.exists(out_dir + 'configs'):
        os.makedirs(out_dir + 'configs')

    set_seed(params['seed'])        

    graph = LoadData(DATASET_NAME)  # Now using DGL data loading


    subgraphs = partition_graph(graph, num_parts=config['data']['num_parts'])

    
    # Initialize lists for storing embeddings and metadata
    subgraph_embeddings, lpe_embeddings = [], []
    node_labels, node_counts, node_indices = [], [], []
    start_idx = 0

    # Process each subgraph
    for i, subgraph in enumerate(subgraphs):
        num_nodes = subgraph.number_of_nodes()
        node_indices.append(torch.arange(start_idx, start_idx + num_nodes, device=device))
        start_idx += num_nodes
        
        # Compute embeddings
        gcn_embeddings = compute_gcn_embeddings(
            subgraph, 
            input_dim=config['gcn']['input_dim'],
            hidden_dim=config['gcn']['hidden_dim'],
            output_dim=config['gcn']['output_dim']
        )
        lpe = compute_laplacian_positional_embedding(subgraph, embedding_dim=config['gcn']['output_dim'])
        
        # Store results
        subgraph_embeddings.append(mean_pooling(gcn_embeddings))
        lpe_embeddings.append(mean_pooling(lpe))
        
        # Get labels from DGL graph
        node_labels.append(subgraph.ndata['label'])
        node_counts.append(num_nodes)

    # Stack and move to device
    subgraph_embeddings = torch.stack(subgraph_embeddings).to(device)
    lpe_embeddings = torch.stack(lpe_embeddings).to(device)
    
    node_labels = torch.cat(node_labels, dim=0).to(device)
    node_counts = torch.tensor(node_counts).to(device)
    

    # add embeddings
    combined_embedding = subgraph_embeddings + lpe_embeddings
    dataset = create_feature_dataset(combined_embedding)
    
    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
    train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs,graph,node_labels,node_counts,subgraphs)

if __name__ == "__main__":
    main()

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
Convert a graph into a bidirected graph: 0.001 seconds, peak memory: 12.766 GB
Construct multi-constraint weights: 0.000 seconds, peak memory: 12.766 GB
Metis partitioning: 0.012 seconds, peak memory: 12.766 GB


[16:00:01] /opt/dgl/src/graph/transform/metis_partition_hetero.cc:89: Partition a graph with 2708 nodes and 10556 edges into 100 parts and get 1497 edge cuts
Epoch 0: 100%|██████████| 1/1 [00:00<00:00, 11.16it/s, lr=0.0007, test_acc=27.9, time=0.0771, train_acc=1.43, train_loss=1.91, val_acc=5.8, val_loss=1.88]



Test Accuracy: 27.9000
Train Accuracy: 1.4286
Convergence Time (Epochs): 0.0000
TOTAL TIME TAKEN: 0.1960s
AVG TIME PER EPOCH: 0.0783s
match count  =  279

 Subgraph Comparison Completed



In [None]:
1+1