In [205]:
import pandas as pd
import tqdm
import numpy as np
import sklearn.metrics

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.nn as dglnn

In [1]:
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset(name='ogbn-mag')

Using backend: pytorch


Downloading http://snap.stanford.edu/ogb/data/nodeproppred/mag.zip


Downloaded 0.40 GB: 100%|██████████| 413/413 [04:41<00:00,  1.46it/s]


Extracting dataset/mag.zip
Loading necessary files...
This might take a while.


100%|██████████| 1/1 [00:00<00:00, 8507.72it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

Processing graphs...
Converting graphs into DGL objects...


100%|██████████| 1/1 [00:40<00:00, 40.48s/it]


Saving...


데이터셋은 다음과 같은 요소들이 포함되어있음.

- DGL graph object

- The node label tensor

GPU 명시

In [172]:
device = 'cuda'

In [3]:
graph, label = dataset[0] # graph ; dgl graph object, label ; torch tensor of shape (num_nodes, 1)

split_idx = dataset.get_idx_split()
train_nids, valid_nids, test_nids = split_idx['train'], split_idx['valid'], split_idx['test']

# Data description

In [33]:
node_type = pd.DataFrame(graph.ntypes,columns=['node_type'])
edge_type = pd.DataFrame(graph.etypes,columns=['edge_type'])
graph_data = pd.concat([node_type,edge_type,relation_type],axis=1)

In [44]:
graph_data

Unnamed: 0,node_type,edge_type,source,relation,destination
0,author,affiliated_with,author,affiliated_with,institution
1,field_of_study,writes,author,writes,paper
2,institution,cites,paper,cites,paper
3,paper,has_topic,paper,has_topic,field_of_study


Graph 가 heterogeneous 하기 때문에 우리의 node를 사전형 자료구조로 구성했음. key는 node type , value는 node id 리스트로 이루어져 있음 

In [53]:
graph.metagraph().edges()

OutMultiEdgeDataView([('author', 'institution'), ('author', 'paper'), ('paper', 'paper'), ('paper', 'field_of_study')])

In [52]:
train_nids

{'paper': tensor([     0,      1,      2,  ..., 736386, 736387, 736388])}

1939743

In [79]:
# node의 데이터갯수가 balanced 될 필요 없음.
print(f'노드 총 갯수')
display(graph.num_nodes())
print(f'author 갯수')
display(graph.num_nodes('author'))
print(f'field_of_study 갯수')
display(graph.num_nodes('field_of_study'))
print(f'institution 갯수')
display(graph.num_nodes('institution'))
print(f'paper 갯수')
display(graph.num_nodes('paper'))

노드 총 갯수


1939743

author 갯수


1134649

field_of_study 갯수


59965

institution 갯수


8740

paper 갯수


736389

In [80]:
print(f'엣지 총 갯수')
display(graph.num_edges())
print(f'affiliated 엣지 총 갯수')
display(graph.num_edges('affiliated_with'))
print(f'writes 엣지 총 갯수')
display(graph.num_edges('writes'))
print(f'cities 엣지 총 갯수')
display(graph.num_edges('cites'))
print(f'has_topic 엣지 총 갯수')
display(graph.num_edges('has_topic'))

엣지 총 갯수


21111007

affiliated 엣지 총 갯수


1043998

writes 엣지 총 갯수


7145660

cities 엣지 총 갯수


5416271

has_topic 엣지 총 갯수


7505078

In [104]:
# 자주 헷갈리는 item()의 유무에 따른 data 

display((node_labels.max() + 1))
display((node_labels.max() + 1).item())

tensor(349)

349

- node features 형태는 모두 tensor로 담겨있어야 함.

In [107]:
print(graph)

print('Node labels')
node_labels = label['paper'].flatten()

print('Shape of target node labels:', node_labels.shape)
num_classes = (node_labels.max() + 1).item()
print(f'Num of classes : {num_classes}')

print('Node features')
node_features = graph.nodes['paper'].data['feat']
num_features = node_features.shape[1]
print(f'Shape of features of paper node type : {num_features}')


Graph(num_nodes={'author': 1134649, 'field_of_study': 59965, 'institution': 8740, 'paper': 736389},
      num_edges={('author', 'affiliated_with', 'institution'): 1043998, ('author', 'writes', 'paper'): 7145660, ('paper', 'cites', 'paper'): 5416271, ('paper', 'has_topic', 'field_of_study'): 7505078},
      metagraph=[('author', 'institution', 'affiliated_with'), ('author', 'paper', 'writes'), ('paper', 'paper', 'cites'), ('paper', 'field_of_study', 'has_topic')])
Node labels
Shape of target node labels: torch.Size([736389])
Num of classes : 349
Node features
Shape of features of paper node type : 128


## Add reverse edges

Realation 이 고정된, directed 형태로 구성되어있기에 reversed 해주어 undirected 로 설정해주는 과정.

In [133]:
print(graph.metagraph().nodes())
print(graph.metagraph().edges())

['author', 'institution', 'paper', 'field_of_study']
[('author', 'institution'), ('author', 'paper'), ('paper', 'paper'), ('paper', 'field_of_study')]


In [134]:
graph

Graph(num_nodes={'author': 1134649, 'field_of_study': 59965, 'institution': 8740, 'paper': 736389},
      num_edges={('author', 'affiliated_with', 'institution'): 1043998, ('author', 'writes', 'paper'): 7145660, ('paper', 'cites', 'paper'): 5416271, ('paper', 'has_topic', 'field_of_study'): 7505078},
      metagraph=[('author', 'institution', 'affiliated_with'), ('author', 'paper', 'writes'), ('paper', 'paper', 'cites'), ('paper', 'field_of_study', 'has_topic')])

In [228]:
src_writes, dst_writes = graph.all_edges(etype="writes")
src_topic, dst_topic = graph.all_edges(etype="has_topic")
src_aff, dst_aff = graph.all_edges(etype="affiliated_with")


graph = dgl.heterograph({
    ("author", "writes", "paper"): (src_writes, dst_writes),
    ("paper", "has_topic", "field_of_study"): (src_topic, dst_topic),
    ("author", "affiliated_with", "institution"): (src_aff, dst_aff),
    ("paper", "writes-rev", "author"): (dst_writes, src_writes),
    ("field_of_study", "has_topic-rev", "paper"): (dst_topic, src_topic),
    ("institution", "affiliated_with-rev", "author"): (dst_aff, src_aff),
})

In [151]:
display(graph)
print(' * ' * 45)
display(graph_2)

Graph(num_nodes={'author': 1134649, 'field_of_study': 59965, 'institution': 8740, 'paper': 736389},
      num_edges={('author', 'affiliated_with', 'institution'): 1043998, ('author', 'writes', 'paper'): 7145660, ('paper', 'cites', 'paper'): 5416271, ('paper', 'has_topic', 'field_of_study'): 7505078},
      metagraph=[('author', 'institution', 'affiliated_with'), ('author', 'paper', 'writes'), ('paper', 'paper', 'cites'), ('paper', 'field_of_study', 'has_topic')])

 *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  *  * 


Graph(num_nodes={'author': 1134649, 'field_of_study': 59965, 'institution': 8740, 'paper': 736389},
      num_edges={('author', 'affiliated_with', 'institution'): 1043998, ('author', 'writes', 'paper'): 7145660, ('field_of_study', 'has_topic', 'paper'): 7505078, ('institution', 'affiliated_with', 'author'): 1043998, ('paper', 'has_topic', 'field_of_study'): 7505078, ('paper', 'writes', 'author'): 7145660},
      metagraph=[('author', 'institution', 'affiliated_with'), ('author', 'paper', 'writes'), ('institution', 'author', 'affiliated_with'), ('paper', 'field_of_study', 'has_topic'), ('paper', 'author', 'writes'), ('field_of_study', 'paper', 'has_topic')])

- 그래프 structure은 같으나 내부 feature 는 반영되지 않았음.

In [157]:
display(graph.nodes['paper'].data['feat'])
graph_2.nodes['paper'].data['feat']

tensor([[-0.0954,  0.0408, -0.2109,  ...,  0.0616, -0.0277, -0.1338],
        [-0.1510, -0.1073, -0.2220,  ...,  0.3458, -0.0277, -0.2185],
        [-0.1148, -0.1760, -0.2606,  ...,  0.1731, -0.1564, -0.2780],
        ...,
        [ 0.0228, -0.0865,  0.0981,  ..., -0.0547, -0.2077, -0.2305],
        [-0.2891, -0.2029, -0.1525,  ...,  0.1042,  0.2041, -0.3528],
        [-0.0890, -0.0348, -0.2642,  ...,  0.2601, -0.0875, -0.5171]])

KeyError: 'feat'

## Defining neighbor sampler and data loader in DGL

- 2개의 R-GCN (2-hop) layer 을 통해 neighbor sampling 하기 위해 우리는 연결된 관계로부터 15개의 neigbors's info를  가져올것임. 

In [229]:
sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 15]) # sampling 이 때 layer가 2개이니 이 역시도 2개로 matching 해줘야 함.


