In [None]:
import numpy as np
import torch
import pickle
import time
import os
%matplotlib inline
import matplotlib.pyplot as plt
import math
from tqdm import tqdm
from tensorboardX import SummaryWriter
import glob
import dgl.function as fn
import torch.nn.functional as F
from dgl.nn.functional import edge_softmax

In [None]:
import os
os.chdir('../../') # go to root folder of the project
print(os.getcwd())


In [None]:
from data.dataset import DatasetProcess
from ogb.graphproppred import Evaluator

dataset = DatasetProcess('ogbg-ppa')
evaluator = Evaluator(name='ogbg-ppa')

In [None]:
import pickle

%load_ext autoreload
%autoreload 2


from torch.utils.data import DataLoader


In [None]:
split_idx = dataset.get_idx_split() 
train_loader = DataLoader(dataset[split_idx['train']], batch_size=32, shuffle=True, collate_fn=dataset.collate_dgl)
val_loader = DataLoader(dataset[split_idx['valid']], batch_size=32, shuffle=False, collate_fn=dataset.collate_dgl)
test_loader = DataLoader(dataset[split_idx['test']], batch_size=32, shuffle=False, collate_fn=dataset.collate_dgl)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
import dgl
import torch.optim as optim
from dgl.nn.pytorch import GATConv
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts



def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1):
    act = act_type.lower()
    
    if act == 'relu':
        layer = nn.ReLU(inplace)
    elif act == 'leakyrelu':
        layer = nn.LeakyReLU(neg_slope, inplace)
    elif act == 'prelu':
        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
    else:
        raise NotImplementedError('activation layer [%s] is not found' % act)
    
    return layer


def norm_layer(norm_type, nc):
    norm = norm_type.lower()

    if norm == 'batch':
        layer = nn.BatchNorm1d(nc, affine=True)
    elif norm == 'layer':
        layer = nn.LayerNorm(nc, elementwise_affine=True)
    elif norm == 'instance':
        layer = nn.InstanceNorm1d(nc, affine=False)
    else:
        raise NotImplementedError(f'Normalization layer {norm} is not supported.')

    return layer


class MLP(nn.Sequential):
    def __init__(self,
                 channels,
                 act='relu',
                 norm=None,
                 dropout=0.,
                 bias=True):
        layers = []
        
        for i in range(1, len(channels)):
            layers.append(nn.Linear(channels[i - 1], channels[i], bias))
            if i < len(channels) - 1:
                if norm is not None and norm.lower() != 'none':
                    layers.append(norm_layer(norm, channels[i]))
                if act is not None and act.lower() != 'none':
                    layers.append(act_layer(act))
                layers.append(nn.Dropout(dropout))
        
        super(MLP, self).__init__(*layers)


class MessageNorm(nn.Module):
    def __init__(self, learn_scale=False):
        super(MessageNorm, self).__init__()
        self.scale = nn.Parameter(torch.FloatTensor([1.0]), requires_grad=learn_scale)

    def forward(self, feats, msg, p=2):
        msg = F.normalize(msg, p=2, dim=-1)
        feats_norm = feats.norm(p=p, dim=-1, keepdim=True)
        return msg * feats_norm * self.scale


