In [None]:
# from IPython.core.display import display, HTML
# display(HTML("<style>.container { width:90% !important; }</style>"))

In [3]:
import os, sys
import argparse
import os.path as osp
import random
from time import perf_counter
import yaml
import numpy as np
import scipy.sparse as sp
import scipy

import torch
import torch.nn.functional as F
import torch.nn as nn

import logging
import shutil
import ast

from lib_EGNN_Pytorch.models.model_app import RwCL_Model
from lib_EGNN_Pytorch import utils, Post_utils, evaluation
from lib_EGNN_Pytorch.data_preprocessing import Pre_utils

from lib_EGNN_Pytorch.app.RwCL import basic_exec_cluster, basic_exec_link, basic_exec_node_classification
from lib_EGNN_Pytorch.app.RwCL import RwCL_app

from lib_EGNN_Pytorch.app.RwSL import basic_exec_cluster
from lib_EGNN_Pytorch.app.RwSL import RwSL_app
# from lib_EGNN_Pytorch.app.RwCL.multi_exec import run_train_cluster, run_test_cluster
%matplotlib inline  

In [None]:
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

# Set up logging
logger = logging.getLogger()
logger.handlers = []
ch = logging.StreamHandler()
formatter = logging.Formatter(
        fmt='%(asctime)s (%(levelname)s): %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S')
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel('INFO')

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [None]:
# >>>>>>>>>>>>>>>>>>>> Setting input data and configurations:
tkipf_graph_path = "/home/xiangli/projects/tmpdata/GCN/Graph_Clustering/tkipf_gcn_data"

data_name = 'cora'
# data_name = 'cite'

# data_name = 'acm'
# data_name = 'dblp'
# data_name = 'cite'

sdcn_data_path = "/home/xiangli/projects/tmpdata/GCN/Graph_Clustering/sdcn/"
workdir = f"/home/xiangli/projects/GCN_program/Workshop_local/EGNN_workdir_results/{data_name}/"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config_file_name = f'config_{data_name.lower()}.yaml'
config_file_path = os.path.join('./config_data_RwCL_clustering/', config_file_name)
with open(config_file_path, 'r') as c:
    config = yaml.safe_load(c)    
    
# For strings that yaml doesn't parse (e.g. None)
for key, val in config.items():
    if type(val) is str:
        try:
            config[key] = ast.literal_eval(val)
        except (ValueError, SyntaxError):
            pass

torch.manual_seed(config['seed'])
random.seed(12345)        
        
print(">>>>>>>>>>>>>>>>>>>> Setting Tuning :")
tune_param_name = "nothing"
# tune_param_name = "n_epochs"

tune_val_label_list = [1]
# tune_val_label_list = [0.0, 0.1, 0.2, 0.5]

# tune_val_list = [10**(-val) for val in tune_val_label_list]
tune_val_list = [val for val in tune_val_label_list]

# trainer_id_list = [0]
trainer_id_list = list(range(1))

print(f"Tune param : {tune_param_name} ; with the following values: {tune_val_list}")

print(">>>>>>>>>>>>>>>>>>>> Loading configs :")
for key, val in config.items():
    print(f"{key}    :    {val}")
    
# =================== copy the config file: =======================================
dest_folder = os.path.dirname(os.path.join(workdir, f"tune_{tune_param_name}/"))
if not os.path.exists(os.path.join(dest_folder, config_file_name)):    
    os.makedirs(dest_folder, exist_ok=True)
    shutil.copyfile(config_file_path,  os.path.join(dest_folder, config_file_name))    

### Use tkipf dataset: Cora, PubMed

In [None]:
Cython_GBP_data_path = f"/home/xiangli/projects/tmpdata/GCN/Graph_Clustering/tkipf_gcn_data/Packed_data/no_row_normalize/{data_name.lower()}/GBP_input/"
# Pre_utils.convert_GBP_input_tkipf_gcn(data_name.lower(), tkipf_graph_path, directed = False,  
#                                       normalize = False, redo_save = False)