# cuda (gpu) 를 활용하여 좀 더 빠른 sampling 할 수 있게 설정하는걸 권고함. -> training 시 input , output device Inconsisten error happens
train_dataloader = dgl.dataloading.NodeDataLoader(
    graph, train_nids, sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0,
    device='cpu'
)

In [230]:
example_minibatch = next(iter(train_dataloader)) # neighbor sampling
print(example_minibatch)

[{'author': tensor([   7140,   98061,  380753,  ..., 1017013,  469025, 1051417]), 'field_of_study': tensor([12271, 13062, 13979,  ..., 18351, 18970, 28496]), 'institution': tensor([ 649, 1201, 5887,  ..., 1767, 4918, 1897]), 'paper': tensor([ 18312, 622480, 158710,  ..., 401026, 591121, 555247])}, {'author': tensor([], dtype=torch.int64), 'field_of_study': tensor([], dtype=torch.int64), 'institution': tensor([], dtype=torch.int64), 'paper': tensor([ 18312, 622480, 158710,  ...,  19714, 576902, 351818])}, [Block(num_src_nodes={'author': 5003, 'field_of_study': 3514, 'institution': 1311, 'paper': 75935},
      num_dst_nodes={'author': 4783, 'field_of_study': 3514, 'institution': 0, 'paper': 1024},
      num_edges={('author', 'affiliated_with', 'institution'): 0, ('author', 'writes', 'paper'): 4823, ('field_of_study', 'has_topic-rev', 'paper'): 10543, ('institution', 'affiliated_with-rev', 'author'): 7009, ('paper', 'has_topic', 'field_of_study'): 50961, ('paper', 'writes-rev', 'author'):

In [232]:
## 이해가 안가는 파트 ... 

input_nodes, output_nodes, bipartites = example_minibatch
print("To compute {} target nodes' output we need {} nodes' input features".format(len(output_nodes['paper']), len(input_nodes['paper'])))

print("")
print("Output nodes")
print(output_nodes)

print("")
print("Input nodes")
print(input_nodes)

To compute 1024 target nodes' output we need 75935 nodes' input features

Output nodes
{'author': tensor([], dtype=torch.int64), 'field_of_study': tensor([], dtype=torch.int64), 'institution': tensor([], dtype=torch.int64), 'paper': tensor([ 18312, 622480, 158710,  ...,  19714, 576902, 351818])}

Input nodes
{'author': tensor([   7140,   98061,  380753,  ..., 1017013,  469025, 1051417]), 'field_of_study': tensor([12271, 13062, 13979,  ..., 18351, 18970, 28496]), 'institution': tensor([ 649, 1201, 5887,  ..., 1767, 4918, 1897]), 'paper': tensor([ 18312, 622480, 158710,  ..., 401026, 591121, 555247])}


In [217]:
for block in bipartites:
    print(block)
    print()

Block(num_src_nodes={'author': 22628, 'field_of_study': 0, 'institution': 0, 'paper': 26879},
      num_dst_nodes={'author': 4637, 'field_of_study': 0, 'institution': 0, 'paper': 6242},
      num_edges={('author', 'affiliated_with', 'institution'): 0, ('author', 'writes', 'paper'): 30521, ('paper', 'cites', 'paper'): 34036, ('paper', 'has_topic', 'field_of_study'): 0},
      metagraph=[('author', 'institution', 'affiliated_with'), ('author', 'paper', 'writes'), ('paper', 'paper', 'cites'), ('paper', 'field_of_study', 'has_topic')])

Block(num_src_nodes={'author': 4637, 'field_of_study': 0, 'institution': 0, 'paper': 6242},
      num_dst_nodes={'author': 0, 'field_of_study': 0, 'institution': 0, 'paper': 1024},
      num_edges={('author', 'affiliated_with', 'institution'): 0, ('author', 'writes', 'paper'): 4690, ('paper', 'cites', 'paper'): 5267, ('paper', 'has_topic', 'field_of_study'): 0},
      metagraph=[('author', 'institution', 'affiliated_with'), ('author', 'paper', 'writes'), ('

# Defining model

- __ModuleList__

    - nn.Module을 리스트로 정리하는 방법이다.

    - 각 레이어를 리스트에 전달하고 레이어의 iterator를 만든다. 덕분에 forward처리를 간단하게 할 수 있다는 듯 하다.

    - 처음으로 적는 것은 아주 무식하게 하나하나 다 적어서 리스트에 넣고 for로 돌리는 방식이다.

In [233]:
class RGCN(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers, rel_names):
        super().__init__()
        
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        
        self.layers.append(dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, n_hidden)
            for rel in rel_names}, aggregate='sum'))
        
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.HeteroGraphConv({
                rel: dglnn.GraphConv(n_hidden, n_hidden)
                for rel in rel_names}, aggregate='sum'))
            
        self.layers.append(dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(n_hidden, n_classes)
            for rel in rel_names}, aggregate='sum'))

    def forward(self, bipartites, x):
        # inputs are features of nodes
        for l, (layer, bipartite) in enumerate(zip(self.layers, bipartites)):
            x = layer(bipartite, x)
            if l != self.n_layers - 1:
                x = {k: F.relu(v) for k, v in x.items()}
        return x

## What to do about featureless nodes

이번 튜토리얼에서는 총 4개의 노드 'author' , 'field_of_study', 'institution' , 'paper' 중에서 'paper' 만 feature 가 존재함 . 다시 리캡해보면 아래 코드로 확인이 가능함.

In [184]:
display(graph.nodes['paper'].data['feat'])
display(graph.nodes['field_of_study'].data['feat'])
display(graph.nodes['institution'].data['feat'])
display(graph.nodes['author'].data['feat'])



tensor([[-0.0954,  0.0408, -0.2109,  ...,  0.0616, -0.0277, -0.1338],
        [-0.1510, -0.1073, -0.2220,  ...,  0.3458, -0.0277, -0.2185],
        [-0.1148, -0.1760, -0.2606,  ...,  0.1731, -0.1564, -0.2780],
        ...,
        [ 0.0228, -0.0865,  0.0981,  ..., -0.0547, -0.2077, -0.2305],
        [-0.2891, -0.2029, -0.1525,  ...,  0.1042,  0.2041, -0.3528],
        [-0.0890, -0.0348, -0.2642,  ...,  0.2601, -0.0875, -0.5171]])

KeyError: 'feat'

- 우리는 message passing 을 하기 위해 각 노드마다 feature 가 필요함. 그를 위해 우리는 Embedding layer로 부터 representation 으로 feature 을 대체하고자 함 !

In [186]:
# feature 을 가지고 있는 노드

graph.nodes['paper']

NodeSpace(data={'year': tensor([[2015],
        [2012],
        [2012],
        ...,
        [2016],
        [2017],
        [2014]]), 'feat': tensor([[-0.0954,  0.0408, -0.2109,  ...,  0.0616, -0.0277, -0.1338],
        [-0.1510, -0.1073, -0.2220,  ...,  0.3458, -0.0277, -0.2185],
        [-0.1148, -0.1760, -0.2606,  ...,  0.1731, -0.1564, -0.2780],
        ...,
        [ 0.0228, -0.0865,  0.0981,  ..., -0.0547, -0.2077, -0.2305],
        [-0.2891, -0.2029, -0.1525,  ...,  0.1042,  0.2041, -0.3528],
        [-0.0890, -0.0348, -0.2642,  ...,  0.2601, -0.0875, -0.5171]])})

In [189]:
# relation 만 가지고 있는 node

display(graph.nodes['field_of_study'])
display(graph.nodes['institution'])
display(graph.nodes['author'])

NodeSpace(data={})

NodeSpace(data={})

NodeSpace(data={})

In [234]:
class NodeEmbed(nn.Module):
    def __init__(self, num_nodes, embed_size,):
        super(NodeEmbed, self).__init__()
        self.embed_size = embed_size
        self.node_embeds = nn.ModuleDict()
        for ntype in num_nodes:
            node_embed = torch.nn.Embedding(num_nodes[ntype], self.embed_size)
            nn.init.uniform_(node_embed.weight, -1.0, 1.0)
            self.node_embeds[str(ntype)] = node_embed
    
    def forward(self, node_ids):
        embeds = {}
        for ntype in node_ids:
            embeds[ntype] = self.node_embeds[ntype](node_ids[ntype])
        return embeds

## Initialize model and optimizer

In [193]:
# recap section
display(graph.ntypes)
display(graph.etypes)


['author', 'field_of_study', 'institution', 'paper']

['affiliated_with', 'writes', 'cites', 'has_topic']

In [235]:
# feature 가 있는 paper 은 제외

num_nodes = {ntype: graph.number_of_nodes(ntype) for ntype in graph.ntypes if ntype != 'paper'}
num_layers = 2
hidden_dim = 128
embed = NodeEmbed(num_nodes, hidden_dim)
model = RGCN(num_features, hidden_dim, num_classes, num_layers, graph.etypes).cuda()
opt = torch.optim.Adam(list(model.parameters()) + list(embed.parameters()))

In [201]:
embed

NodeEmbed(
  (node_embeds): ModuleDict(
    (author): Embedding(1134649, 128)
    (field_of_study): Embedding(59965, 128)
    (institution): Embedding(8740, 128)
  )
)

In [202]:
model

RGCN(
  (layers): ModuleList(
    (0): HeteroGraphConv(
      (mods): ModuleDict(
        (affiliated_with): GraphConv(in=128, out=128, normalization=both, activation=None)
        (writes): GraphConv(in=128, out=128, normalization=both, activation=None)
        (cites): GraphConv(in=128, out=128, normalization=both, activation=None)
        (has_topic): GraphConv(in=128, out=128, normalization=both, activation=None)
      )
    )
    (1): HeteroGraphConv(
      (mods): ModuleDict(
        (affiliated_with): GraphConv(in=128, out=349, normalization=both, activation=None)
        (writes): GraphConv(in=128, out=349, normalization=both, activation=None)
        (cites): GraphConv(in=128, out=349, normalization=both, activation=None)
        (has_topic): GraphConv(in=128, out=349, normalization=both, activation=None)
      )
    )
  )
)

## Defining Training Loop

validation 을 통해 score ( model selection ) 을 해야하는데 본 데이터에서는 train test(label)만 존재하므로 validation set 을 만들어줌. 

In [203]:
valid_nids # validation node ids

{'paper': tensor([   332,    756,    784,  ..., 736364, 736367, 736370])}

In [236]:
valid_dataloader = dgl.dataloading.NodeDataLoader(
    graph, valid_nids, sampler,
    batch_size=1024,
    shuffle=False,
    drop_last=False,
    num_workers=0
)

In [None]:
best_accuracy = 0
best_model_path = 'model.pt'
for epoch in range(100):
    model.train()
    
    with tqdm.tqdm(train_dataloader) as tq:
        for step, (input_nodes, output_nodes, bipartites) in enumerate(tq):
            bipartites = [b.to(torch.device('cuda')) for b in bipartites]
            
            # Get featureless input nodes and use the node embeddings as their initial representation 
            featureless_nodes = {ntype: node_ids for ntype, node_ids in input_nodes.items() if ntype != 'paper'}
            embeddings = {ntype: node_embedding.cuda() for ntype, node_embedding in embed(featureless_nodes).items()}
            
            # Get input features for node type 'paper' which has input features
            inputs = {'paper': node_features[input_nodes['paper']].cuda()}
            
            inputs.update(embeddings) # Merge feature inputs with input that has features
            
            labels = node_labels[output_nodes['paper']].cuda()
            predictions = model(bipartites, inputs)['paper']

            loss = F.cross_entropy(predictions, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

            accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy())
            
            tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % accuracy}, refresh=False)
        
    model.eval()
    
    predictions = []
    labels = []
    with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
        for input_nodes, output_nodes, bipartites in tq:
            bipartites = [b.to(torch.device('cuda')) for b in bipartites]
            
            featureless_nodes = {ntype: node_ids for ntype, node_ids in input_nodes.items() if ntype != "paper"}
            embeddings = {ntype: node_embedding.cuda() for ntype, node_embedding in embed(featureless_nodes).items()}
            inputs = {'paper': node_features[input_nodes['paper']].cuda()}
            inputs.update(embeddings)
            
            labels.append(node_labels[output_nodes['paper']].numpy())
            predictions.append(model(bipartites, inputs)['paper'].argmax(1).cpu().numpy())
        predictions = np.concatenate(predictions)
        labels = np.concatenate(labels)
        accuracy = sklearn.metrics.accuracy_score(labels, predictions)
        print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy))
        if best_accuracy < accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), best_model_path)