class GENConv(nn.Module):
    r"""
    Parameters
    ----------
    dataset: str
        Name of ogb dataset.
    in_dim: int
        Size of input dimension.
    out_dim: int
        Size of output dimension.
    aggregator: str
        Type of aggregator scheme ('softmax', 'power'), default is 'softmax'.
    beta: float
        A continuous variable called an inverse temperature. Default is 1.0.
    learn_beta: bool
        Whether beta is a learnable variable or not. Default is False.
    p: float
        Initial power for power mean aggregation. Default is 1.0.
    learn_p: bool
        Whether p is a learnable variable or not. Default is False.
    msg_norm: bool
        Whether message normalization is used. Default is False.
    learn_msg_scale: bool
        Whether s is a learnable scaling factor or not in message normalization. Default is False.
    norm: str
        Type of ('batch', 'layer', 'instance') norm layer in MLP layers. Default is 'batch'.
    mlp_layers: int
        The number of MLP layers. Default is 1.
    eps: float
        A small positive constant in message construction function. Default is 1e-7.
    """
    def __init__(self,
                 dataset,
                 in_dim,
                 out_dim,
                 aggregator='softmax',
                 beta=1.0,
                 learn_beta=False,
                 p=1.0,
                 learn_p=False,
                 msg_norm=False,
                 learn_msg_scale=False,
                 norm='batch',
                 mlp_layers=1,
                 eps=1e-7):
        super(GENConv, self).__init__()
        
        self.aggr = aggregator
        self.eps = eps

        channels = [in_dim]
        for i in range(mlp_layers - 1):
            channels.append(in_dim * 2)
        channels.append(out_dim)

        self.mlp = MLP(channels, norm=norm)
        self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None

        self.beta = nn.Parameter(torch.Tensor([beta]), requires_grad=True) if learn_beta and self.aggr == 'softmax' else beta
        self.p = nn.Parameter(torch.Tensor([p]), requires_grad=True) if learn_p else p

        if dataset == 'ogbg-molhiv':
            self.edge_encoder = BondEncoder(in_dim)
        elif dataset == 'ogbg-ppa':
            self.edge_encoder = nn.Linear(in_dim, in_dim)
        else:
            raise ValueError(f'Dataset {dataset} is not supported.')

    def forward(self, g, node_feats, edge_feats):
        with g.local_scope():
            # Node and edge feature dimension need to match.
            g.ndata['h'] = node_feats
            g.edata['h'] = self.edge_encoder(edge_feats)
            g.apply_edges(fn.u_add_e('h', 'h', 'm'))

            if self.aggr == 'softmax':
                g.edata['m'] = F.relu(g.edata['m']) + self.eps
                g.edata['a'] = edge_softmax(g, g.edata['m'] * self.beta)
                g.update_all(lambda edge: {'x': edge.data['m'] * edge.data['a']},
                             fn.sum('x', 'm'))
            
            elif self.aggr == 'power':
                minv, maxv = 1e-7, 1e1
                torch.clamp_(g.edata['m'], minv, maxv)
                g.update_all(lambda edge: {'x': torch.pow(edge.data['m'], self.p)},
                             fn.mean('x', 'm'))
                torch.clamp_(g.ndata['m'], minv, maxv)
                g.ndata['m'] = torch.pow(g.ndata['m'], self.p)
            
            else:
                raise NotImplementedError(f'Aggregator {self.aggr} is not supported.')
            
            if self.msg_norm is not None:
                g.ndata['m'] = self.msg_norm(node_feats, g.ndata['m'])
            
            feats = node_feats + g.ndata['m']
            
            return self.mlp(feats)    
    
    
    
    
class DeeperGCN(nn.Module):
    r"""

    Parameters
    ----------
    dataset: str
        Name of ogb dataset.
    node_feat_dim: int
        Size of node feature dimension.
    edge_feat_dim: int
        Size of edge feature dimension.
    hid_dim: int
        Size of hidden dimension.
    out_dim: int
        Size of output dimension.
    num_layers: int
        Number of graph convolutional layers.
    dropout: float
        Dropout rate. Default is 0.
    norm: str
        Type of ('batch', 'layer', 'instance') norm layer in MLP layers. Default is 'batch'.
    pooling: str
        Type of ('sum', 'mean', 'max') pooling layer. Default is 'mean'.
    beta: float
        A continuous variable called an inverse temperature. Default is 1.0.
    lean_beta: bool
        Whether beta is a learnable weight. Default is False.
    aggr: str
        Type of aggregator scheme ('softmax', 'power'). Default is 'softmax'.
    mlp_layers: int
        Number of MLP layers in message normalization. Default is 1.
    """
    def __init__(self,
                 dataset,
                 node_feat_dim,
                 edge_feat_dim,
                 hid_dim,
                 out_dim,
                 num_layers,
                 update_type=1,
                 dropout=0.,
                 norm='batch',
                 pooling='mean',
                 beta=1.0,
                 learn_beta=False,
                 aggr='softmax',
                 mlp_layers=1):
        super(DeeperGCN, self).__init__()
        
        self.dataset = dataset
        self.num_layers = num_layers
        self.dropout = dropout
        self.gcns = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.update_type = update_type

        for i in range(self.num_layers):
            conv = GENConv(dataset=dataset,
                           in_dim=hid_dim,
                           out_dim=hid_dim,
                           aggregator=aggr,
                           beta=beta,
                           learn_beta=learn_beta,
                           mlp_layers=mlp_layers,
                           norm=norm)
            
            self.gcns.append(conv)
            self.norms.append(norm_layer(norm, hid_dim))

        self.node_encoder = nn.Linear(node_feat_dim, hid_dim)
        self.edge_encoder = nn.Linear(edge_feat_dim, hid_dim)


