In [None]:
import os
import pandas as pd
import datetime as dt
import numpy as np
import time 
import torch
import pickle
from torch_geometric.data import Data
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv
import torch.nn.functional as F
from sklearn.metrics import precision_score, accuracy_score,f1_score,roc_auc_score
import warnings
warnings.filterwarnings("ignore")

In [None]:
# for GAT run this else run next block
from sklearn.metrics import roc_auc_score
from torch_geometric.utils import negative_sampling


class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads = 3):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels,heads = heads)
        self.conv2 = SAGEConv(hidden_channels*heads, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(
            dim=-1
        )  # product of a pair of nodes on each edge

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()
    

def train_link_predictor(
    model, train_data, val_data, optimizer, criterion, n_epochs=100
):
    best_model = model
    best_val_auc = 0
    val_auc_list = []
    train_loss_list = []
    for epoch in range(1, n_epochs + 1):

        model.train()
        optimizer.zero_grad()
        z = model.encode(train_data.x, train_data.edge_index)

        # sampling training negatives for every training epoch
        neg_edge_index = negative_sampling(
            edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
            num_neg_samples=train_data.edge_label_index.size(1), method='sparse')

        edge_label_index = torch.cat(
            [train_data.edge_label_index, neg_edge_index],
            dim=-1,
        )
        edge_label = torch.cat([
            train_data.edge_label,
            train_data.edge_label.new_zeros(neg_edge_index.size(1))
        ], dim=0)

        out = model.decode(z, edge_label_index).view(-1)
        loss = criterion(out, edge_label)
        loss.backward()
        optimizer.step()

        val_auc = eval_link_predictor(model, val_data)
        val_auc_list.append(val_auc)
        if (val_auc > best_val_auc):
            
            torch.save(model,"best_link_prediction.pt")
            best_model = model
            best_val_auc = val_auc
            line = "Epoch: "+str(epoch) + "\tTrain Loss: "+str(loss) + "\tVal AUC: "+str(best_val_auc)
            
        if epoch%10 == 0:
            print(f"Epoch: {epoch:03d}, Train Loss: {loss:.3f}, Val AUC: {val_auc:.3f}")
        #train_loss_list.append({loss:.3f})
        train_loss_list.append(loss)
    print("\n\n for best model :\n",line)
    
    return best_model,val_auc_list,train_loss_list #,train_auc_list


@torch.no_grad()
def eval_link_predictor(model, data):
    with torch.no_grad():
        model.eval()
        z = model.encode(data.x, data.edge_index)
        out = model.decode(z, data.edge_label_index).view(-1).sigmoid()

    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())

In [None]:
#user set params
dir_path = "WSJ split data files"
dbname = "WSJ"
node_feature_dir_path = "WSJ node features"
epochs = 100

In [None]:
for split_no in range(1,11,1):
    test_file_no = split_no
    print("Test no ", test_file_no, " going on")
    print("\n\n\n Test file no : ",test_file_no)
    train_df = pd.DataFrame()
    for i in range(1,11,1):
        if i != test_file_no:
            df = pd.read_csv(dir_path+"/WSJ_positive_edges_split_"+str(i)+".csv") 
            train_df = train_df.append(df, ignore_index=True)
            del df
        
            
    train_df = train_df.drop(['layer', 'weight',"sign"], axis=1)
    
    #making node_id variables continuous
    temp_node_no = pickle.load(open("node_id_into_continuous_node_ids.p","rb"))
    
    # collecting all positive edges
    train_df['src'] = [temp_node_no[i] for i in train_df['src']]
    train_df['dest'] = [temp_node_no[i] for i in train_df['dest']]
    final_edge_list = [torch.Tensor([train_df['src']]),torch.Tensor([train_df['dest']])]
    final_edge_list = torch.cat(tuple(final_edge_list),dim=0)  
    
    # downloading  node_embeddings 
    node_embeddings = pickle.load(open(node_feature_dir_path + "/node_embeddings_for_pytorch_models_excluding_split_"+str(test_file_no)+".p","rb"))
    
    #making training and val graphs
    data = Data()
    data.x = node_embeddings
    data.edge_index = final_edge_list.type(torch.int64)

    split = T.RandomLinkSplit(
        num_val=0.05,
        num_test=0.0,
        is_undirected=True,
        add_negative_train_samples=False,
        neg_sampling_ratio=1.0)
    train_data, val_data, test_data = split(data)
    
    del train_df, node_embeddings, final_edge_list, data
   
    break
    model = Net(8, 3, 2) #.to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
    criterion = torch.nn.BCEWithLogitsLoss()
    
    curr = time.time()
    best_model,_,_ = train_link_predictor(model, train_data, val_data, optimizer, criterion,n_epochs=epochs)
    timetaken = time.time()-curr
    print("time taken for training : ",timetaken)
    print("time taken for per epoch : ",timetaken/epochs)
    
    torch.save(best_model.state_dict(), "my_"+str(dbname)+"_SAGE_model_state_dict_"+str(test_file_no))
    torch.save(best_model, "my_"+str(dbname)+"_SAGE_whole_model_"+str(test_file_no))