100%|██████████| 615/615 [00:43<00:00, 14.02it/s, loss=2.103, acc=0.423]
100%|██████████| 64/64 [00:03<00:00, 16.58it/s]
  0%|          | 1/615 [00:00<01:12,  8.43it/s, loss=2.147, acc=0.401]

Epoch 0 Validation Accuracy 0.3626751337104456


100%|██████████| 615/615 [00:42<00:00, 14.35it/s, loss=2.148, acc=0.411]
100%|██████████| 64/64 [00:03<00:00, 17.44it/s]
  0%|          | 1/615 [00:00<01:10,  8.74it/s, loss=2.138, acc=0.424]

Epoch 1 Validation Accuracy 0.35438277408714686


100%|██████████| 615/615 [00:42<00:00, 14.48it/s, loss=2.086, acc=0.429]
100%|██████████| 64/64 [00:03<00:00, 17.20it/s]
  0%|          | 0/615 [00:00<?, ?it/s]

Epoch 2 Validation Accuracy 0.3604247907643459


100%|██████████| 615/615 [00:43<00:00, 14.18it/s, loss=2.006, acc=0.419]
100%|██████████| 64/64 [00:03<00:00, 17.02it/s]
  0%|          | 1/615 [00:00<01:10,  8.71it/s, loss=2.070, acc=0.421]