adj_full, features, labels_full, _ = Pre_utils.load_gcn_tkipf_data(tkipf_graph_path, 
                                                        data_name.lower(), normalize = False, redo_save = False)

features = np.ascontiguousarray(features, dtype = np.float32)
adj_matrix_cython = np.ascontiguousarray(np.load(os.path.join(Cython_GBP_data_path, f'{data_name.lower()}_adj.npy')), dtype=np.int64)



### SDCN dataset: Cite, ACM, DBLP

In [None]:
Cython_GBP_data_path = os.path.join(sdcn_data_path, f"data_for_GBP_input/{data_name.lower()}/") 
Pre_utils.convert_GBP_input_sdcn(data_name, sdcn_data_path, target_path = Cython_GBP_data_path)

features, labels_full = Pre_utils.load_sdcn_data_func(sdcn_data_path, data_name)  # both features and labels are numpy array
adj_full = Pre_utils.load_sdcn_graph(sdcn_data_path, data_name)  # scipy.sparse.csr_matrix


features = np.ascontiguousarray(features, dtype = np.float32)
adj_matrix_cython = np.ascontiguousarray(np.load(os.path.join(Cython_GBP_data_path, f'{data_name.lower()}_adj.npy')), dtype=np.int64)

features_GBP = Pre_utils.GBP_feat_precomputation(data_name, 40, 
                                config["alpha"], config["rmax"], config["rrz"], 
                                rwnum = 0, directed = False, add_self_loop = True,
                                rand_seed = 10, 
                                feats = features, adj_matrix = adj_matrix_cython)



### Execute Single Run

In [None]:
features_GBP = Pre_utils.precompute_Cython_GBP_feat(data_name, 40, 
                                config["alpha"], config["rmax"], config["rrz"], 
                                rwnum = 0, directed = False, add_self_loop = False,
                                rand_seed = 10, 
                                feats = features, adj_matrix = adj_matrix_cython)

tune_val_label = tune_val_label_list[0]
tune_val = tune_val_list[0]
trainer_id = trainer_id_list[0]

### Perform train cluster

In [None]:
input_data = [features_GBP, labels_full, adj_full]

model = RwCL_Model(config)

checkpoint_file_path = os.path.join(workdir,
                f"tune_{tune_param_name}/model_checkpoint/tunelabel_{tune_val_label}_trainer_{trainer_id}/best_model.pkl")

if os.path.exists(checkpoint_file_path):
    print("ckpt file already exists, so removed ...")
    os.remove(checkpoint_file_path)
else:
    os.makedirs(os.path.dirname(checkpoint_file_path), exist_ok=True)

val_metric_path = os.path.join(workdir, 
                            f"tune_{tune_param_name}/val_metric/tunelabel_{tune_val_label}_trainer_{trainer_id}/val_metric.pkl")

# ==========================  Start the training ==========================
time_training, metric_summary = basic_exec_cluster.train(model, config, input_data, device = device, checkpoint_file_path = checkpoint_file_path)


### Perform Test cluster

In [None]:
model_test = RwCL_Model(config)

checkpoint_file_path = os.path.join(workdir,
                f"tune_{tune_param_name}/model_checkpoint/tunelabel_{tune_val_label}_trainer_{trainer_id}/best_model.pkl")

if not os.path.exists(os.path.dirname(checkpoint_file_path)):
    raise("checkpoint file is missing")

test_metric_path = os.path.join(workdir,
                f"tune_{tune_param_name}/test_metric/tunelabel_{tune_val_label}_trainer_{trainer_id}/test_metric.pkl")

# >>>>>>>>>>>>>>>>>>>> Start test inference
test_time, test_metric = basic_exec_cluster.test(model_test, config, input_data, 
                            device = "cpu", checkpoint_file_path = checkpoint_file_path)

### Perform train cluster from a class defined from the interface

In [None]:
obj = RwCL_app.RwCL_framework(config)

In [None]:
input_data = [features_GBP, labels_full, adj_full]

model = RwCL_Model(config)

checkpoint_file_path = os.path.join(workdir,
                f"tune_{tune_param_name}/model_checkpoint/tunelabel_{tune_val_label}_trainer_{trainer_id}/best_model.pkl")

