In [None]:
import numpy as np
import torch
import pickle
import time
import os
%matplotlib inline
import matplotlib.pyplot as plt
import math
from tqdm import tqdm
from tensorboardX import SummaryWriter
import glob
import copy

In [None]:
import os
os.chdir('../../') # go to root folder of the project
print(os.getcwd())


In [None]:
import pickle

%load_ext autoreload
%autoreload 2

from data.superpixels import SuperPixDatasetDGL 

from data.data import LoadData
from torch.utils.data import DataLoader
from data.superpixels import SuperPixDataset


In [None]:
DATASET_NAME = 'ENZYMES'
dataset = LoadData(DATASET_NAME) 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
import dgl
import torch.optim as optim
from dgl.nn.pytorch import GATConv
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts


class MLPReadout(nn.Module):

    def __init__(self, input_dim, output_dim, L=2): 
        super().__init__()
        list_FC_layers = [ nn.Linear( input_dim//2**l , input_dim//2**(l+1) , bias=True ) for l in range(L) ]
        list_FC_layers.append(nn.Linear( input_dim//2**L , output_dim , bias=True ))
        self.FC_layers = nn.ModuleList(list_FC_layers)
        self.L = L
        
    def forward(self, x):
        y = x
        for l in range(self.L):
            y = self.FC_layers[l](y)
            y = F.relu(y)
        y = self.FC_layers[self.L](y)
        return y


class GatedGCNLayer(nn.Module):
    def __init__(self, input_dim, output_dim, dropout, batch_norm, residual=False):
        super().__init__()
        self.in_channels = input_dim
        self.out_channels = output_dim
        self.dropout = dropout
        self.batch_norm = batch_norm
        self.residual = residual
        
        if input_dim != output_dim:
            self.residual = False
        
        self.A = nn.Linear(input_dim, output_dim, bias=True)
        self.B = nn.Linear(input_dim, output_dim, bias=True)
        self.C = nn.Linear(input_dim, output_dim, bias=True)
        self.D = nn.Linear(input_dim, output_dim, bias=True)
        self.E = nn.Linear(input_dim, output_dim, bias=True)
        self.bn_node_h = nn.BatchNorm1d(output_dim)
        self.bn_node_e = nn.BatchNorm1d(output_dim)

    def message_func(self, edges):
        Bh_j = edges.src['Bh']    
        e_ij = edges.data['Ce'] +  edges.src['Dh'] + edges.dst['Eh'] # e_ij = Ce_ij + Dhi + Ehj
        edges.data['e'] = e_ij
        return {'Bh_j' : Bh_j, 'e_ij' : e_ij}

    def reduce_func(self, nodes):
        Ah_i = nodes.data['Ah']
        Bh_j = nodes.mailbox['Bh_j']
        e = nodes.mailbox['e_ij'] 
        sigma_ij = torch.sigmoid(e) # sigma_ij = sigmoid(e_ij)
        h = Ah_i + torch.sum( sigma_ij * Bh_j, dim=1 ) / ( torch.sum( sigma_ij, dim=1 ) + 1e-6 )  # hi = Ahi + sum_j eta_ij/sum_j' eta_ij' * Bhj <= dense attention       
        return {'h' : h}
    
    def forward(self, g, h, e):
        h_in = h # for residual connection
        e_in = e # for residual connection
        
        g.ndata['h']  = h 
        g.ndata['Ah'] = self.A(h) 
        g.ndata['Bh'] = self.B(h) 
        g.ndata['Dh'] = self.D(h)
        g.ndata['Eh'] = self.E(h) 
        g.edata['e']  = e 
        g.edata['Ce'] = self.C(e) 
        
        g.apply_edges(fn.u_add_v('Dh', 'Eh', 'DEh'))
        g.edata['e'] = g.edata['DEh'] + g.edata['Ce']
        g.edata['sigma'] = torch.sigmoid(g.edata['e'])
        g.update_all(fn.u_mul_e('Bh', 'sigma', 'm'), fn.sum('m', 'sum_sigma_h'))
        g.update_all(fn.copy_e('sigma', 'm'), fn.sum('m', 'sum_sigma'))
        g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_sigma_h'] / (g.ndata['sum_sigma'] + 1e-6)

        h = g.ndata['h'] # result of graph convolution
        e = g.edata['e'] # result of graph convolution
        
        if self.batch_norm:
            h = self.bn_node_h(h) # batch normalization  
            e = self.bn_node_e(e) # batch normalization  
        
        
        h = F.leaky_relu(h) # non-linear activation
        e = F.leaky_relu(e) # non-linear activation
        
        if self.residual:
            h = h_in + h # residual connection
            e = e_in + e # residual connection
        
        h = F.dropout(h, self.dropout, training=self.training)
        e = F.dropout(e, self.dropout, training=self.training)
        
        return h, e
    
    def __repr__(self):
        return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__,
                                             self.in_channels,
                                             self.out_channels)

    
class GNN(nn.Module):
    def __init__(self, net_params, update_type):
        super().__init__()
        in_dim = net_params['in_dim']
        in_dim_edge = net_params['in_dim_edge']
        hidden_dim = net_params['hidden_dim']
        out_dim = net_params['out_dim']
        n_classes = net_params['n_classes']
        dropout = net_params['dropout']
        n_layers = net_params['L']
        self.readout = net_params['readout']
        self.batch_norm = net_params['batch_norm']
        self.residual = net_params['residual']
        self.edge_feat = net_params['edge_feat']
        self.device = net_params['device']
        self.update_type = update_type
        
        
        self.embedding_h1 = nn.Linear(in_dim, hidden_dim)
        self.embedding_h2 = nn.Linear(hidden_dim, hidden_dim)
        
        self.embedding_e1 = nn.Linear(in_dim_edge, hidden_dim)
        self.embedding_e2 = nn.Linear(in_dim_edge, hidden_dim)
        
        self.layers = nn.ModuleList([ GatedGCNLayer(hidden_dim, hidden_dim, dropout,
                                                    self.batch_norm, self.residual) for _ in range(n_layers-1) ]) 
        self.layers.append(GatedGCNLayer(hidden_dim, out_dim, dropout, self.batch_norm, self.residual))

    def forward(self, g, h, e, stage):
        if stage == 0:
            if self.update_type == 1:
                h = self.embedding_h1(h)
                e = self.embedding_e1(e)
            else:
                h = self.embedding_h2(h)
                e = self.embedding_e2(e)

        for conv in self.layers:
            h, e = conv(g, h, e)
        return h, e

In [None]:
class GIGLayer(nn.Module):
    def __init__(self, net_params, device, is_last_layer):
        super(GIGLayer, self).__init__()
        self.device = device
        self.is_last_layer = is_last_layer
        SGU_params = net_params.copy()
        GGU_P1_params = net_params.copy()
        GGU_P2_params = net_params.copy()
        SGU_params['L'] = 2
        GGU_P1_params['L'] = 1
        GGU_P2_params['L'] = 1
        GGU_P1_params['in_dim'] = net_params['hidden_dim']
        GGU_P2_params['in_dim'] = net_params['hidden_dim']
        self.SGU = GNN(SGU_params, 1)
        self.GGU_Part1 = GNN(GGU_P1_params, 2)
        self.GGU_Part2 = GNN(GGU_P2_params, 2)
    
    def forward(self, SGU_graph, GGU_P1_graph, GGU_P2_graph, h, SGU_e, GGU_P1_e, GGU_P2_e, stage):
        h1, e1 = self.SGU(SGU_graph, h, SGU_e, stage)
        h2, e2 = self.GGU_Part1(GGU_P1_graph, h1, GGU_P1_e, stage)
        h3, e3 = self.GGU_Part2(GGU_P2_graph, h2, GGU_P2_e, stage)

        if not self.is_last_layer:
            return SGU_graph, GGU_P1_graph, GGU_P2_graph, h3, e1, e2, e3
        else:
            SGU_graph.ndata['h'] = h3
            hg = dgl.mean_nodes(SGU_graph, 'h')
            return hg

In [None]:
class GIGNet(nn.Module):
    def __init__(self, net_params, n_layers, num_neighbors, device):
        super().__init__()
        self.device = device
        self.num_neighbors = num_neighbors
        self.layers = nn.ModuleList([GIGLayer(net_params, device, False) for _ in range(n_layers-1)])
        self.layers.append(GIGLayer(net_params, device, True))
        self.MLP_layer = MLPReadout(net_params['out_dim'], net_params['n_classes'])

    #Obtain global neighbors
    def global_neighbors(self, h, num_global_neighbors):
        h = h.to(self.device)
        si = torch.einsum('i j , j k -> i k', h, h.transpose(0, 1)).to(self.device)
        diag = torch.diag(si).to(self.device)
        a_diag = torch.diag_embed(diag).to(self.device)
        si = (si - a_diag).to(self.device)
        _, target_inds = si.topk(k=num_global_neighbors, dim=1, largest=True)
        return target_inds.to(self.device)
    
    #GIG data generation module
    def GDG(self, graphs, indices, edge_lengths, node_lengths):
        global_src = torch.zeros((0)).to(self.device)
        global_tgt = torch.zeros((0)).to(self.device)
        node_index = 0
        edge_index = 0
        SGU_graphs = []
        GGU_P2_graphs = []
        proxy_inds = []
        proxy_vertices = torch.zeros((0)).to(self.device)
        #graph construction in DGL for SGU and GGU part 2
        for ind, indice in enumerate(indices):
            node_length = node_lengths[ind]
            edge_length = edge_lengths[ind]
            
            #Vertex and edge features from current sub-graph
            h = graphs.ndata['feat'][node_index:node_index+node_length, :].to(self.device)
            e = graphs.edata['feat'][edge_index:edge_index+edge_length, :].to(self.device)
            
            #Source and target for SGU module
            src1 = torch.arange(node_length).to(self.device)
            tgt1 = torch.full([node_length,], node_length).to(self.device)
            src1 = torch.cat((indice[0].to(self.device), src1))
            tgt1 = torch.cat((indice[1].to(self.device), tgt1))
            
            #Source and target for GGU part 2
            src3 = torch.full([node_length,], node_length).to(self.device)
            tgt3 = torch.arange(node_length).to(self.device)
            
            #SGU graph
            SGU_g = dgl.graph((src1, tgt1)).to(self.device)
            
            #GGU part 2 graph
            GGU_P2_g = dgl.graph((src3, tgt3)).to(self.device)
            
            #Proxy vertex feature initilization 
            proxy_vertex = h.mean(dim=0).reshape(1,18).to(self.device)
            proxy_vertices = torch.cat((proxy_vertices, proxy_vertex), dim=0)
            
            #Vertex and edge features construction
            SGU_node_features = torch.cat((h, proxy_vertex), dim=0)
            GGU_P1_node_features = torch.cat((h, proxy_vertex), dim=0)
            SGU_edge_features = torch.cat((e, torch.ones(h.shape[0],18).to(self.device)))
            GGU_P2_edge_features = torch.ones(h.shape[0],18).to(self.device)
            SGU_g.ndata['feat'] = SGU_node_features
            GGU_P2_g.ndata['feat'] = GGU_P1_node_features
            SGU_g.edata['feat'] = SGU_edge_features
            GGU_P2_g.edata['feat'] = GGU_P2_edge_features
            
            #Update indexes
            node_index += node_length
            edge_index += edge_length
            
            #Append graphs into corresponding sets
            SGU_graphs.append(SGU_g)
            GGU_P2_graphs.append(GGU_P2_g)
            
            if ind == 0:
                current_ind = node_length
            else:
                current_ind = current_ind + node_length + 1
            proxy_inds.append(current_ind)
            
        #Combine single graphs to form a new DGL graph
        SGU_graphs = dgl.batch(SGU_graphs).to(self.device)
        GGU_P2_graphs = dgl.batch(GGU_P2_graphs).to(self.device)
        
        #Obtain edge features of SGU and GGU part 2
        SGU_k = SGU_graphs.edata['feat']
        GGU_P2_k = GGU_P2_graphs.edata['feat']
        
        #graph construction in DGL for GGU part 1
        num_global_neighbors = self.num_neighbors
        proxy_inds = torch.tensor(proxy_inds)
        target_inds = self.global_neighbors(proxy_vertices, num_global_neighbors)
        for ind, i in enumerate(proxy_inds):
            src1 = torch.full([num_global_neighbors,], i).to(self.device)
            tgt1 = proxy_inds[target_inds[ind,:]].to(self.device)
            src2 = proxy_inds[target_inds[ind,:]].to(self.device)
            tgt2 = torch.full([num_global_neighbors,], i).to(self.device)
            
            src = torch.cat((src1, src2))
            tgt = torch.cat((tgt1, tgt2))
            
            global_src = torch.cat((global_src, src))
            global_tgt = torch.cat((global_tgt, tgt)) 
        
        #Ensure the number of vertex matches with original graph
        last_src = torch.tensor([SGU_graphs.num_nodes()-1]).to(self.device)
        last_tgt = torch.tensor([SGU_graphs.num_nodes()-1]).to(self.device)
        global_src = torch.cat((global_src, last_src))
        global_tgt = torch.cat((global_tgt, last_tgt))
          
        #Form new DGL graph for GGU part 1
        GGU_P1_graphs = dgl.graph((global_src.long(), global_tgt.long())).to(self.device)
        GGU_P1_k = torch.ones(proxy_inds.shape[0]*num_global_neighbors*2+1, 18).to(self.device)

        return SGU_graphs, GGU_P1_graphs, GGU_P2_graphs, SGU_k, GGU_P1_k, GGU_P2_k
    
    def forward(self, graphs):
        indices = []    
        edge_lengths = []
        node_lengths = []
        for graph in dgl.unbatch(graphs):
            adj = graph.adjacency_matrix(transpose=False)._indices()
            ind1 = adj[0]
            ind2 = adj[1]
            inds = [ind1, ind2]
            edge_lengths.append(graph.num_edges())
            node_lengths.append(graph.num_nodes())
            indices.append(inds)
        
        #GIG data
        SGU_graphs, GGU_P1_graphs, GGU_P2_graphs, SGU_k, GGU_P1_k, GGU_P2_k = \
        self.GDG(graphs, indices, edge_lengths, node_lengths)
        h = SGU_graphs.ndata['feat']
        
        #Forward propagation
        for ind, conv in enumerate(self.layers):
            if ind < len(self.layers)-1:
                SGU_graphs, GGU_P1_graphs, GGU_P2_graphs, h, SGU_k, GGU_P1_k, GGU_P2_k = \
                conv(SGU_graphs, GGU_P1_graphs, GGU_P2_graphs, h, SGU_k, GGU_P1_k, GGU_P2_k, ind)
            else:
                best_scores = \
                conv(SGU_graphs, GGU_P1_graphs, GGU_P2_graphs, h, SGU_k, GGU_P1_k, GGU_P2_k, ind)
        return self.MLP_layer(best_scores)

    
    def loss(self, pred, label):
        criterion = nn.CrossEntropyLoss().to(self.device)
        loss = criterion(pred, label)
        return loss

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


net_params = {}
net_params['device'] = device
net_params['gated'] = True 
net_params['in_dim'] = dataset.all.graph_lists[0].ndata['feat'][0].shape[0]
net_params['in_dim_edge'] = dataset.all.graph_lists[0].edata['feat'][0].shape[0]
net_params['residual'] = True
net_params['hidden_dim'] = 100 
net_params['out_dim'] = 100
num_classes = len(np.unique(dataset.all.graph_labels))
net_params['n_classes'] = num_classes
net_params['n_heads'] = -1
net_params['L'] = 4  
net_params['readout'] = 'mean'
net_params['layer_norm'] = True
net_params['batch_norm'] = True
net_params['in_feat_dropout'] = 0.0
net_params['dropout'] = 0.1
net_params['edge_feat'] = True
net_params['self_loop'] = False

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

torch.cuda.manual_seed(41)

In [None]:
def evaluate_network(model, device, data_loader, epoch):
    model.eval()
    epoch_test_loss = 0
    epoch_test_acc = 0
    nb_data = 0
    with torch.no_grad():
        for iter, (batch_graphs, batch_labels) in enumerate(data_loader):
            batch_graphs = batch_graphs.to(device)
            batch_x = batch_graphs.ndata['feat'].to(device)
            batch_e = batch_graphs.edata['feat'].to(device)
            batch_labels = batch_labels.to(device)
            
            batch_scores = model.forward(batch_graphs)
            loss = model.loss(batch_scores, batch_labels) 
            epoch_test_loss += loss.detach().item()
            epoch_test_acc += accuracy(batch_scores, batch_labels)
            nb_data += batch_labels.size(0)
        epoch_test_loss /= (iter + 1)
        epoch_test_acc /= nb_data
    return epoch_test_loss, epoch_test_acc


def train_epoch(model, optimizer, device, data_loader, epoch):
    model.train()
    epoch_loss = 0
    epoch_train_acc = 0
    nb_data = 0
    gpu_mem = 0
    for iter, (batch_graphs, batch_labels) in enumerate(data_loader):
        batch_graphs = batch_graphs.to(device)
        batch_x = batch_graphs.ndata['feat'].to(device)
        batch_e = batch_graphs.edata['feat'].to(device)
        batch_labels = batch_labels.to(device)
        optimizer.zero_grad()
        batch_scores = model.forward(batch_graphs)
        loss = model.loss(batch_scores, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        epoch_train_acc += accuracy(batch_scores, batch_labels)
        nb_data += batch_labels.size(0)
    epoch_loss /= (iter + 1)
    epoch_train_acc /= nb_data
    return epoch_loss, epoch_train_acc, optimizer


def accuracy(scores, targets):
    scores = scores.detach().argmax(dim=1)
    acc = (scores==targets).float().sum().item()
    return acc  

In [None]:
avg_test_acc = []
avg_train_acc = []
avg_convergence_epochs = []

t0 = time.time()
per_epoch_time = []

dataset = LoadData("ENZYMES")



trainset, valset, testset = dataset.train, dataset.val, dataset.test

root_log_dir = os.path.join('/root/tf-logs/', "DATA_ENZYMES")
model_store_dir = os.path.join('/root/models/', "ENZYMES_MODEL")
device = net_params['device']
final_best = 0

try:
    for split_number in range(10):
        t0_split = time.time()
        log_dir = os.path.join(root_log_dir, "RUN_69" + str(split_number))
        writer = SummaryWriter(log_dir=log_dir)

        # setting seeds
        np.random.seed(41)
        torch.manual_seed(41)
        if device.type == 'cuda':
            torch.cuda.manual_seed(41)

        print("RUN NUMBER: ", split_number)
        trainset, valset, testset = dataset.train[split_number], dataset.val[split_number], dataset.test[split_number]
        print("Training Graphs: ", len(trainset))
        print("Validation Graphs: ", len(valset))
        print("Test Graphs: ", len(testset))
        print("Number of Classes: ", net_params['n_classes'])

        model = GIGNet(net_params, 2, 14, device)
        model = model.to(device)
        optimizer = optim.AdamW(model.parameters(), lr=0.0012, weight_decay=0.14)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                         factor=0.5,
                                                         patience=20,
                                                         verbose=True)
        best_split_model = 0
        best_split_score = 0
        epoch_train_losses, epoch_val_losses = [], []
        epoch_train_accs, epoch_val_accs = [], [] 


        train_loader = DataLoader(trainset, batch_size=16, shuffle=True, drop_last=True, collate_fn=dataset.collate)
        val_loader = DataLoader(valset, batch_size=16, shuffle=False, drop_last=True, collate_fn=dataset.collate)
        test_loader = DataLoader(testset, batch_size=16, shuffle=False, drop_last=True, collate_fn=dataset.collate)

        with tqdm(range(250)) as t:
            for epoch in t:

                t.set_description('Epoch %d' % epoch)    

                start = time.time()

                
                epoch_train_loss, epoch_train_acc, optimizer = train_epoch(model, optimizer, device, train_loader, epoch)

                epoch_val_loss, epoch_val_acc = evaluate_network(model, device, val_loader, epoch)
                _, epoch_test_acc = evaluate_network(model, device, test_loader, epoch)

                epoch_train_losses.append(epoch_train_loss)
                epoch_val_losses.append(epoch_val_loss)
                epoch_train_accs.append(epoch_train_acc)
                epoch_val_accs.append(epoch_val_acc)

                writer.add_scalar('train/_loss', epoch_train_loss, epoch)
                writer.add_scalar('val/_loss', epoch_val_loss, epoch)
                writer.add_scalar('train/_acc', epoch_train_acc, epoch)
                writer.add_scalar('val/_acc', epoch_val_acc, epoch)
                writer.add_scalar('test/_acc', epoch_test_acc, epoch)
                writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)

                _, epoch_test_acc = evaluate_network(model, device, test_loader, epoch)
                
                if epoch_test_acc > best_split_score:
                    best_split_score = epoch_test_acc
                    best_split_model = copy.deepcopy(model)
                    
                    
                t.set_postfix(time=time.time()-start, lr=optimizer.param_groups[0]['lr'],
                              train_loss=epoch_train_loss, val_loss=epoch_val_loss,
                              train_acc=epoch_train_acc, val_acc=epoch_val_acc,
                              test_acc=epoch_test_acc)  

                per_epoch_time.append(time.time()-start)


                scheduler.step(epoch_val_loss)

                if optimizer.param_groups[0]['lr'] < 0.0:
                    print("\n!! LR EQUAL TO MIN LR SET.")
                    break


        _, test_acc = evaluate_network(best_split_model, device, test_loader, epoch)   
        _, train_acc = evaluate_network(best_split_model, device, train_loader, epoch)    
        avg_test_acc.append(test_acc)   
        avg_train_acc.append(train_acc)
        avg_convergence_epochs.append(epoch)
        if test_acc > final_best:
            final_best = test_acc
#             torch.save(model.state_dict(), '{}.pkl'.format(model_store_dir + "/epoch_" + str(epoch)))

        print("Test Accuracy: {:.4f}".format(test_acc))
        print("Train Accuracy: {:.4f}".format(train_acc))

except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early because of KeyboardInterrupt')

print("TEST ACCURACY [FINAL RESULT] {:.4f}".format(final_best))
print("TOTAL TIME TAKEN: {:.4f}hrs".format((time.time()-t0)/3600))
print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))
print("AVG CONVERGENCE Time (Epochs): {:.4f}".format(np.mean(np.array(avg_convergence_epochs))))
# Final test accuracy value averaged over 10-fold
print("""\n\n\nFINAL RESULTS\n\nTEST ACCURACY averaged: {:.4f} with s.d. {:.4f}"""          .format(np.mean(np.array(avg_test_acc))*100, np.std(avg_test_acc)*100))
print("\nAll splits Test Accuracies:\n", avg_test_acc)
print("""\n\n\nFINAL RESULTS\n\nTRAIN ACCURACY averaged: {:.4f} with s.d. {:.4f}"""          .format(np.mean(np.array(avg_train_acc))*100, np.std(avg_train_acc)*100))
print("\nAll splits Train Accuracies:\n", avg_train_acc)

writer.close()