Epoch 3 Validation Accuracy 0.36785400514804484


100%|██████████| 615/615 [00:42<00:00, 14.42it/s, loss=2.103, acc=0.424]
100%|██████████| 64/64 [00:03<00:00, 17.08it/s]
  0%|          | 1/615 [00:00<01:09,  8.86it/s, loss=2.054, acc=0.425]

Epoch 4 Validation Accuracy 0.3612262827725458


100%|██████████| 615/615 [00:42<00:00, 14.39it/s, loss=2.118, acc=0.413]
100%|██████████| 64/64 [00:03<00:00, 16.84it/s]
  0%|          | 1/615 [00:00<01:17,  7.92it/s, loss=1.984, acc=0.430]

Epoch 5 Validation Accuracy 0.3711370397200943


100%|██████████| 615/615 [00:43<00:00, 14.17it/s, loss=2.066, acc=0.434]
100%|██████████| 64/64 [00:03<00:00, 16.94it/s]
  0%|          | 1/615 [00:00<01:12,  8.49it/s, loss=2.050, acc=0.420]

Epoch 6 Validation Accuracy 0.3645093173445953


100%|██████████| 615/615 [00:43<00:00, 14.25it/s, loss=2.077, acc=0.418]
100%|██████████| 64/64 [00:03<00:00, 17.02it/s]
  0%|          | 1/615 [00:00<01:09,  8.82it/s, loss=2.002, acc=0.440]

