In [None]:
# from IPython.core.display import display, HTML
# display(HTML("<style>.container { width:90% !important; }</style>"))

In [1]:
import os, sys
import argparse
import os.path as osp
import random
from time import perf_counter
import yaml
import numpy as np
import scipy.sparse as sp
import scipy

import torch
import torch.nn.functional as F
import torch.nn as nn

import logging
import shutil
import ast

from lib_EGNN_Pytorch.models.model_app import RwCL_Model
from lib_EGNN_Pytorch import utils, Post_utils, evaluation
from lib_EGNN_Pytorch.data_preprocessing import Pre_utils

from lib_EGNN_Pytorch.app.RwCL import basic_exec_link, basic_exec_node_classification
from lib_EGNN_Pytorch.app.RwCL import RwCL_app

# from lib_EGNN_Pytorch.app.RwSL import basic_exec_cluster
# from lib_EGNN_Pytorch.app.RwSL import RwSL_app
# from lib_EGNN_Pytorch.app.RwCL.multi_exec import run_train_cluster, run_test_cluster
%matplotlib inline  

In [2]:
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

# Set up logging
logger = logging.getLogger()
logger.handlers = []
ch = logging.StreamHandler()
formatter = logging.Formatter(
        fmt='%(asctime)s (%(levelname)s): %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S')
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel('INFO')

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [3]:
# >>>>>>>>>>>>>>>>>>>> Setting input data and configurations:
tkipf_graph_path = "/home/xiangli/projects/tmpdata/GCN/Graph_Clustering/tkipf_gcn_data"
shchur_graph_path = "/home/xiangli/projects/tmpdata/GCN/Graph_Clustering/shchur_gnnbenchmark_data/npz"
sdcn_data_path = "/home/xiangli/projects/tmpdata/GCN/Graph_Clustering/sdcn/"

data_name = 'cora'
# data_name = 'cite'

workdir = f"/home/xiangli/projects/GCN_program/Workshop_local/EGNN_workdir_results/RwCL/{data_name}/link/"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config_file_name = f'config_{data_name.lower()}.yaml'
config_file_path = os.path.join('./config_data_RwCL_link/', config_file_name)
with open(config_file_path, 'r') as c:
    config = yaml.safe_load(c)   
    
# For strings that yaml doesn't parse (e.g. None)
for key, val in config.items():
    if type(val) is str:
        try:
            config[key] = ast.literal_eval(val)
        except (ValueError, SyntaxError):
            pass

torch.manual_seed(config['seed'])
random.seed(12345)        
        
print(">>>>>>>>>>>>>>>>>>>> Setting Tuning :")
tune_param_name = "nothing"
# tune_param_name = "n_epochs"

tune_val_label_list = [1]
# tune_val_label_list = [0.0, 0.1, 0.2, 0.5]

# tune_val_list = [10**(-val) for val in tune_val_label_list]
tune_val_list = [val for val in tune_val_label_list]

# trainer_id_list = [0]
trainer_id_list = list(range(1))

print(f"Tune param : {tune_param_name} ; with the following values: {tune_val_list}")

print(">>>>>>>>>>>>>>>>>>>> Loading configs :")
for key, val in config.items():
    print(f"{key}    :    {val}")
    
# =================== copy the config file: =======================================
dest_folder = os.path.dirname(os.path.join(workdir, f"tune_{tune_param_name}/"))
if not os.path.exists(os.path.join(dest_folder, config_file_name)):    
    os.makedirs(dest_folder, exist_ok=True)
    shutil.copyfile(config_file_path,  os.path.join(dest_folder, config_file_name))    

>>>>>>>>>>>>>>>>>>>> Setting Tuning :
Tune param : nothing ; with the following values: [1]
>>>>>>>>>>>>>>>>>>>> Loading configs :
data_name    :    cora
n_clusters    :    7
feat_dim    :    1433
seed    :    39788
lr    :    0.0001
enc_arch    :    256-128
mlp_arch    :    256-
num_proj_hidden    :    512
tau    :    1.0
drop_feature_rate_1    :    0.08
view_num    :    3
num_epochs    :    300
weight_decay    :    0.02
batch_size_train    :    512
eval_display    :    10
loss_batch_size    :    0
alpha    :    0.1
rmax    :    1e-06
rrz    :    0.4
batchnorm    :    False
dropout_rate    :    0.1


### Use tkipf dataset: Cora, PubMed

In [4]:
Cython_GBP_data_path = f"/home/xiangli/projects/tmpdata/GCN/Graph_Clustering/tkipf_gcn_data/Packed_data/no_row_normalize/{data_name.lower()}/cython_GBP_input/"
link_adj_GBP_file = os.path.join(Cython_GBP_data_path, "link_prediction", f'{data_name.lower()}_link_np_array.npz')

adj_full, features, labels_full, _ = Pre_utils.load_gcn_tkipf_data(tkipf_graph_path, 
                                                        data_name.lower(), normalize = False, redo_save = False)

features, adj_matrix_cython, val_edges, val_edges_false, test_edges, test_edges_false, _ = \
    Pre_utils.load_Cython_GBP_input_link(data_name.lower(), features, adj_full,
                adj_GBP_file_path = link_adj_GBP_file, directed = False, redo_save = False, val_frac = 0.05, test_frac = 0.1)

print(sum(adj_full.diagonal()))

Packed data already exists at: /home/xiangli/projects/tmpdata/GCN/Graph_Clustering/tkipf_gcn_data/Packed_data/no_row_normalize/cora, LOADING...
Loading the pre-existent adj_GBP_file at : /home/xiangli/projects/tmpdata/GCN/Graph_Clustering/tkipf_gcn_data/Packed_data/no_row_normalize/cora/cython_GBP_input/link_prediction/cora_link_np_array.npz
0.0


