In [36]:
import torch 
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from torch import optim
import numpy as np
from random import shuffle

import dgl
from dgl import function as fn
from dgl.base import DGLError
from dgl.utils import expand_as_pair, check_eq_shape
from dgl.data import citation_graph

# Data

In [75]:
data = citation_graph.load_cora()
features = data.features
labels = torch.LongTensor(data.labels)
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
g = dgl.from_networkx(data.graph)

Loading from cache failed, re-processing.
Finished data loading and preprocessing.
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done saving data into cached files.


In [66]:
train_mask

tensor([ True,  True,  True,  ..., False, False, False])

# Model

In [67]:
class GATConv(nn.Module):
    def __init__(self, g, num_ins, num_outs,activation = True):
        super().__init__()
        self.fc = nn.Linear(num_ins, num_outs, bias = True)
        self.fc2 = nn.Linear(num_outs * 2, 1)
        self.activation = activation
        self.g = g
        self.reset_parameters()
    
    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_uniform_(self.fc.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
    
    def message_passing(self, edges):  
        return {'m': edges.src['z'], 's': edges.data['s']}
        
    def reduce_function(self, nodes):
        return {'out': torch.sum(nodes.mailbox['m']*nodes.mailbox['s'], dim = 1)}
    
    def edge_attention(self, edges):
        eatt = F.relu(self.fc2(torch.cat([edges.src['z'], edges.dst['z']], dim = -1)))
        return {'s': F.softmax(eatt, -1)}
    
    def forward(self, inputs):

        self.g.ndata.update({'z': self.fc(inputs)})        
        self.g.apply_edges(self.edge_attention)
        self.g.update_all(self.message_passing, self.reduce_function)
        if self.activation: return F.relu(self.g.ndata['out'])
        else: return self.g.ndata['out']

In [134]:
class MultiheadGAT(nn.Module):
    def __init__(self, g, num_ins, num_outs, concat = True, k = 8, activation = True):
        super().__init__()
        self.concat = concat
        self.activation = activation
        if concat:
            assert num_outs % k == 0, 'invalid num outs'
            self.gconvs =  nn.ModuleList([GATConv(g, num_ins, num_outs//k) for i in range(k)])
        else: 
            self.gconvs =  nn.ModuleList([GATConv(g, num_ins, num_outs, activation= False) for i in range(k)])
    
    def forward(self, inputs):
        if self.concat:
            return torch.cat([gconv(inputs) for gconv in self.gconvs], dim = 1)
        else: 
            out = torch.stack([gconv(inputs) for gconv in self.gconvs], dim = 0)
            out = torch.mean(out, dim = 0)
            return F.relu(out) if self.activation else out

# Training

In [139]:
net = nn.Sequential(MultiheadGAT(g, 1433, 480), MultiheadGAT(g, 480, 7, concat= False, activation=False))
optimizer = optim.Adam(net.parameters(), lr = 1e-3)

for epoch in range(30):
    logits = net(features)
    logp = F.log_softmax(logits)
    loss = F.nll_loss(logp[train_mask], labels[train_mask])
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    print('epoch: {0}, loss: {1}'.format(epoch, loss.item()))

  


epoch: 0, loss: 2.0811610221862793
epoch: 1, loss: 1.8620535135269165
epoch: 2, loss: 1.8042323589324951
epoch: 3, loss: 1.7786604166030884
epoch: 4, loss: 1.742370843887329
epoch: 5, loss: 1.6985654830932617
epoch: 6, loss: 1.6517014503479004
epoch: 7, loss: 1.6036046743392944
epoch: 8, loss: 1.5528745651245117
epoch: 9, loss: 1.501745581626892
epoch: 10, loss: 1.4532300233840942
epoch: 11, loss: 1.4100029468536377
epoch: 12, loss: 1.3730733394622803
epoch: 13, loss: 1.340494990348816
epoch: 14, loss: 1.3077342510223389
epoch: 15, loss: 1.2711265087127686
epoch: 16, loss: 1.2322332859039307
epoch: 17, loss: 1.1940268278121948
epoch: 18, loss: 1.1575514078140259
epoch: 19, loss: 1.1224384307861328
epoch: 20, loss: 1.088490605354309
epoch: 21, loss: 1.055169939994812
epoch: 22, loss: 1.0215721130371094
epoch: 23, loss: 0.9876906275749207
epoch: 24, loss: 0.9543854594230652
epoch: 25, loss: 0.9222344756126404
epoch: 26, loss: 0.8913723826408386
epoch: 27, loss: 0.8616583943367004
epoch: 

In [140]:
def accuracy(features, mask):
    num = torch.tensor((features[mask]).shape[0], dtype = torch.float)
    logits = net(features)
    return torch.sum(torch.max(logits, dim = 1)[1][mask] == labels[mask])/num

In [141]:
accuracy(features, val_mask)

tensor(0.7360)

In [142]:
a = nn.Sequential(nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 2)), nn.ReLU())
list(a.parameters())

[Parameter containing:
 tensor([[ 0.3071, -0.2970,  0.4974],
         [-0.3593,  0.2232,  0.3549],
         [-0.1576, -0.5640,  0.1042],
         [-0.5513,  0.1275, -0.0775]], requires_grad=True),
 Parameter containing:
 tensor([-0.0260,  0.3171, -0.0312, -0.0210], requires_grad=True),
 Parameter containing:
 tensor([[ 0.4282, -0.0090,  0.2219, -0.3619],
         [ 0.2819,  0.0856,  0.0566, -0.4666]], requires_grad=True),
 Parameter containing:
 tensor([-0.1567,  0.4183], requires_grad=True)]