Epoch 7 Validation Accuracy 0.35899135313429614


100%|██████████| 615/615 [00:47<00:00, 12.92it/s, loss=1.948, acc=0.457]
100%|██████████| 64/64 [00:04<00:00, 13.11it/s]
  0%|          | 1/615 [00:00<01:51,  5.52it/s, loss=2.039, acc=0.464]

Epoch 8 Validation Accuracy 0.36904082985249465


100%|██████████| 615/615 [00:42<00:00, 14.52it/s, loss=2.049, acc=0.444]
100%|██████████| 64/64 [00:03<00:00, 17.37it/s]
  0%|          | 1/615 [00:00<01:09,  8.79it/s, loss=1.952, acc=0.476]

Epoch 9 Validation Accuracy 0.36526456942924523


100%|██████████| 615/615 [00:43<00:00, 14.24it/s, loss=1.929, acc=0.438]
100%|██████████| 64/64 [00:03<00:00, 17.16it/s]
  0%|          | 1/615 [00:00<01:09,  8.82it/s, loss=1.948, acc=0.451]

Epoch 10 Validation Accuracy 0.37482082029624375


100%|██████████| 615/615 [00:44<00:00, 13.88it/s, loss=2.097, acc=0.424]
100%|██████████| 64/64 [00:04<00:00, 15.55it/s]
  0%|          | 1/615 [00:00<01:11,  8.55it/s, loss=1.954, acc=0.462]

Epoch 11 Validation Accuracy 0.3650025431957952


100%|██████████| 615/615 [00:49<00:00, 12.36it/s, loss=2.042, acc=0.398]
100%|██████████| 64/64 [00:04<00:00, 14.15it/s]
  0%|          | 0/615 [00:00<?, ?it/s]

Epoch 12 Validation Accuracy 0.3727708503521941


100%|██████████| 615/615 [00:52<00:00, 11.66it/s, loss=1.946, acc=0.446]
100%|██████████| 64/64 [00:04<00:00, 14.39it/s]
  0%|          | 1/615 [00:00<01:15,  8.10it/s, loss=1.866, acc=0.470]

Epoch 13 Validation Accuracy 0.3647096903466453


100%|██████████| 615/615 [00:50<00:00, 12.18it/s, loss=1.985, acc=0.438]
100%|██████████| 64/64 [00:04<00:00, 14.80it/s]
  0%|          | 1/615 [00:00<01:08,  8.92it/s, loss=1.867, acc=0.468]

Epoch 14 Validation Accuracy 0.36771528537739484


100%|██████████| 615/615 [00:42<00:00, 14.56it/s, loss=2.029, acc=0.431]
100%|██████████| 64/64 [00:03<00:00, 17.39it/s]
  0%|          | 1/615 [00:00<01:08,  8.93it/s, loss=1.939, acc=0.480]

Epoch 15 Validation Accuracy 0.3552305060188967


100%|██████████| 615/615 [00:42<00:00, 14.53it/s, loss=2.012, acc=0.453]
100%|██████████| 64/64 [00:03<00:00, 17.33it/s]
  0%|          | 1/615 [00:00<01:09,  8.85it/s, loss=1.919, acc=0.460]

Epoch 16 Validation Accuracy 0.3567255968803465


100%|██████████| 615/615 [00:42<00:00, 14.54it/s, loss=1.890, acc=0.453]
100%|██████████| 64/64 [00:03<00:00, 17.25it/s]
  0%|          | 1/615 [00:00<01:14,  8.25it/s, loss=1.977, acc=0.427]