### Execute Single Run

In [5]:
features_GBP = Pre_utils.precompute_Cython_GBP_feat(data_name, 40, 
                                config["alpha"], config["rmax"], config["rrz"], 
                                rwnum = 0, directed = False, add_self_loop = False,
                                rand_seed = 10, 
                                feats = features, adj_matrix = adj_matrix_cython)

tune_val_label = tune_val_label_list[0]
tune_val = tune_val_list[0]
trainer_id = trainer_id_list[0]

Total pre-computation time cost is 31.886870659000124 seconds! 


### Perform train cluster

In [None]:
input_data = [features_GBP, adj_full, val_edges, val_edges_false, test_edges, test_edges_false]

model = RwCL_Model(config)

checkpoint_file_path = os.path.join(workdir,
                f"tune_{tune_param_name}/model_checkpoint/tunelabel_{tune_val_label}_trainer_{trainer_id}/best_model.pkl")

if os.path.exists(checkpoint_file_path):
    print("ckpt file already exists, so removed ...")
    os.remove(checkpoint_file_path)
else:
    os.makedirs(os.path.dirname(checkpoint_file_path), exist_ok=True)

val_metric_path = os.path.join(workdir, 
                            f"tune_{tune_param_name}/val_metric/tunelabel_{tune_val_label}_trainer_{trainer_id}/val_metric.pkl")

# ==========================  Start the training ==========================
time_training, metric_summary = basic_exec_link.train(model, config, input_data, device = device, checkpoint_file_path = checkpoint_file_path)


### Perform Test cluster

In [None]:
model_test = RwCL_Model(config)

checkpoint_file_path = os.path.join(workdir,
                f"tune_{tune_param_name}/model_checkpoint/tunelabel_{tune_val_label}_trainer_{trainer_id}/best_model.pkl")

if not os.path.exists(os.path.dirname(checkpoint_file_path)):
    raise("checkpoint file is missing")

test_metric_path = os.path.join(workdir,
                f"tune_{tune_param_name}/test_metric/tunelabel_{tune_val_label}_trainer_{trainer_id}/test_metric.pkl")

# >>>>>>>>>>>>>>>>>>>> Start test inference
test_time, test_metric = basic_exec_link.test(model_test, config, input_data, 
                            device = "cpu", checkpoint_file_path = checkpoint_file_path)

### Perform train cluster from a class defined from the interface

In [6]:
obj = RwCL_app.RwCL_framework(config)

In [7]:
input_data = [features_GBP, adj_full, val_edges, val_edges_false, test_edges, test_edges_false]

model = RwCL_Model(config)

checkpoint_file_path = os.path.join(workdir,
                f"tune_{tune_param_name}/model_checkpoint/tunelabel_{tune_val_label}_trainer_{trainer_id}/best_model.pkl")

if os.path.exists(checkpoint_file_path):
    print("ckpt file already exists, so removed ...")
    os.remove(checkpoint_file_path)
else:
    os.makedirs(os.path.dirname(checkpoint_file_path), exist_ok=True)

val_metric_path = os.path.join(workdir, 
                            f"tune_{tune_param_name}/val_metric/tunelabel_{tune_val_label}_trainer_{trainer_id}/val_metric.pkl")

# ==========================  Start the training ==========================
time_training, metric_summary = obj.train_link(model, config, input_data, device = device, checkpoint_file_path = checkpoint_file_path)


2022-04-21 19:43:53 (INFO): Epoch   10 | total train loss: 6.337 | Trained local batch number: 4 | train time: 0.663s
2022-04-21 19:43:54 (INFO): Epoch   20 | total train loss: 6.266 | Trained local batch number: 2 | train time: 0.886s
2022-04-21 19:43:55 (INFO): Epoch   30 | total train loss: 4.967 | Trained local batch number: 6 | train time: 1.106s
2022-04-21 19:43:55 (INFO): Epoch   40 | total train loss: 6.222 | Trained local batch number: 4 | train time: 1.319s
2022-04-21 19:43:56 (INFO): Epoch   50 | total train loss: 6.209 | Trained local batch number: 2 | train time: 1.528s
2022-04-21 19:43:56 (INFO): Epoch   60 | total train loss: 4.897 | Trained local batch number: 6 | train time: 1.745s
2022-04-21 19:43:57 (INFO): Epoch   70 | total train loss: 6.166 | Trained local batch number: 4 | train time: 1.958s
2022-04-21 19:43:58 (INFO): Epoch   80 | total train loss: 6.160 | Trained local batch number: 2 | train time: 2.179s
2022-04-21 19:43:58 (INFO): Epoch   90 | total train los

In [8]:
model_test = RwCL_Model(config)

checkpoint_file_path = os.path.join(workdir,
                f"tune_{tune_param_name}/model_checkpoint/tunelabel_{tune_val_label}_trainer_{trainer_id}/best_model.pkl")

if not os.path.exists(os.path.dirname(checkpoint_file_path)):
    raise("checkpoint file is missing")

test_metric_path = os.path.join(workdir,
                f"tune_{tune_param_name}/test_metric/tunelabel_{tune_val_label}_trainer_{trainer_id}/test_metric.pkl")

# >>>>>>>>>>>>>>>>>>>> Start test inference
test_time, test_metric = obj.test_link(model_test, config, input_data, 
                            device = "cpu", checkpoint_file_path = checkpoint_file_path)

2022-04-21 19:45:08 (INFO): Test metrics: | auc_score : 0.9076905904676862 | ap_score : 0.8981485826361854 | test time: 0.197s


### GPU flush

In [None]:
# # free GPU memory
# !(nvidia-smi | grep 'python' | awk '{ print $5 }' | xargs -n1 kill -9 )
# !(nvidia-smi | grep 'python')