#         if pooling == 'sum':
#             self.pooling = SumPooling()
#         elif pooling == 'mean':
#             self.pooling = AvgPooling()
#         elif pooling == 'max':
#             self.pooling = MaxPooling()
#         else:
#             raise NotImplementedError(f'{pooling} is not supported.')
        
        self.output = nn.Linear(hid_dim, out_dim)

    def forward(self, g, hv, he):
        hv = self.node_encoder(hv)
        he = self.node_encoder(he)
        
        with g.local_scope():
            for layer in range(self.num_layers):
                hv1 = self.norms[layer](hv)
                hv1 = F.relu(hv1)
                hv1 = F.dropout(hv1, p=self.dropout, training=self.training)
                hv = self.gcns[layer](g, hv1, he) + hv

#             h_g = self.pooling(g, hv)

            return hv, he  
    
    
    

In [None]:
class GIGLayer(nn.Module):
    def __init__(self, net_params, device, is_last_layer):
        super(GIGLayer, self).__init__()
        self.device = device
        self.is_last_layer = is_last_layer
        SGU_params = net_params.copy()
        GGU_P1_params = net_params.copy()
        GGU_P2_params = net_params.copy()
        SGU_params['L'] = 2
        GGU_P1_params['L'] = 1
        GGU_P2_params['L'] = 1
        GGU_P1_params['in_dim'] = net_params['hidden_dim']
        GGU_P2_params['in_dim'] = net_params['hidden_dim']
        
        self.SGU = DeeperGCN(dataset='ogbg-ppa',
                      node_feat_dim=7,
                      edge_feat_dim=7,
                      hid_dim=100,
                      out_dim=net_params['n_classes'],
                      num_layers=net_params['L']).to(device)    
        self.GGU_Part1 = DeeperGCN(dataset='ogbg-ppa',
                      node_feat_dim=7,
                      edge_feat_dim=7,
                      hid_dim=100,
                      out_dim=net_params['n_classes'],
                      num_layers=GGU_P1_params['L']).to(device) 
        self.GGU_Part2 = DeeperGCN(dataset='ogbg-ppa',
                      node_feat_dim=7,
                      edge_feat_dim=7,
                      hid_dim=100,
                      out_dim=net_params['n_classes'],
                      num_layers=GGU_P2_params['L']).to(device) 
        
    
    def forward(self, SGU_graph, GGU_P1_graph, GGU_P2_graph, h, SGU_e, GGU_P1_e, GGU_P2_e, stage):
        h1, e1 = self.SGU(SGU_graph, h, SGU_e)
        h2, e2 = self.GGU_Part1(GGU_P1_graph, h1, GGU_P1_e)
        h3, e3 = self.GGU_Part2(GGU_P2_graph, h2, GGU_P2_e)

        if not self.is_last_layer:
            return SGU_graph, GGU_P1_graph, GGU_P2_graph, h3, e1, e2, e3
        else:
            SGU_graph.ndata['h'] = h3
            hg = dgl.mean_nodes(SGU_graph, 'h')
            return hg

In [None]:
class GIGNet(nn.Module):
    def __init__(self, net_params, n_layers, num_neighbors, device):
        super().__init__()
        self.device = device
        self.num_neighbors = num_neighbors
        self.layers = nn.ModuleList([GIGLayer(net_params, device, False) for _ in range(n_layers-1)])
        self.layers.append(GIGLayer(net_params, device, True))
        self.edge_encoder = nn.Linear(net_params['in_dim'], net_params['hidden_dim'])
        self.node_encoder = nn.Linear(net_params['in_dim'], net_params['hidden_dim'])