Epoch 17 Validation Accuracy 0.3645401439602953


100%|██████████| 615/615 [00:41<00:00, 14.66it/s, loss=2.000, acc=0.416]
100%|██████████| 64/64 [00:03<00:00, 17.59it/s]
  0%|          | 1/615 [00:00<01:09,  8.85it/s, loss=1.907, acc=0.471]

Epoch 18 Validation Accuracy 0.3619661215493457


100%|██████████| 615/615 [00:42<00:00, 14.61it/s, loss=2.039, acc=0.454]
100%|██████████| 64/64 [00:03<00:00, 17.22it/s]
  0%|          | 1/615 [00:00<01:18,  7.86it/s, loss=1.868, acc=0.471]

Epoch 19 Validation Accuracy 0.37030472109619444


100%|██████████| 615/615 [00:42<00:00, 14.54it/s, loss=1.852, acc=0.469]
100%|██████████| 64/64 [00:03<00:00, 17.56it/s]
  0%|          | 1/615 [00:00<01:17,  7.96it/s, loss=1.904, acc=0.460]

Epoch 20 Validation Accuracy 0.3630758797145455


100%|██████████| 615/615 [00:42<00:00, 14.57it/s, loss=1.872, acc=0.469]
100%|██████████| 64/64 [00:03<00:00, 17.18it/s]
  0%|          | 1/615 [00:00<01:08,  8.92it/s, loss=1.846, acc=0.472]

Epoch 21 Validation Accuracy 0.36159620216094573


100%|██████████| 615/615 [00:41<00:00, 14.65it/s, loss=1.883, acc=0.465]
100%|██████████| 64/64 [00:03<00:00, 16.94it/s]
  0%|          | 1/615 [00:00<01:11,  8.61it/s, loss=1.867, acc=0.462]

Epoch 22 Validation Accuracy 0.36170409531589576


100%|██████████| 615/615 [00:41<00:00, 14.73it/s, loss=1.845, acc=0.467]
100%|██████████| 64/64 [00:03<00:00, 16.49it/s]
  0%|          | 1/615 [00:00<01:08,  9.02it/s, loss=1.920, acc=0.468]

Epoch 23 Validation Accuracy 0.36376947856779546


100%|██████████| 615/615 [00:42<00:00, 14.59it/s, loss=1.816, acc=0.498]
100%|██████████| 64/64 [00:03<00:00, 17.83it/s]
  0%|          | 1/615 [00:00<01:09,  8.83it/s, loss=1.823, acc=0.498]

Epoch 24 Validation Accuracy 0.36651304736509505


100%|██████████| 615/615 [00:42<00:00, 14.52it/s, loss=1.861, acc=0.467]
100%|██████████| 64/64 [00:03<00:00, 17.17it/s]
  0%|          | 1/615 [00:00<01:10,  8.76it/s, loss=1.934, acc=0.467]

Epoch 25 Validation Accuracy 0.3671449929869449


100%|██████████| 615/615 [00:42<00:00, 14.47it/s, loss=1.976, acc=0.440]
100%|██████████| 64/64 [00:03<00:00, 16.96it/s]
  0%|          | 1/615 [00:00<01:19,  7.74it/s, loss=1.918, acc=0.458]

Epoch 26 Validation Accuracy 0.36407774472479537


100%|██████████| 615/615 [00:42<00:00, 14.59it/s, loss=1.880, acc=0.454]
100%|██████████| 64/64 [00:03<00:00, 17.46it/s]
  0%|          | 1/615 [00:00<01:10,  8.75it/s, loss=1.743, acc=0.498]

Epoch 27 Validation Accuracy 0.36836264430709476


100%|██████████| 615/615 [00:42<00:00, 14.52it/s, loss=1.887, acc=0.465]
100%|██████████| 64/64 [00:03<00:00, 17.38it/s]
  0%|          | 1/615 [00:00<01:10,  8.77it/s, loss=1.869, acc=0.463]

Epoch 28 Validation Accuracy 0.3644168374974953


100%|██████████| 615/615 [00:42<00:00, 14.46it/s, loss=1.833, acc=0.467]
100%|██████████| 64/64 [00:03<00:00, 17.05it/s]
  0%|          | 1/615 [00:00<01:10,  8.74it/s, loss=1.887, acc=0.458]

Epoch 29 Validation Accuracy 0.36171950862374574


100%|██████████| 615/615 [00:42<00:00, 14.49it/s, loss=1.928, acc=0.432]
100%|██████████| 64/64 [00:03<00:00, 17.09it/s]
  0%|          | 1/615 [00:00<01:11,  8.65it/s, loss=1.908, acc=0.446]

Epoch 30 Validation Accuracy 0.3580357280475963


100%|██████████| 615/615 [00:42<00:00, 14.38it/s, loss=1.886, acc=0.446]
100%|██████████| 64/64 [00:03<00:00, 17.13it/s]
  0%|          | 1/615 [00:00<01:11,  8.64it/s, loss=1.857, acc=0.455]

Epoch 31 Validation Accuracy 0.3572342360393964


100%|██████████| 615/615 [00:43<00:00, 14.28it/s, loss=1.785, acc=0.468]
100%|██████████| 64/64 [00:03<00:00, 17.21it/s]
  0%|          | 1/615 [00:00<01:24,  7.23it/s, loss=1.892, acc=0.450]

Epoch 32 Validation Accuracy 0.3655265956626952


100%|██████████| 615/615 [00:43<00:00, 14.12it/s, loss=1.813, acc=0.472]
100%|██████████| 64/64 [00:03<00:00, 16.85it/s]
  0%|          | 1/615 [00:00<01:12,  8.44it/s, loss=1.915, acc=0.468]