if os.path.exists(checkpoint_file_path):
    print("ckpt file already exists, so removed ...")
    os.remove(checkpoint_file_path)
else:
    os.makedirs(os.path.dirname(checkpoint_file_path), exist_ok=True)

val_metric_path = os.path.join(workdir, 
                            f"tune_{tune_param_name}/val_metric/tunelabel_{tune_val_label}_trainer_{trainer_id}/val_metric.pkl")

# ==========================  Start the training ==========================
time_training, metric_summary = obj.train_cluster(model, config, input_data, device = device, checkpoint_file_path = checkpoint_file_path)


In [None]:
model_test = RwCL_Model(config)

checkpoint_file_path = os.path.join(workdir,
                f"tune_{tune_param_name}/model_checkpoint/tunelabel_{tune_val_label}_trainer_{trainer_id}/best_model.pkl")

if not os.path.exists(os.path.dirname(checkpoint_file_path)):
    raise("checkpoint file is missing")

test_metric_path = os.path.join(workdir,
                f"tune_{tune_param_name}/test_metric/tunelabel_{tune_val_label}_trainer_{trainer_id}/test_metric.pkl")

# >>>>>>>>>>>>>>>>>>>> Start test inference
test_time, test_metric = obj.test_cluster(model_test, config, input_data, 
                            device = "cpu", checkpoint_file_path = checkpoint_file_path)

In [None]:
for tune_val_label, tune_val in zip(tune_val_label_list, tune_val_list):
    
    features_GBP = Pre_utils.precompute_Cython_GBP_feat(data_name, 40, 
                                config["alpha"], config["rmax"], config["rrz"], 
                                rwnum = 0, directed = False, add_self_loop = False,
                                rand_seed = 10, 
                                feats = features, adj_matrix = adj_matrix_cython)
    
    input_data = [features_GBP, labels_full, adj_full]
    for trainer_id in trainer_id_list:
            print(f"Training >>> current tuning the hyper-param: {tune_param_name}; with a value : {tune_val} ; on trainer id: {trainer_id}" )
            # encoder = Encoder(config)
            model = RwCL_Model(config)
            
            run_train_cluster(data_name, model, input_data, workdir, config, tune_param_name, tune_val_label, tune_val, 
                                    trainer_id = trainer_id, device = device)

            # test_encoder = Encoder(config)
            model_test = RwCL_Model(config)
            
            run_test_cluster(data_name, model_test, input_data, workdir, config, 
                                    tune_param_name, tune_val_label, tune_val, trainer_id = trainer_id)

### Post-processing

In [None]:
for tune_val_label, tune_val in zip(tune_val_label_list, tune_val_list):
    for trainer_id in trainer_id_list:
        print(f"Validation Postprocessing >>> current tuning the hyper-param: {tune_param_name}; with a value : {tune_val} ; on trainer id: {trainer_id}" )
        Post_utils.draw_val_metrics(workdir, tune_param_name, tune_val_label, trainer_id = trainer_id)

In [None]:
Post_utils.generate_val_all_metric(workdir, tune_param_name, tune_val_label_list, tune_val_list, trainer_id_list, real_time=False)

Post_utils.generate_val_all_metric(workdir, tune_param_name, tune_val_label_list, tune_val_list, trainer_id_list, real_time=True)

In [None]:
Post_utils.generate_test_table(workdir, tune_param_name, 
                    tune_val_label_list, tune_val_list, trainer_id_list, skip_trainer = [])

Post_utils.plot_raw_tune_test_table(workdir, tune_param_name, 
                    tune_val_label_list, tune_val_list, trainer_id_list, skip_trainer = []) 

Post_utils.plot_stats_test_table(workdir, tune_param_name, 
                    tune_val_label_list, tune_val_list, trainer_id_list, skip_trainer = [])

### GPU flush

In [None]:
# # free GPU memory
# !(nvidia-smi | grep 'python' | awk '{ print $5 }' | xargs -n1 kill -9 )
# !(nvidia-smi | grep 'python')