#         self.MLP_layer = MLPReadout(net_params['out_dim'], net_params['n_classes'])
        self.output = nn.Linear(net_params['hidden_dim'], net_params['n_classes'])

    #Obtain global neighbors
    def global_neighbors(self, h, num_global_neighbors):
        h = h.to(self.device)
        si = torch.einsum('i j , j k -> i k', h, h.transpose(0, 1)).to(self.device)
        diag = torch.diag(si).to(self.device)
        a_diag = torch.diag_embed(diag).to(self.device)
        si = (si - a_diag).to(self.device)
        _, target_inds = si.topk(k=num_global_neighbors, dim=1, largest=True)
        return target_inds.to(self.device)
    
    #GIG data generation module
    def GDG(self, graphs, indices, edge_lengths, node_lengths):
        global_src = torch.zeros((0)).to(self.device)
        global_tgt = torch.zeros((0)).to(self.device)
        node_index = 0
        edge_index = 0
        SGU_graphs = []
        GGU_P2_graphs = []
        proxy_inds = []
        proxy_vertices = torch.zeros((0)).to(self.device)
        #graph construction in DGL for SGU and GGU part 2
        for ind, indice in enumerate(indices):
            node_length = node_lengths[ind]
            edge_length = edge_lengths[ind]
            
            #Vertex and edge features from current sub-graph
            h = graphs.ndata['feat'][node_index:node_index+node_length, :].to(self.device)
            e = graphs.edata['feat'][edge_index:edge_index+edge_length, :].to(self.device)
            
            #Source and target for SGU module
            src1 = torch.arange(node_length).to(self.device)
            tgt1 = torch.full([node_length,], node_length).to(self.device)
            src1 = torch.cat((indice[0].to(self.device), src1))
            tgt1 = torch.cat((indice[1].to(self.device), tgt1))
            
            #Source and target for GGU part 2
            src3 = torch.full([node_length,], node_length).to(self.device)
            tgt3 = torch.arange(node_length).to(self.device)
            
            #SGU graph
            SGU_g = dgl.graph((src1, tgt1)).to(self.device)
            
            #GGU part 2 graph
            GGU_P2_g = dgl.graph((src3, tgt3)).to(self.device)
            
            #Proxy vertex feature initilization 
            proxy_vertex = h.mean(dim=0).reshape(1,100).to(self.device)
            proxy_vertices = torch.cat((proxy_vertices, proxy_vertex), dim=0)
            
            #Vertex and edge features construction
            SGU_node_features = torch.cat((h, proxy_vertex), dim=0)
            GGU_P1_node_features = torch.cat((h, proxy_vertex), dim=0)
            SGU_edge_features = torch.cat((e, torch.randn(h.shape[0],100).to(self.device)))
            GGU_P2_edge_features = torch.randn(h.shape[0],100).to(self.device)
            SGU_g.ndata['feat'] = SGU_node_features
            GGU_P2_g.ndata['feat'] = GGU_P1_node_features
            SGU_g.edata['feat'] = SGU_edge_features
            GGU_P2_g.edata['feat'] = GGU_P2_edge_features
            
            #Update indexes
            node_index += node_length
            edge_index += edge_length
            
            #Append graphs into corresponding sets
            SGU_graphs.append(SGU_g)
            GGU_P2_graphs.append(GGU_P2_g)
            
            if ind == 0:
                current_ind = node_length
            else:
                current_ind = current_ind + node_length + 1
            proxy_inds.append(current_ind)
            
        #Combine single graphs to form a new DGL graph
        SGU_graphs = dgl.batch(SGU_graphs).to(self.device)
        GGU_P2_graphs = dgl.batch(GGU_P2_graphs).to(self.device)
        
        #Obtain edge features of SGU and GGU part 2
        SGU_k = SGU_graphs.edata['feat']
        GGU_P2_k = GGU_P2_graphs.edata['feat']
        
        #graph construction in DGL for GGU part 1
        num_global_neighbors = self.num_neighbors
        proxy_inds = torch.tensor(proxy_inds)
        target_inds = self.global_neighbors(proxy_vertices, num_global_neighbors)
        for ind, i in enumerate(proxy_inds):
            src1 = torch.full([num_global_neighbors,], i).to(self.device)
            tgt1 = proxy_inds[target_inds[ind,:]].to(self.device)
            src2 = proxy_inds[target_inds[ind,:]].to(self.device)
            tgt2 = torch.full([num_global_neighbors,], i).to(self.device)
            
            src = torch.cat((src1, src2))
            tgt = torch.cat((tgt1, tgt2))
            
            global_src = torch.cat((global_src, src))
            global_tgt = torch.cat((global_tgt, tgt)) 
        
        #Ensure the number of vertex matches with original graph
        last_src = torch.tensor([SGU_graphs.num_nodes()-1]).to(self.device)
        last_tgt = torch.tensor([SGU_graphs.num_nodes()-1]).to(self.device)
        global_src = torch.cat((global_src, last_src))
        global_tgt = torch.cat((global_tgt, last_tgt))
          
        #Form new DGL graph for GGU part 1
        GGU_P1_graphs = dgl.graph((global_src.long(), global_tgt.long())).to(self.device)
        GGU_P1_k = torch.randn(proxy_inds.shape[0]*num_global_neighbors*2+1, 100).to(self.device)

        return SGU_graphs, GGU_P1_graphs, GGU_P2_graphs, SGU_k, GGU_P1_k, GGU_P2_k
    
    def forward(self, graphs):
        indices = []    
        edge_lengths = []
        node_lengths = []
        for graph in dgl.unbatch(graphs):
            adj = graph.adjacency_matrix(transpose=False)._indices()
            ind1 = adj[0]
            ind2 = adj[1]
            inds = [ind1, ind2]
            edge_lengths.append(graph.num_edges())
            node_lengths.append(graph.num_nodes())
            indices.append(inds)
            
            
        graphs.edata['h'] = graphs.edata['feat']
        graphs.update_all(fn.copy_e('h', 'm'), fn.sum('m', 'h1'))
        graphs.ndata['feat'] = self.node_encoder(graphs.ndata['h1'])
        graphs.edata['feat'] = self.edge_encoder(graphs.edata['feat'])
        
        
        #GIG data
        SGU_graphs, GGU_P1_graphs, GGU_P2_graphs, SGU_k, GGU_P1_k, GGU_P2_k = \
        self.GDG(graphs, indices, edge_lengths, node_lengths)
        h = SGU_graphs.ndata['feat']
        
        #Forward propagation
        for ind, conv in enumerate(self.layers):
            if ind < len(self.layers)-1:
                SGU_graphs, GGU_P1_graphs, GGU_P2_graphs, h, SGU_k, GGU_P1_k, GGU_P2_k = \
                conv(SGU_graphs, GGU_P1_graphs, GGU_P2_graphs, h, SGU_k, GGU_P1_k, GGU_P2_k, ind)
            else:
                best_scores = \
                conv(SGU_graphs, GGU_P1_graphs, GGU_P2_graphs, h, SGU_k, GGU_P1_k, GGU_P2_k, ind)
        return self.output(best_scores)

    
    def loss(self, pred, label):
        criterion = nn.CrossEntropyLoss().to(self.device)
        loss = criterion(pred, label.long())
        return loss

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


