In [1]:
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

from torch_geometric.datasets import WikiCS
from torch_geometric.data import ClusterData, ClusterLoader

In [2]:
dataset = WikiCS(root='./data')[0]

print(len(dataset))

7


In [3]:
NUM_CLASSES = len(dataset.y.unique())

In [4]:
# Clusters/partitions a graph data object into multiple subgraphs

cluster_data = ClusterData(data=dataset, num_parts=20)

data_loader = ClusterLoader(cluster_data, batch_size=5, num_workers=5, shuffle=True)

Computing METIS partitioning...
Done!


In [6]:
cluster_data[1]

Data(edge_index=[2, 831], stopping_mask=[568, 20], test_mask=[568], train_mask=[568, 20], val_mask=[568, 20], x=[568, 300], y=[568])

In [7]:
from torch_geometric.nn import SAGEConv

In [8]:
class WikiNode(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        
        self.convs = nn.ModuleList()
        
        self.convs.append(SAGEConv(dataset.num_features, 128))
        
        self.convs.append(SAGEConv(128, 128))
        
        self.out_conv = SAGEConv(128, NUM_CLASSES)
        
    def forward(self, data):
        
        x, edge_index = data.x, data.edge_index
        
        for conv in self.convs:
                        
            x = conv(x, edge_index)
            
            x = F.relu(x)
            
        return self.out_conv(x, edge_index)

In [9]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [10]:
node_clf = WikiNode().to(DEVICE)

opt = torch.optim.Adam(node_clf.parameters(), weight_decay=0.001)

In [11]:
@torch.no_grad()
def masked_acc(logits, y, mask):
    
    a_max = torch.argmax(logits, dim=-1)
    
    eq = (a_max == y).float() * mask.float()
            
    return eq.sum() / mask.float().sum()

In [12]:
def masked_cross_entropy(logits, y, mask):
    
    masked_y = y.masked_fill(mask == False, value=-1)
    
    loss = F.cross_entropy(logits, masked_y, ignore_index=-1)
    
    return loss

In [13]:
def train_one_epoch():
        
    losses = []

    for data in data_loader:
        
        data = data.to(DEVICE)
        
        logits = node_clf(data)
        
        mask = data.train_mask[:, 0]
        
        loss = masked_cross_entropy(logits, data.y, mask)
                
        # backprop + update
        loss.backward()
        opt.step()
        opt.zero_grad()

        losses.append(loss.item())
        
    return np.array(losses).mean()

In [14]:
@torch.no_grad()
def evaluate():
        
    accuracies = []

    for data in data_loader:
        
        data = data.to(DEVICE)
        
        logits = node_clf(data)

        mask = data.test_mask
                
        # compute accuracy
        acc = masked_acc(logits, data.y, mask)

        accuracies.append(acc.item())

    return np.array(accuracies).mean()

In [15]:
for i in range(50):
    
    loss = train_one_epoch()
    
    acc = evaluate()
    
    print(f'loss = {loss: .3f}\t acc = {acc: .3f}')

loss =  2.289	 acc =  0.227
loss =  2.197	 acc =  0.229
loss =  2.115	 acc =  0.232
loss =  2.039	 acc =  0.306
loss =  1.967	 acc =  0.439
loss =  1.905	 acc =  0.508
loss =  1.820	 acc =  0.530
loss =  1.714	 acc =  0.528
loss =  1.629	 acc =  0.531
loss =  1.533	 acc =  0.534
loss =  1.432	 acc =  0.541
loss =  1.346	 acc =  0.546
loss =  1.269	 acc =  0.552
loss =  1.224	 acc =  0.572
loss =  1.166	 acc =  0.591
loss =  1.124	 acc =  0.617
loss =  1.073	 acc =  0.624
loss =  1.037	 acc =  0.652
loss =  0.989	 acc =  0.663
loss =  0.951	 acc =  0.664
loss =  0.940	 acc =  0.687
loss =  0.886	 acc =  0.695
loss =  0.882	 acc =  0.691
loss =  0.844	 acc =  0.702
loss =  0.818	 acc =  0.710
loss =  0.795	 acc =  0.716
loss =  0.754	 acc =  0.710
loss =  0.743	 acc =  0.735
loss =  0.742	 acc =  0.733
loss =  0.706	 acc =  0.736
loss =  0.660	 acc =  0.742
loss =  0.649	 acc =  0.734
loss =  0.629	 acc =  0.746
loss =  0.632	 acc =  0.747
loss =  0.586	 acc =  0.752
loss =  0.570	 acc =