In [None]:
import numpy as np
import torch
import pickle
import time
import os
%matplotlib inline
import matplotlib.pyplot as plt
import math
from tqdm import tqdm
from tensorboardX import SummaryWriter
import glob
from sklearn.metrics import confusion_matrix
from collections import OrderedDict
from scipy import sparse as sp

import warnings
warnings.filterwarnings("ignore")

In [None]:
import os
os.chdir('../../') 
print(os.getcwd())


In [None]:
import pickle

%load_ext autoreload
%autoreload 2

from data.superpixels import SuperPixDatasetDGL 

from data.data import LoadData
from torch.utils.data import DataLoader
from data.superpixels import SuperPixDataset


In [None]:
class DotDict(dict):
    def __init__(self, **kwds):
        self.update(kwds)
        self.__dict__ = self

In [None]:
DATASET_NAME = 'SBM_PATTERN'
dataset = LoadData(DATASET_NAME) 
trainset, valset, testset = dataset.train, dataset.val, dataset.test

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
import dgl
import torch.optim as optim
from bisect import bisect_right
from torch.optim.lr_scheduler import CosineAnnealingLR



class MLPReadout(nn.Module):

    def __init__(self, input_dim, output_dim, L=3): 
        super().__init__()
        list_FC_layers = [ nn.Linear( input_dim//2**l , input_dim//2**(l+1) , bias=True ) for l in range(L) ]
        list_FC_layers.append(nn.Linear( input_dim//2**L , output_dim , bias=True ))
        self.FC_layers = nn.ModuleList(list_FC_layers)
        self.L = L

    def act(self, x):
        return x*(torch.tanh(F.softplus(x)))
        
    def forward(self, x):
        y = x
        for l in range(self.L):
            y = self.FC_layers[l](y)
            y = F.leaky_relu(y)
        y = self.FC_layers[self.L](y)
        return y


class GatedGCNLayer(nn.Module):
    def __init__(self, input_dim, output_dim, dropout, batch_norm, residual=False):
        super().__init__()
        self.in_channels = input_dim
        self.out_channels = output_dim
        self.dropout = dropout
        self.batch_norm = batch_norm
        self.residual = residual
        
        if input_dim != output_dim:
            self.residual = False
        
        self.A = nn.Linear(input_dim, output_dim, bias=True)
        self.B = nn.Linear(input_dim, output_dim, bias=True)
        self.C = nn.Linear(input_dim, output_dim, bias=True)
        self.D = nn.Linear(input_dim, output_dim, bias=True)
        self.E = nn.Linear(input_dim, output_dim, bias=True)
        self.bn_node_h = nn.BatchNorm1d(output_dim)
        self.bn_node_e = nn.BatchNorm1d(output_dim)
    
    def act(self, x):
        return x*(torch.tanh(F.softplus(x)))
    
    
    def message_func(self, edges):
        Bh_j = edges.src['Bh']    
        e_ij = edges.data['Ce'] +  edges.src['Dh'] + edges.dst['Eh'] 
        edges.data['e'] = e_ij
        return {'Bh_j' : Bh_j, 'e_ij' : e_ij}

    def reduce_func(self, nodes):
        Ah_i = nodes.data['Ah']
        Bh_j = nodes.mailbox['Bh_j']
        e = nodes.mailbox['e_ij'] 
        sigma_ij = torch.sigmoid(e) 
        h = Ah_i + torch.sum( sigma_ij * Bh_j, dim=1 ) / ( torch.sum( sigma_ij, dim=1 ) + 1e-6 )      
        return {'h' : h}
    
    def forward(self, g, h, e):
        h_in = h 
        e_in = e 
        
        g.ndata['h']  = h 
        g.ndata['Ah'] = self.A(h) 
        g.ndata['Bh'] = self.B(h) 
        g.ndata['Dh'] = self.D(h)
        g.ndata['Eh'] = self.E(h) 
        g.edata['e']  = e 
        g.edata['Ce'] = self.C(e) 
        
        g.apply_edges(fn.u_add_v('Dh', 'Eh', 'DEh'))
        g.edata['e'] = g.edata['DEh'] + g.edata['Ce']
        g.edata['sigma'] = torch.sigmoid(g.edata['e'])
        g.update_all(fn.u_mul_e('Bh', 'sigma', 'm'), fn.sum('m', 'sum_sigma_h'))
        g.update_all(fn.copy_e('sigma', 'm'), fn.sum('m', 'sum_sigma'))
        g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_sigma_h'] / (g.ndata['sum_sigma'] + 1e-6)

        h = g.ndata['h'] # result of graph convolution
        e = g.edata['e'] # result of graph convolution
        
        if self.batch_norm:
            h = self.bn_node_h(h) # batch normalization  
            e = self.bn_node_e(e) # batch normalization  
        
        # h = F.leaky_relu(h) # non-linear activation
        # e = F.leaky_relu(e) # non-linear activation
        h = self.act(h) # non-linear activation
        e = self.act(e) # non-linear activation
        
        if self.residual:
            h = h_in + h # residual connection
            e = e_in + e # residual connection
        
        h = F.dropout(h, self.dropout, training=self.training)
        e = F.dropout(e, self.dropout, training=self.training)
        
        return h, e
    
    def __repr__(self):
        return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__,
                                             self.in_channels,
                                             self.out_channels)

    
class GNN(nn.Module):
    def __init__(self, net_params, update_type):
        super().__init__()
        in_dim = net_params['in_dim']
        in_dim_edge = 1
        hidden_dim = net_params['hidden_dim']
        out_dim = net_params['out_dim']
        n_classes = net_params['n_classes']
        dropout = net_params['dropout']
        n_layers = net_params['L']
        self.readout = net_params['readout']
        self.batch_norm = net_params['batch_norm']
        self.residual = net_params['residual']
        self.edge_feat = net_params['edge_feat']
        self.device = net_params['device']
        self.update_type = update_type
        self.pos_enc = net_params['pos_enc']
        if self.pos_enc:
            pos_enc_dim = net_params['pos_enc_dim']
            self.embedding_pos_enc = nn.Linear(pos_enc_dim, hidden_dim)
        
        self.layers = nn.ModuleList([ GatedGCNLayer(hidden_dim, hidden_dim, dropout,
                                                    self.batch_norm, self.residual) for _ in range(n_layers-1) ]) 
        self.layers.append(GatedGCNLayer(hidden_dim, out_dim, dropout, self.batch_norm, self.residual))

        
    def forward(self, g, h, e, stage, h_pos_enc=None):

        if self.pos_enc:
            h_pos_enc = self.embedding_pos_enc(h_pos_enc.float()) 
            h = h + h_pos_enc 
        
        for conv in self.layers:
            h, e = conv(g, h, e)
        return h, e

In [None]:
class GIGLayer(nn.Module):
    def __init__(self, net_params, device, is_last_layer, embedding_h1, embedding_h2, embedding_e1, embedding_e2):
        super(GIGLayer, self).__init__()
        self.device = device
        self.is_last_layer = is_last_layer
        SGU_params = net_params.copy()
        GGU_P1_params = net_params.copy()
        GGU_P2_params = net_params.copy()
        SGU_params['L'] = 10
        GGU_P1_params['L'] = 3
        GGU_P2_params['L'] = 3
        GGU_P1_params['in_dim'] = net_params['hidden_dim']
        GGU_P2_params['in_dim'] = net_params['hidden_dim']
        self.SGU = GNN(SGU_params, 1)
        self.GGU_Part1 = GNN(GGU_P1_params, 2)
        self.GGU_Part2 = GNN(GGU_P2_params, 2)
        self.MLP_layer = MLPReadout(net_params['out_dim'], net_params['n_classes'])
        self.embedding_h1 = embedding_h1
        self.embedding_h2 = embedding_h2
        self.embedding_e1 = embedding_e1
        self.embedding_e2 = embedding_e2
    
    def forward(self, SGU_graph, GGU_P1_graph, GGU_P2_graph, h, SGU_e, GGU_P1_e, GGU_P2_e, original_length, stage, batch_pos_enc):
        if stage == 0:
            h = self.embedding_h1(h)
            SGU_e = self.embedding_e1(SGU_e)
        h1, e1 = self.SGU(SGU_graph, h, SGU_e, stage, batch_pos_enc)

        if stage == 0:
            h1 = self.embedding_h2(h1)
            GGU_P1_e = self.embedding_e2(GGU_P1_e)
        h2, e2 = self.GGU_Part1(GGU_P1_graph, h1, GGU_P1_e, stage, batch_pos_enc)

        if stage == 0:
            h2 = self.embedding_h2(h2)
            GGU_P2_e = self.embedding_e2(GGU_P2_e)
        h3, e3 = self.GGU_Part2(GGU_P2_graph, h2, GGU_P2_e, stage, batch_pos_enc)

        if not self.is_last_layer:
            return SGU_graph, GGU_P1_graph, GGU_P2_graph, h3, e1, e2, e3
        else:
            h_r = h3[:original_length, :]
            return self.MLP_layer(h_r)

In [None]:
class GIGNet(nn.Module):
    def __init__(self, net_params, n_classes, n_layers, device):
        super().__init__()
        self.device = device
        self.n_classes = n_classes
        self.embedding_h1 = nn.Linear(net_params['in_dim'], net_params['hidden_dim'])
        self.embedding_h2 = nn.Linear(net_params['hidden_dim'], net_params['hidden_dim'])
        self.embedding_e1 = nn.Linear(1, net_params['hidden_dim'])
        self.embedding_e2 = nn.Linear(1, net_params['hidden_dim'])
        self.layers = nn.ModuleList([GIGLayer(net_params, device, False, self.embedding_h1, self.embedding_h2, self.embedding_e1, self.embedding_e2) for _ in range(n_layers-1)])
        self.layers.append(GIGLayer(net_params, device, True, self.embedding_h1, self.embedding_h2, self.embedding_e1, self.embedding_e2))
        self.embedding_h = nn.Embedding(net_params['in_dim_node'], net_params['in_dim'])
        self.in_dim = net_params['in_dim']
        self.pos_enc_dim = net_params['pos_enc_dim']

    def positional_encoding(self, g, pos_enc_dim):
        # A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float)
        A = g.adj_external().to_dense()
        N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
        L = sp.eye(g.number_of_nodes()) - N * A * N
        
        EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR', tol=1e-2) # for 40 PEs
        EigVec = EigVec[:, EigVal.argsort()] # increasing order
        return torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).to(self.device).float() 
    
    
    
    def proxy_edge_construction(self, proxy_vertex, sub_vertex, num_proxy_edges):
        sub_vertex = sub_vertex.permute(1,0).to(self.device)
        si = torch.einsum('i j , j k -> i k', proxy_vertex, sub_vertex).to(self.device)
        _, target_inds_sim = si.topk(k=num_proxy_edges, dim=1, largest=True)
        _, target_inds_nonsim = si.topk(k=num_proxy_edges, dim=1, largest=False)
        return target_inds_nonsim.squeeze(0).to(self.device), target_inds_sim.squeeze(0).to(self.device)

    def global_neighbors(self, h, num_global_neighbors):
        h = h.to(self.device)
        si = torch.einsum('i j , j k -> i k', h, h.transpose(0, 1)).to(self.device)
        diag = torch.diag(si).to(self.device)
        a_diag = torch.diag_embed(diag).to(self.device)
        si = (si - a_diag).to(self.device)
        _, target_inds_sim = si.topk(k=num_global_neighbors, dim=1, largest=True)
        _, target_inds_nonsim = si.topk(k=num_global_neighbors, dim=1, largest=False)
        return target_inds_sim.to(self.device), target_inds_nonsim.to(self.device)
    
    
    def GDG(self, graphs, indices, edge_lengths, node_lengths, batch_pos_enc):
        global_src = torch.zeros((0)).to(self.device)
        global_tgt = torch.zeros((0)).to(self.device)
        global_src_pos = torch.zeros((0)).to(self.device)
        global_tgt_pos = torch.zeros((0)).to(self.device)
        src1 = torch.zeros((0)).to(self.device)
        tgt1 = torch.zeros((0)).to(self.device)
        src3 = torch.zeros((0)).to(self.device)
        tgt3 = torch.zeros((0)).to(self.device)
        node_index = 0
        current_ind = graphs.ndata['feat'].shape[0]
        proxy_inds = []
        proxy_vertices = torch.zeros((0)).to(self.device)
        proxy_vertices_gi_lo = torch.zeros((0)).to(self.device)
        proxy_edge_counter = 0
        for ind, indice in enumerate(indices):
            node_length = node_lengths[ind]
            edge_length = edge_lengths[ind]
            
            h = graphs.ndata['feat'][node_index:node_index+node_length, :].to(self.device)
            
            proxy_vertex = h.mean(dim=0).reshape(1, self.in_dim).to(self.device)
            proxy_vertex_li_go = nn.init.xavier_uniform_(proxy_vertex)
            proxy_vertex_gi_lo = nn.init.xavier_uniform_(proxy_vertex)
            proxy_vertices = torch.cat((proxy_vertices, proxy_vertex_li_go), dim=0)
            proxy_vertices = torch.cat((proxy_vertices, proxy_vertex_gi_lo), dim=0)
            proxy_vertices_gi_lo = torch.cat((proxy_vertices_gi_lo, proxy_vertex_gi_lo), dim=0)
            
            proxy_node_length = int(node_length * 0.1)
            proxy_edge_counter += proxy_node_length
            
            
            src1_sub, tgt3_sub = self.proxy_edge_construction(proxy_vertex, h, proxy_node_length)
            src1_sub = src1_sub + node_index
            tgt3_sub = tgt3_sub + node_index
            tgt1_sub = torch.full([proxy_node_length,], current_ind).to(self.device)   
            
            src3_sub = torch.cat((torch.full([proxy_node_length,], current_ind+1).to(self.device), torch.tensor([current_ind]).to(self.device)))
            tgt3_sub = torch.cat((tgt3_sub, torch.tensor([current_ind+1]).to(self.device)))
            
            
            src1 = torch.cat((src1, src1_sub))
            tgt1 = torch.cat((tgt1, tgt1_sub))
            src3 = torch.cat((src3, src3_sub))
            tgt3 = torch.cat((tgt3, tgt3_sub))
            
            
            node_index += node_length
            proxy_inds.append(current_ind+1)
            current_ind += 2
        
        src1 = torch.cat((graphs.adj().indices()[0].to(self.device), src1))
        tgt1 = torch.cat((graphs.adj().indices()[1].to(self.device), tgt1))
        SGU_g = dgl.graph((src1.long(), tgt1.long())).to(self.device)
        SGU_g.add_nodes(1)
        GGU_P2_g = dgl.graph((src3.long(), tgt3.long())).to(self.device)
        
        h = torch.cat((graphs.ndata['feat'], proxy_vertices))
        SGU_k = torch.cat((graphs.edata['feat'], torch.ones(proxy_edge_counter, 1).to(self.device)))
        SGU_k = nn.init.xavier_uniform_(SGU_k)
        GGU_P2_k = torch.randn(proxy_edge_counter+len(indices), 1).to(self.device)
#         GGU_P2_k = torch.ones(graphs.num_nodes(), 1).to(self.device)
        GGU_P2_k = nn.init.xavier_uniform_(GGU_P2_k)
        
        num_global_neighbors = 9
        proxy_inds = torch.tensor(proxy_inds).to(self.device)     
        proxy_inds_new = torch.arange(len(node_lengths)*2).to(self.device) 
        num_global_neighbors = int(num_global_neighbors/2)
        target_inds_similar, target_inds_nonsimilar = self.global_neighbors(proxy_vertices_gi_lo, num_global_neighbors)
 

        for ind, i in enumerate(proxy_inds):
            src1_1 = proxy_inds[target_inds_similar[ind,:]].to(self.device)
            tgt1_1 = torch.full([num_global_neighbors,], i).to(self.device)
            src1_2 = torch.full([num_global_neighbors,], i).to(self.device)
            tgt1_2 = proxy_inds[target_inds_similar[ind,:]].to(self.device)
            
            src2_1 = proxy_inds[target_inds_nonsimilar[ind,:]].to(self.device)
            tgt2_1 = torch.full([num_global_neighbors,], i).to(self.device)
            src2_2 = torch.full([num_global_neighbors,], i).to(self.device)
            tgt2_2 = proxy_inds[target_inds_nonsimilar[ind,:]].to(self.device)
            
            src1 = torch.cat((src1_1, src1_2))
            tgt1 = torch.cat((tgt1_1, tgt1_2))
            src2 = torch.cat((src2_1, src2_2))
            tgt2 = torch.cat((tgt2_1, tgt2_2))
            
            src_ggu2 = torch.cat((src1, src2))
            tgt_ggu2 = torch.cat((tgt1, tgt2))
                        
            global_src = torch.cat((global_src, src_ggu2))
            global_tgt = torch.cat((global_tgt, tgt_ggu2)) 
            
            
            src3_1 = proxy_inds_new[target_inds_similar[ind,:]].to(self.device)
            tgt3_1 = torch.full([num_global_neighbors,], proxy_inds_new[ind]).to(self.device)
            src3_2 = torch.full([num_global_neighbors,], proxy_inds_new[ind]).to(self.device)
            tgt3_2 = proxy_inds_new[target_inds_similar[ind,:]].to(self.device)
            
            src4_1 = proxy_inds_new[target_inds_nonsimilar[ind,:]].to(self.device)
            tgt4_1 = torch.full([num_global_neighbors,], proxy_inds_new[ind]).to(self.device)
            src4_2 = torch.full([num_global_neighbors,], proxy_inds_new[ind]).to(self.device)
            tgt4_2 = proxy_inds_new[target_inds_nonsimilar[ind,:]].to(self.device)
            
            src3 = torch.cat((src3_1*2+1, src3_2*2+1))
            src3 = torch.cat((src3, torch.tensor([proxy_inds_new[ind]*2]).to(self.device)))
            tgt3 = torch.cat((tgt3_1*2+1, tgt3_2*2+1))
            tgt3 = torch.cat((tgt3, torch.tensor([proxy_inds_new[ind]*2+1]).to(self.device)))
            src4 = torch.cat((src4_1*2+1, src4_2*2+1))
            tgt4 = torch.cat((tgt4_1*2+1, tgt4_2*2+1))
            
            src_ggu2_pe = torch.cat((src3, src4))
            tgt_ggu2_pe = torch.cat((tgt3, tgt4))

            global_src_pos = torch.cat((global_src_pos, src_ggu2_pe))
            global_tgt_pos = torch.cat((global_tgt_pos, tgt_ggu2_pe)) 
        

        GGU_P1_g = dgl.graph((global_src.long(), global_tgt.long())).to(self.device)
        GGU_P1_k = torch.ones(proxy_inds.shape[0]*num_global_neighbors*4, 1).to(self.device)
#         GGU_P1_k = torch.ones(proxy_inds.shape[0]*num_global_neighbors*2, 1).to(self.device)
        GGU_P1_k = nn.init.xavier_uniform_(GGU_P1_k)
    
    
        GGU_P1_g_pos = dgl.graph((global_src_pos.long(), global_tgt_pos.long())).to(self.device)
        pos_enc_ggu1 = self.positional_encoding(GGU_P1_g_pos, self.pos_enc_dim)
        batch_pos_enc = torch.cat((batch_pos_enc.to(self.device), pos_enc_ggu1.to(self.device)), dim=0).to(self.device)

        return h, SGU_g, GGU_P1_g, GGU_P2_g, SGU_k, GGU_P1_k, GGU_P2_k, batch_pos_enc
    
    
    
    def forward(self, graphs, batch_pos_enc=None):
        original_length = graphs.num_nodes()
        indices = []    
        edge_lengths = []
        node_lengths = []
        for graph in dgl.unbatch(graphs):
            adj = graph.adjacency_matrix().indices()
            ind1 = adj[0]
            ind2 = adj[1]
            inds = [ind1, ind2]
            edge_lengths.append(graph.num_edges())
            node_lengths.append(graph.num_nodes())
            indices.append(inds)
        
        h_p = self.embedding_h(graphs.ndata['feat'])
        graphs.ndata['feat'] = h_p
        h, SGU_g, GGU_P1_g, GGU_P2_g, SGU_k, GGU_P1_k, GGU_P2_k, batch_pos_enc = self.GDG(graphs, indices, edge_lengths, node_lengths, batch_pos_enc)

        for ind, conv in enumerate(self.layers):
            if ind < len(self.layers)-1:
                SGU_g, GGU_P1_g, GGU_P2_g, h, SGU_k, GGU_P1_k, GGU_P2_k = conv(SGU_g, GGU_P1_g, GGU_P2_g, h, SGU_k, GGU_P1_k, GGU_P2_k, original_length, ind, batch_pos_enc)
            else:
                best_scores = conv(SGU_g, GGU_P1_g, GGU_P2_g, h, SGU_k, GGU_P1_k, GGU_P2_k, original_length, ind, batch_pos_enc)
        return best_scores

    
    def loss(self, pred, label):
        V = label.size(0)
        label_count = torch.bincount(label)
        label_count = label_count[label_count.nonzero()].squeeze()
        cluster_sizes = torch.zeros(self.n_classes).long().to(self.device)
        cluster_sizes[torch.unique(label)] = label_count
        weight = (V - cluster_sizes).float() / V
        weight *= (cluster_sizes>0).float()
        
        criterion = nn.CrossEntropyLoss(weight=weight)
        loss = criterion(pred, label)

        return loss

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net_params = {}
net_params['device'] = device
net_params['in_dim_node'] = torch.unique(trainset[0][0].ndata['feat'],dim=0).size(0) # node_dim (feat is an integer)
net_params['in_dim'] = 15     #10->9
net_params['hidden_dim'] = 100
net_params['out_dim'] = 100
num_classes = torch.unique(trainset[0][1],dim=0).size(0)
net_params['n_classes'] = num_classes
net_params['L'] = 7  # min L should be 2
net_params['readout'] = "mean"
net_params['layer_norm'] = True
net_params['batch_norm'] = True
net_params['in_feat_dropout'] = 0.0
net_params['dropout'] = 0.0
net_params['residual'] = True
net_params['edge_feat'] = False
net_params['self_loop'] = False
net_params['pos_enc'] = True
net_params['pos_enc_dim'] = 4

start0 = time.time()
if net_params['pos_enc']:
    print("[!] Adding graph positional encoding.")
    dataset._add_positional_encodings(net_params['pos_enc_dim'])
    print('Time PE:',time.time()-start0)
    
trainset, valset, testset = dataset.train, dataset.val, dataset.test

train_loader = DataLoader(trainset, batch_size=64, shuffle=True, collate_fn=dataset.collate, drop_last=True)
val_loader = DataLoader(valset, batch_size=64, shuffle=False, collate_fn=dataset.collate, drop_last=True)
test_loader = DataLoader(testset, batch_size=64, shuffle=False, collate_fn=dataset.collate, drop_last=True)

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

torch.cuda.manual_seed(41)

In [None]:
model = GIGNet(net_params, num_classes, 2, device)
model = model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=0.003, weight_decay=0)     
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5, verbose=True, min_lr=0.00001, mode='min')
# scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=45, verbose=True)

