In [37]:
#!/usr/bin/env python3  
# -*- coding: utf-8 -*- 
#----------------------------------------------------------------------------
"""
Created By  : Sayali Anil Alatkar 
Created Date: 06/01/2023 
version ='1.1'
"""
# ---------------------------------------------------------------------------
# Implementation for NestedGNN:Detecting Malicious Network Activity with Nested Graph Neural Networks 
# (doi: 10.1109/ICC45855.2022.9838698.)

"\nCreated By  : Sayali Anil Alatkar \nCreated Date: 06/01/2023 \nversion ='1.1'\n"

Training GNN has the following steps:

1. Creating graph batches with dataloaders
2. Message passing to learn node embeddings with gcn/gat layers
3. Readout to aggregate node embeddings
4. Train classification loss on learned graph emb

PsychAD contrasts: [link](https://docs.google.com/spreadsheets/d/1EAxMz9oSm4Ht-MFyo3dNo-FziTiqJvFgkQaDE1yxpyg/edit#gid=0)

In [2]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [43]:
import os
import random
import pickle
from collections import OrderedDict, Counter
os.environ['DGLBACKEND'] = 'pytorch'
import torch
import torch.nn as nn
import torch.nn.functional as F

import warnings
warnings.filterwarnings('ignore')

import dgl
import dgl.data
from dgl.nn import GraphConv
from dgl.data import DGLDataset
from dgl.dataloading import GraphDataLoader

from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import balanced_accuracy_score, confusion_matrix, average_precision_score

from captum.attr import IntegratedGradients
from functools import partial

import pandas as pd
import numpy as np

#from modified_gnnexplainer import NestedGNNExplainer
from nested_gnn import *

##### Create dataloaders with batches (of 12 graphs) for batched training

In [7]:
with open('./data/dgl_graphs_cc_grn_filtered_train_(204_individuals).pkl', 'rb') as f:
    train_patients = pickle.load(f)

with open('./data/dgl_graphs_cc_grn_filtered_heldout_(204_individuals).pkl', 'rb') as f:
    heldout_patients = pickle.load(f)

In [9]:
# SPLIT TRAIN/TEST
def generate_dataset(new_PsychAD2_GRN_Dataset,PsychAD2_5_CC_Dataset, patient_ids,
                     train_labels,test_labels,train_patient_ids, test_patient_ids,
                    num_cell_types = 12, OUTER_LAYER_BATCH = 10, INNER_LAYER_BATCH = 10*12):
    
    class GRN_train(DGLDataset):
        def __init__(self):
            super().__init__(name="GRN")

        def process(self):
            self.graphs = train_grn_graphs
            self.labels = train_grn_labels

        def __getitem__(self, i):
            return self.graphs[i], self.labels[i]

        def __len__(self):
            return len(self.graphs)
    
    class GRN_test(DGLDataset):
        def __init__(self):
            super().__init__(name="GRN")

        def process(self):
            self.graphs = test_grn_graphs
            self.labels = test_grn_labels


        def __getitem__(self, i):
            return self.graphs[i], self.labels[i]

        def __len__(self):
            return len(self.graphs)

    class CC_train(DGLDataset):
        def __init__(self):
            super().__init__(name="CC")

        def process(self):
            self.graphs = train_cc_graphs
            self.labels = train_labels

        def __getitem__(self, i):
            return self.graphs[i], self.labels[i]

        def __len__(self):
            return len(self.graphs)

    class CC_test(DGLDataset):
        def __init__(self):
            super().__init__(name="CC")

        def process(self):
            self.graphs = test_cc_graphs
            self.labels = test_labels


        def __getitem__(self, i):
            return self.graphs[i], self.labels[i]

        def __len__(self):
            return len(self.graphs)

    #train_patient_ids, test_patient_ids, train_labels, test_labels = train_test_split(patient_ids,labels, test_size=0.1, random_state=42, stratify=labels)

    train_grn_labels = [val for val in train_labels for _ in range(12)]
    test_grn_labels = [val for val in test_labels for _ in range(12)]

    train_dataset_grn, test_dataset_grn={},{}
    train_dataset_cc, test_dataset_cc={},{}

    for pid in patient_ids:
        if pid in train_patient_ids:
            train_dataset_grn[pid] = new_PsychAD2_GRN_Dataset[pid]
            train_dataset_cc[pid] = PsychAD2_5_CC_Dataset[pid]
        else:
            test_dataset_grn[pid] = new_PsychAD2_GRN_Dataset[pid]
            test_dataset_cc[pid] = PsychAD2_5_CC_Dataset[pid]
            
     
    train_grn_graphs = [v1 for k,v in train_dataset_grn.items() for k1,v1 in v.items()]
    test_grn_graphs = [v1 for k,v in test_dataset_grn.items() for k1,v1 in v.items()]

    train_cc_graphs = [v for k,v in train_dataset_cc.items() ]
    test_cc_graphs = [v for k,v in test_dataset_cc.items()]

    train_grn_obj = GRN_train()
    test_grn_obj = GRN_test()

    train_cc_obj = CC_train()
    test_cc_obj = CC_test()
    
    # CREATE DATALOADERS
    train_grn_dataloader = GraphDataLoader(train_grn_obj, batch_size=INNER_LAYER_BATCH, drop_last=False)
    test_grn_dataloader = GraphDataLoader(test_grn_obj, batch_size=INNER_LAYER_BATCH, drop_last=False)

    train_cc_dataloader = GraphDataLoader(train_cc_obj, batch_size=OUTER_LAYER_BATCH, drop_last=False)
    test_cc_dataloader = GraphDataLoader(test_cc_obj, batch_size=OUTER_LAYER_BATCH, drop_last=False)

    return train_grn_dataloader,test_grn_dataloader,train_cc_dataloader,test_cc_dataloader

In [10]:
# train_grn_obj,test_grn_obj,train_cc_obj,test_cc_obj = generate_dataset(new_PsychAD2_GRN_Dataset,PsychAD2_5_CC_Dataset,labels,)

In [29]:
train_grn, train_cc, train_labels = train_patients[0], train_patients[1], train_patients[2]

In [31]:
train_labels["label_binary"] = [1 if x=="AD" else 0 for x in train_labels["label"].to_list()]
train_labels

Unnamed: 0,SubID,label,label_binary
4,M10233,AD,1
6,M10282,AD,1
19,M11371,Control,0
23,M11588,AD,1
24,M11589,Control,0
...,...,...,...
1002,M96977,AD,1
1015,M97728,Control,0
1020,M98107,AD,1
1038,M99645,AD,1


In [32]:
patient_ids = train_labels["SubID"].to_list()
labels = train_labels["label_binary"].to_list()

##### Model for NestedGNN

##### Model initialization

In [33]:
np.random.seed(1234)

In [34]:
# Create the model with given dimensions
in_layer_dim=100
out_layer_dim=4735

h_dims = [256,128,2048,1024,512,256,128,64] # gnn layers = 2 inner + 4 outer
output_dim = 2

#model = NestedGNN(in_layer_dim, out_layer_dim, h_dims, output_dim)
#optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)

##### Model Training

In [35]:
num_cell_types = 12
OUTER_LAYER_BATCH = 10
INNER_LAYER_BATCH = OUTER_LAYER_BATCH*num_cell_types
epochs=50
lr = 0.001
gamma=0.9
n_splits=5
inner="GCN" #/GAT
outer="GCN" #/GAT

#### K-Fold Cross Validation

In [42]:
kf = KFold(n_splits=7,random_state=1, shuffle=True)

patient_ids = np.array(patient_ids)
labels = np.array(labels)

models={}

for i,kfsplit in enumerate(kf.split(patient_ids)):
    print (f"\nFold {i+1}")
    
    train_patient_ids, test_patient_ids = patient_ids[kfsplit[0]], patient_ids[kfsplit[1]]
    train_labels, test_labels = labels[kfsplit[0]], labels[kfsplit[1]]
        
    train_grn_dataloader,test_grn_dataloader,train_cc_dataloader,test_cc_dataloader=generate_dataset(
        train_grn, train_cc, patient_ids, train_labels, test_labels,
        train_patient_ids, test_patient_ids, num_cell_types = num_cell_types, 
        OUTER_LAYER_BATCH = OUTER_LAYER_BATCH, INNER_LAYER_BATCH = INNER_LAYER_BATCH)
    
    model = NestedGNN(in_layer_dim, out_layer_dim, h_dims, output_dim)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)

    for epoch in range(epochs):
        model.train()
        num_correct = 0
        num_tests = 0
        total_loss = 0
        total_examples = 0
        y_pred = []
        y_true = []
        ep_loss=[]
        for (batched_grn_graph, labels_grn), (batched_cc_graph, labels_cc) in \
            zip(train_grn_dataloader, train_cc_dataloader):

            optimizer.zero_grad()
            pred = model(batched_grn_graph,batched_cc_graph, 
                         batched_grn_graph.ndata["x"].float(), batched_cc_graph.ndata["x"].float())

            # TRY WEIGHTED LOSS FUNCTION:
            #Weight = [0.1, 0.9]
            #args.loss_fn = nn.CrossEntropyLoss(torch.tensor(weight).to(device).float())
            
            loss = F.cross_entropy(pred, torch.tensor(labels_cc, dtype=torch.long))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            num_correct += (pred.argmax(1) == labels_cc).sum().item()
            num_tests += len(labels_cc)

            total_loss+=loss.detach().numpy()
            total_examples += batched_cc_graph.batch_size
            ep_loss.append(total_loss/total_examples)

            y_true += labels_cc
            y_pred += pred.argmax(1)
        scheduler.step()
        
        if epoch%10 == 0 or epoch == epochs-1:
            print ("Epoch: {}".format(epoch))
            
            train_loss = round(np.mean(ep_loss),5)
            train_bacc = round(balanced_accuracy_score(y_true, y_pred),5)
            print (f"Trainloss: {train_loss}\tTrain BACC: {train_bacc}")
            #print("Train accuracy (Unbalanced):", num_correct / num_tests)
            
            model.eval()
            num_correct = 0
            num_tests = 0
            total_loss = 0
            total_examples = 0
            y_pred = []
            y_true = []
            b_loss=[]
            for (batched_grn_graph, labels_grn), (batched_cc_graph, labels_cc) in \
                zip(test_grn_dataloader, test_cc_dataloader):

                pred = model(batched_grn_graph,batched_cc_graph, 
                             batched_grn_graph.ndata["x"].float(), batched_cc_graph.ndata["x"].float())

                loss = F.cross_entropy(pred, torch.tensor(labels_cc, dtype=torch.long))
                num_correct += (pred.argmax(1) == labels_cc).sum().item()
                num_tests += len(labels_cc)
                y_true += labels_cc
                y_pred += pred.argmax(1)
                total_loss+=loss.detach().numpy()
                total_examples += batched_cc_graph.batch_size
                b_loss.append(round(total_loss/total_examples,2))
                #print ("Batch Loss:{}".format(round(total_loss/total_examples,2)))
            val_loss = round(total_loss/total_examples,5)
            val_bacc = round(balanced_accuracy_score(y_true, y_pred),5)
            print(f"Val Loss::{val_loss}\tVal BACC:{val_bacc}")
            
            if epoch == epochs-1:
                auprc = round(average_precision_score(y_true, y_pred),5)
                print (f"AUPRC on Validation set:{auprc}")
                models[i+1] = {"model": model.state_dict(), 
                               "Train BACC": train_bacc, "Train loss": train_loss,
                               "Val BACC": val_bacc, "Val loss": val_loss,
                               "Val AUPRC":auprc}


