In [1]:
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

from torch_geometric.datasets import WikiCS
from torch_geometric.data import GraphSAINTRandomWalkSampler

In [2]:
dataset = WikiCS(root='./data')[0]

print(len(dataset))

7


In [3]:
NUM_CLASSES = len(dataset.y.unique())

In [4]:
# Clusters/partitions a graph data object into multiple subgraphs

data_loader = GraphSAINTRandomWalkSampler(data=dataset, batch_size=1024, walk_length=2, num_steps=5, num_workers=5)

# num_steps: number of iteration per epoch

In [14]:
bat = next(iter(data_loader))

In [15]:
bat

Data(edge_index=[2, 39970], stopping_mask=[2192, 20], test_mask=[2192], train_mask=[2192, 20], val_mask=[2192, 20], x=[2192, 300], y=[2192])

In [16]:
bat.train_mask

tensor([[False, False, False,  ..., False, False, False],
        [False,  True, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

In [6]:
from torch_geometric.nn import GCNConv

In [13]:
class WikiNode(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        
        self.convs = nn.ModuleList()
        
        self.convs.append(GCNConv(dataset.num_features, 128))
        
        self.convs.append(GCNConv(128, 128))
        
        self.out_conv = GCNConv(128, NUM_CLASSES)
        
    def forward(self, data):
        
        x, edge_index = data.x, data.edge_index
        
        for conv in self.convs:
                        
            x = conv(x, edge_index)
            
            x = F.relu(x)
            
        return self.out_conv(x, edge_index)

In [14]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [15]:
node_clf = WikiNode().to(DEVICE)

opt = torch.optim.Adam(node_clf.parameters(), weight_decay=0.001)

In [16]:
@torch.no_grad()
def masked_acc(logits, y, mask):
    
    a_max = torch.argmax(logits, dim=-1)
    
    eq = (a_max == y).float() * mask.float()
            
    return eq.sum() / mask.float().sum()

In [17]:
def masked_cross_entropy(logits, y, mask):
    
    masked_y = y.masked_fill(mask == False, value=-1)
    
    loss = F.cross_entropy(logits, masked_y, ignore_index=-1)
    
    return loss

In [19]:
def train_one_epoch():
        
    losses = []

    for data in data_loader:
        
        data = data.to(DEVICE)
        
        logits = node_clf(data)
        
        mask = data.train_mask[:, 0]
        
        loss = masked_cross_entropy(logits, data.y, mask)
                
        # backprop + update
        loss.backward()
        opt.step()
        opt.zero_grad()

        losses.append(loss.item())
        
    return np.array(losses).mean()

In [20]:
@torch.no_grad()
def evaluate():
        
    accuracies = []

    for data in data_loader:
        
        data = data.to(DEVICE)
        
        logits = node_clf(data)

        mask = data.test_mask
                
        # compute accuracy
        acc = masked_acc(logits, data.y, mask)

        accuracies.append(acc.item())

    return np.array(accuracies).mean()

In [21]:
for i in range(30):
    
    loss = train_one_epoch()
    
    acc = evaluate()
    
    print(f'loss = {loss: .3f}\t acc = {acc: .3f}')

loss =  2.252	 acc =  0.218
loss =  2.082	 acc =  0.202
loss =  2.001	 acc =  0.273
loss =  1.943	 acc =  0.469
loss =  1.844	 acc =  0.462
loss =  1.769	 acc =  0.563
loss =  1.638	 acc =  0.582
loss =  1.557	 acc =  0.575
loss =  1.458	 acc =  0.588
loss =  1.375	 acc =  0.618
loss =  1.344	 acc =  0.630
loss =  1.201	 acc =  0.623
loss =  1.165	 acc =  0.653
loss =  1.119	 acc =  0.648
loss =  1.051	 acc =  0.662
loss =  0.971	 acc =  0.667
loss =  0.989	 acc =  0.673
loss =  1.011	 acc =  0.669
loss =  0.967	 acc =  0.699
loss =  0.882	 acc =  0.695
loss =  0.907	 acc =  0.679
loss =  0.869	 acc =  0.693
loss =  0.932	 acc =  0.709
loss =  0.871	 acc =  0.704
loss =  0.810	 acc =  0.699
loss =  0.878	 acc =  0.714
loss =  0.794	 acc =  0.727
loss =  0.830	 acc =  0.716
loss =  0.819	 acc =  0.709
loss =  0.798	 acc =  0.739