In [None]:
def evaluate_network_sparse(model, device, data_loader, epoch):
    model.eval()
    epoch_test_loss = 0
    epoch_test_acc = 0

    with torch.no_grad():
        for iter, (batch_graphs, batch_labels) in enumerate(data_loader):
            batch_graphs = batch_graphs.to(device)
            batch_x = batch_graphs.ndata['feat'].to(device)
            batch_e = batch_graphs.edata['feat'].to(device)
            batch_labels = batch_labels.to(device)
            
            try:
                batch_pos_enc = batch_graphs.ndata['pos_enc'].to(device)
                batch_scores = model.forward(batch_graphs, batch_pos_enc)
            except:
                batch_scores = model.forward(batch_graphs)
                
            loss = model.loss(batch_scores, batch_labels) 
            epoch_test_loss += loss.detach().item()
            epoch_test_acc += accuracy_SBM(batch_scores, batch_labels)

        epoch_test_loss /= (iter + 1)
        epoch_test_acc /= (iter + 1)
    return epoch_test_loss, epoch_test_acc

def train_epoch_sparse(model, optimizer, device, data_loader, epoch):
    model.train()
    epoch_loss = 0
    epoch_train_acc = 0

    gpu_mem = 0
    for iter, (batch_graphs, batch_labels) in enumerate(data_loader):
        batch_graphs = batch_graphs.to(device)
        batch_x = batch_graphs.ndata['feat'].to(device)
        batch_e = batch_graphs.edata['feat'].to(device)
        batch_labels = batch_labels.to(device)
        optimizer.zero_grad()

        try:
            batch_pos_enc = batch_graphs.ndata['pos_enc'].to(device)
            sign_flip = torch.rand(batch_pos_enc.size(1)).to(device)
            sign_flip[sign_flip>=0.5] = 1.0; sign_flip[sign_flip<0.5] = -1.0
            batch_pos_enc = batch_pos_enc * sign_flip.unsqueeze(0)
            batch_scores = model.forward(batch_graphs, batch_pos_enc)
        except:
            batch_scores = model.forward(batch_graphs)
        
        
        loss = model.loss(batch_scores, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        epoch_train_acc += accuracy_SBM(batch_scores, batch_labels)

    epoch_loss /= (iter + 1)
    epoch_train_acc /= (iter + 1)
    return epoch_loss, epoch_train_acc, optimizer


def accuracy_SBM(scores, targets):
    S = targets.cpu().numpy()
    C = np.argmax( torch.nn.Softmax(dim=1)(scores).cpu().detach().numpy() , axis=1 )
    CM = confusion_matrix(S,C).astype(np.float32)
    nb_classes = CM.shape[0]
    targets = targets.cpu().detach().numpy()
    nb_non_empty_classes = 0
    pr_classes = np.zeros(nb_classes)
    for r in range(nb_classes):
        cluster = np.where(targets==r)[0]
        if cluster.shape[0] != 0:
            pr_classes[r] = CM[r,r]/ float(cluster.shape[0])
            if CM[r,r]>0:
                nb_non_empty_classes += 1
        else:
            pr_classes[r] = 0.0
    acc = 100.* np.sum(pr_classes)/ float(nb_classes)
    return acc  

In [None]:
log_dir = os.path.join('./tf-logs/', "DATA_PATTERN")    
writer = SummaryWriter(log_dir=log_dir)
with tqdm(range(360)) as t:
    for epoch in t:
        t.set_description('Epoch %d' % epoch)
        epoch_train_loss, epoch_train_acc, optimizer = train_epoch_sparse(model, optimizer, device, train_loader, epoch)
        epoch_val_loss, epoch_val_acc = evaluate_network_sparse(model, device, val_loader, epoch)
        _, epoch_test_acc = evaluate_network_sparse(model, device, test_loader, epoch)                
        
        start = time.time()
        
        writer.add_scalar('train/_loss', epoch_train_loss, epoch)
        writer.add_scalar('val/_loss', epoch_val_loss, epoch)
        writer.add_scalar('train/_acc', epoch_train_acc, epoch)
        writer.add_scalar('val/_acc', epoch_val_acc, epoch)
        writer.add_scalar('test/_acc', epoch_test_acc, epoch)
        writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)
        
        ckpt_dir = os.path.join(os.getcwd(), "MODEL_PATTERN")
        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)
        torch.save(model.state_dict(), '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch)))

        t.set_postfix(time=time.time()-start, lr=optimizer.param_groups[0]['lr'],
                      train_loss=epoch_train_loss, val_loss=epoch_val_loss,
                      train_acc=epoch_train_acc, val_acc=epoch_val_acc, 
                      test_acc=epoch_test_acc)    

        scheduler.step(epoch_val_loss)

In [None]:
# torch.save(model, os.getcwd()+'/best_model.pth')