Epoch 33 Validation Accuracy 0.36911789639174464


100%|██████████| 615/615 [00:42<00:00, 14.47it/s, loss=1.911, acc=0.454]
100%|██████████| 64/64 [00:03<00:00, 17.68it/s]
  0%|          | 1/615 [00:00<01:11,  8.60it/s, loss=1.833, acc=0.491]

Epoch 34 Validation Accuracy 0.36911789639174464


100%|██████████| 615/615 [00:42<00:00, 14.52it/s, loss=1.792, acc=0.478]
100%|██████████| 64/64 [00:03<00:00, 17.69it/s]
  0%|          | 1/615 [00:00<01:18,  7.87it/s, loss=1.716, acc=0.508]

Epoch 35 Validation Accuracy 0.3713374127221443


100%|██████████| 615/615 [00:41<00:00, 14.71it/s, loss=1.970, acc=0.444]
100%|██████████| 64/64 [00:03<00:00, 16.90it/s]
  0%|          | 1/615 [00:00<01:13,  8.40it/s, loss=1.868, acc=0.482]

Epoch 36 Validation Accuracy 0.36749949906749485


100%|██████████| 615/615 [00:42<00:00, 14.54it/s, loss=1.905, acc=0.471]
100%|██████████| 64/64 [00:03<00:00, 16.79it/s]
  0%|          | 1/615 [00:00<01:20,  7.62it/s, loss=1.762, acc=0.481]

Epoch 37 Validation Accuracy 0.3586368470537462


100%|██████████| 615/615 [00:42<00:00, 14.50it/s, loss=1.844, acc=0.456]
100%|██████████| 64/64 [00:03<00:00, 17.36it/s]
  0%|          | 1/615 [00:00<01:10,  8.70it/s, loss=1.816, acc=0.472]

Epoch 38 Validation Accuracy 0.3586985002851462


100%|██████████| 615/615 [00:42<00:00, 14.54it/s, loss=1.798, acc=0.467]
100%|██████████| 64/64 [00:03<00:00, 17.42it/s]
  0%|          | 1/615 [00:00<01:10,  8.67it/s, loss=1.799, acc=0.487]

Epoch 39 Validation Accuracy 0.3605018573035959


100%|██████████| 615/615 [00:42<00:00, 14.60it/s, loss=1.913, acc=0.460]
100%|██████████| 64/64 [00:03<00:00, 17.50it/s]
  0%|          | 1/615 [00:00<01:09,  8.80it/s, loss=1.806, acc=0.497]

Epoch 40 Validation Accuracy 0.36010111129949596


100%|██████████| 615/615 [00:41<00:00, 14.66it/s, loss=1.911, acc=0.451]
100%|██████████| 64/64 [00:03<00:00, 17.36it/s]
  0%|          | 1/615 [00:00<01:08,  8.96it/s, loss=1.732, acc=0.504]

Epoch 41 Validation Accuracy 0.35552335886804665


100%|██████████| 615/615 [00:42<00:00, 14.48it/s, loss=1.886, acc=0.484]
100%|██████████| 64/64 [00:03<00:00, 17.43it/s]
  0%|          | 1/615 [00:00<01:12,  8.44it/s, loss=1.783, acc=0.490]

Epoch 42 Validation Accuracy 0.3675149123753449


100%|██████████| 615/615 [00:42<00:00, 14.57it/s, loss=1.833, acc=0.473]
100%|██████████| 64/64 [00:03<00:00, 16.44it/s]
  0%|          | 1/615 [00:00<01:11,  8.55it/s, loss=1.729, acc=0.499]

Epoch 43 Validation Accuracy 0.3681160313814948


100%|██████████| 615/615 [00:42<00:00, 14.52it/s, loss=1.870, acc=0.450]
100%|██████████| 64/64 [00:03<00:00, 16.90it/s]
  0%|          | 1/615 [00:00<01:16,  7.99it/s, loss=1.690, acc=0.508]

Epoch 44 Validation Accuracy 0.3608563633841459


100%|██████████| 615/615 [00:42<00:00, 14.52it/s, loss=1.866, acc=0.486]
100%|██████████| 64/64 [00:03<00:00, 17.14it/s]
  0%|          | 1/615 [00:00<01:08,  8.93it/s, loss=1.753, acc=0.503]

Epoch 45 Validation Accuracy 0.36572696866474513


100%|██████████| 615/615 [00:42<00:00, 14.47it/s, loss=1.750, acc=0.487]
100%|██████████| 64/64 [00:03<00:00, 17.40it/s]
  0%|          | 1/615 [00:00<01:11,  8.55it/s, loss=1.748, acc=0.485]

Epoch 46 Validation Accuracy 0.36503336981149526


100%|██████████| 615/615 [00:42<00:00, 14.57it/s, loss=1.861, acc=0.473]
100%|██████████| 64/64 [00:03<00:00, 17.37it/s]
  0%|          | 1/615 [00:00<01:19,  7.72it/s, loss=1.733, acc=0.507]

Epoch 47 Validation Accuracy 0.3632300127930455


100%|██████████| 615/615 [00:42<00:00, 14.60it/s, loss=1.745, acc=0.487]
100%|██████████| 64/64 [00:03<00:00, 17.12it/s]
  0%|          | 1/615 [00:00<01:09,  8.81it/s, loss=1.764, acc=0.483]

Epoch 48 Validation Accuracy 0.36777693860879485


100%|██████████| 615/615 [00:41<00:00, 14.65it/s, loss=1.924, acc=0.456]
100%|██████████| 64/64 [00:03<00:00, 16.77it/s]
  0%|          | 1/615 [00:00<01:11,  8.56it/s, loss=1.782, acc=0.477]