Fold 1
Epoch: 0
Trainloss: 5993660351464.699	Train BACC: 0.54441
Val Loss::2057930224435.2334	Val BACC:0.5
Epoch: 10
Trainloss: 46708090916.06448	Train BACC: 0.48444
Val Loss::129801877367.50142	Val BACC:0.5
Epoch: 20
Trainloss: 0.05509	Train BACC: 0.5
Val Loss::0.05731	Val BACC:0.5
Epoch: 30
Trainloss: 0.04887	Train BACC: 0.57937
Val Loss::0.05296	Val BACC:0.5625
Epoch: 40
Trainloss: 0.04386	Train BACC: 0.68899
Val Loss::0.05437	Val BACC:0.5
Epoch: 49
Trainloss: 0.03869	Train BACC: 0.6965
Val Loss::0.04899	Val BACC:0.68182
AUPRC on Validation set:0.81344

Fold 2
Epoch: 0
Trainloss: 610689212453.105	Train BACC: 0.4721
Val Loss::113913364480.06589	Val BACC:0.5
Epoch: 10
Trainloss: 2233160783.70787	Train BACC: 0.49265
Val Loss::0.07207	Val BACC:0.5
Epoch: 20
Trainloss: 0.0439	Train BACC: 0.57871
Val Loss::0.08372	Val BACC:0.5
Epoch: 30
Trainloss: 65864325652.8621	Train BACC: 0.58569
Val Loss::330115495.77125	Val BACC:0.5
Epoch: 40
Trainloss: 13656936917.41693	Train BACC: 0.78196
Val Los