net_params = {}
net_params['device'] = device
net_params['gated'] = True 
net_params['in_dim'] = dataset[split_idx['train']][0][0].edata['feat'].size()[-1]
net_params['in_dim_edge'] = dataset[split_idx['train']][0][0].edata['feat'].size()[-1]
net_params['residual'] = True
net_params['hidden_dim'] = 105 
net_params['out_dim'] = 105
num_classes = int(dataset.num_classes)
net_params['n_classes'] = num_classes
net_params['n_heads'] = -1
net_params['L'] = 4  
net_params['readout'] = 'mean'
net_params['layer_norm'] = True
net_params['batch_norm'] = True
net_params['in_feat_dropout'] = 0.1
net_params['dropout'] = 0.05 
net_params['edge_feat'] = True
net_params['self_loop'] = False

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

torch.cuda.manual_seed(41)
    
model = GIGNet(net_params, 2, 6, device)
model = model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=0.01, weight_decay=0)
scheduler = CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=30, verbose=True)

In [None]:
def train_epoch_sparse(model, optimizer, device, data_loader):
    model.train()
    
    train_loss = []
    for g, labels in data_loader:
        g = g.to(device)
        labels = labels.to(device)
        logits = model.forward(g)
        loss = F.nll_loss(logits, labels.squeeze(1).long())
        train_loss.append(loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return sum(train_loss) / len(train_loss)


@torch.no_grad()
def evaluate_network_sparse(model, device, data_loader, evaluator):
    model.eval()
    y_true, y_pred = [], []

    for g, labels in data_loader:
        g = g.to(device)
        logits = model.forward(g)
        y_true.append(labels.detach().cpu())
        y_pred.append(logits.argmax(dim=-1, keepdim=True).detach().cpu())
    
    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()

    return evaluator.eval({
        'y_true': y_true,
        'y_pred': y_pred
    })['acc']

In [None]:
log_dir = os.path.join('/root/tf-logs/', "DATA_OGBG_PPA")    
writer = SummaryWriter(log_dir=log_dir)
with tqdm(range(360)) as t:
    for epoch in t:
        t.set_description('Epoch %d' % epoch)
        epoch_train_loss = train_epoch_sparse(model, optimizer, device, train_loader)
        epoch_val_acc = evaluate_network_sparse(model, device, val_loader, evaluator)
        epoch_test_acc = evaluate_network_sparse(model, device, test_loader, evaluator)                
        
        start = time.time()
        
        writer.add_scalar('train/_loss', epoch_train_loss, epoch)
        writer.add_scalar('val/_acc', epoch_val_acc, epoch)
        writer.add_scalar('test/_acc', epoch_test_acc, epoch)
        writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)
        
        ckpt_dir = os.path.join(os.getcwd(), "MODEL_OGBG_PPA")
        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)
        torch.save(model.state_dict(), '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch)))
   
        t.set_postfix(time=time.time()-start, lr=optimizer.param_groups[0]['lr'],
                      train_loss=epoch_train_loss, 
                      val_acc=epoch_val_acc, 
                      test_acc=epoch_test_acc)    

        scheduler.step()