Epoch 49 Validation Accuracy 0.36225897439849564


100%|██████████| 615/615 [00:42<00:00, 14.49it/s, loss=1.822, acc=0.454]
100%|██████████| 64/64 [00:03<00:00, 16.14it/s]
  0%|          | 1/615 [00:00<01:14,  8.28it/s, loss=1.787, acc=0.504]

Epoch 50 Validation Accuracy 0.3625518272476456


100%|██████████| 615/615 [00:42<00:00, 14.55it/s, loss=1.776, acc=0.481]
100%|██████████| 64/64 [00:03<00:00, 17.62it/s]
  0%|          | 1/615 [00:00<01:07,  9.10it/s, loss=1.795, acc=0.488]

Epoch 51 Validation Accuracy 0.36768445876169487


100%|██████████| 615/615 [00:41<00:00, 14.67it/s, loss=1.764, acc=0.472]
100%|██████████| 64/64 [00:03<00:00, 17.43it/s]
  0%|          | 1/615 [00:00<01:08,  8.93it/s, loss=1.867, acc=0.467]

Epoch 52 Validation Accuracy 0.3612879360039458


100%|██████████| 615/615 [00:42<00:00, 14.64it/s, loss=1.842, acc=0.486]
100%|██████████| 64/64 [00:03<00:00, 16.72it/s]
  0%|          | 1/615 [00:00<01:09,  8.78it/s, loss=1.748, acc=0.497]

Epoch 53 Validation Accuracy 0.359823671758196


100%|██████████| 615/615 [00:42<00:00, 14.57it/s, loss=1.806, acc=0.473]
100%|██████████| 64/64 [00:03<00:00, 17.68it/s]
  0%|          | 1/615 [00:00<01:09,  8.80it/s, loss=1.639, acc=0.517]

Epoch 54 Validation Accuracy 0.35646357064689654


100%|██████████| 615/615 [00:42<00:00, 14.56it/s, loss=1.745, acc=0.480]
100%|██████████| 64/64 [00:03<00:00, 17.71it/s]
  0%|          | 1/615 [00:00<01:11,  8.58it/s, loss=1.771, acc=0.485]

Epoch 55 Validation Accuracy 0.36119545615684584


100%|██████████| 615/615 [00:41<00:00, 14.65it/s, loss=1.883, acc=0.453]
100%|██████████| 64/64 [00:03<00:00, 17.45it/s]
  0%|          | 1/615 [00:00<01:10,  8.77it/s, loss=1.736, acc=0.473]

Epoch 56 Validation Accuracy 0.3689945899289446


100%|██████████| 615/615 [00:42<00:00, 14.41it/s, loss=1.777, acc=0.486]
100%|██████████| 64/64 [00:03<00:00, 16.48it/s]
  0%|          | 1/615 [00:00<01:11,  8.63it/s, loss=1.725, acc=0.491]

Epoch 57 Validation Accuracy 0.36017817783874595


100%|██████████| 615/615 [00:43<00:00, 14.26it/s, loss=1.785, acc=0.481]
100%|██████████| 64/64 [00:03<00:00, 17.04it/s]
  0%|          | 1/615 [00:00<01:13,  8.40it/s, loss=1.724, acc=0.474]

Epoch 58 Validation Accuracy 0.3586830869772962


 63%|██████▎   | 385/615 [00:27<00:16, 14.07it/s, loss=1.706, acc=0.496]

In [None]:
def inference(model, graph, input_features, batch_size):
    nodes = {ntype: torch.arange(graph.number_of_nodes(ntype)) for ntype in graph.ntypes}
    
    sampler = dgl.dataloading.MultiLayerNeighborSampler([None])  # one layer at a time, taking all neighbors
    dataloader = dgl.dataloading.NodeDataLoader(
        graph, nodes, sampler
        ,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=0)
    
    with torch.no_grad():
        for l, layer in enumerate(model.layers):
            # Allocate a buffer of output representations for every node
            # Note that the buffer is on CPU memory.
            output_features = {ntype: torch.zeros(
                graph.number_of_nodes(ntype), model.n_hidden if l != model.n_layers - 1 else model.n_classes)
                for ntype in graph.ntypes}

            for input_nodes, output_nodes, bipartites in tqdm.tqdm(dataloader):
                bipartite = bipartites[0].to(torch.device('cuda'))

                # send features for nodes in batch to gpu 
                x = {ntype: input_features[ntype][input_nodes[ntype]].cuda() for ntype in input_nodes}

                # the following code is identical to the loop body in model.forward()
                x = layer(bipartite, x)
                if l != model.n_layers - 1:
                    x = {k: F.relu(v) for k, v in x.items()}
                
                for ntype in x:
                    output_features[ntype][output_nodes[ntype]] = x[ntype].cpu()
            input_features = output_features
    return output_features

In [None]:
model.load_state_dict(torch.load(best_model_path))

featureless_nodes = {ntype: torch.arange(num_nodes_ntype) for ntype, num_nodes_ntype in num_nodes.items()}
embeddings = {ntype: node_embedding for ntype, node_embedding in embed(featureless_nodes).items()}
inputs = {'paper': node_features}
inputs.update(embeddings)

all_predictions = inference(model, graph, inputs, 8192)

In [None]:
test_predictions = all_predictions['paper'][test_nids['paper']].argmax(1)
test_labels = node_labels[test_nids['paper']]
test_accuracy = sklearn.metrics.accuracy_score(test_predictions.numpy(), test_labels.numpy())
print('Test accuracy:', test_accuracy)