In [38]:
avg_prc = 0
for i in models:
    avg_prc += models[i]["Val AUPRC"]
avg_prc/len(models)

0.79035

In [39]:
avg_bacc = 0
for i in models:
    avg_bacc += models[i]["Val BACC"]
avg_bacc/len(models)

0.614924

In [40]:
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
tn, fp, fn, tp 

(3, 4, 0, 33)

In [41]:
confusion_matrix(y_true, y_pred)

array([[ 3,  4],
       [ 0, 33]])

In [45]:
num_params = sum(p.numel() for p in model.parameters())
num_params

12775042

In [191]:
# SAVE MODEL DICTIONARIES
model_file = "models"+"_cv"+str(n_splits)+"_epochs"+str(epochs)+"_lr"+str(lr)+"_gamma"+str(gamma)+\
            "_outerbatch"+str(OUTER_LAYER_BATCH)+"_inner"+inner+"_outer"+outer+\
            "_numofcelltypes"+str(num_cell_types)+".pkl"

In [192]:
model_file

'models_cv5_epochs50_lr0.001_gamma0.9_outerbatch10_innerGCN_outerGCN_numofcelltypes12.pkl'

In [193]:
MODEL_FOLDER="./models/"

In [194]:
with open(MODEL_FOLDER+model_file, 'wb') as f:
    pickle.dump(models, f)

In [195]:
model

NestedGNN(
  (conv1): GraphConv(in=100, out=256, normalization=both, activation=None)
  (conv2): GraphConv(in=256, out=128, normalization=both, activation=None)
  (conv3): GraphConv(in=4863, out=2048, normalization=both, activation=None)
  (conv4): GraphConv(in=2048, out=1024, normalization=both, activation=None)
  (conv5): GraphConv(in=1024, out=512, normalization=both, activation=None)
  (conv6): GraphConv(in=512, out=256, normalization=both, activation=None)
  (classify): Linear(in_features=256, out_features=2, bias=True)
)