In [1]:
import torch
import torch.nn as nn

import torch_sparse
import torch_geometric.transforms as T

from ogb.nodeproppred import PygNodePropPredDataset, Evaluator

In [2]:
class Conv(nn.Module):
    def __init__(self, input_channels, out_channels):
        super(Conv, self).__init__()

        self.weight = nn.Parameter(torch.Tensor(input_channels, out_channels))
        self.bias = nn.Parameter(torch.Tensor(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        nn.init.zeros_(self.bias)
        
    def forward(self, x, adj):
        x = torch.matmul(x, self.weight)
        out = torch_sparse.matmul(adj, x) + self.bias
        return out

class GCN(nn.Module):
    def __init__(self, input_channels, out_channels, hidden_channels, num_layers, p):
        super(GCN, self).__init__()
        self.convs = nn.ModuleList()
        self.convs.append(Conv(input_channels, hidden_channels))
        
        self.bns = nn.ModuleList()
        self.bns.append(nn.BatchNorm1d(hidden_channels))
        
        for _ in range(num_layers - 2):
            self.convs.append(Conv(hidden_channels, hidden_channels))
            self.bns.append(nn.BatchNorm1d(hidden_channels))
            
        self.convs.append(Conv(hidden_channels, out_channels))
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p)
        self.num_layers = num_layers
        
    def forward(self, x, adj):
        for i in range(self.num_layers - 1):
            x = self.convs[i](x, adj)
            x = self.bns[i](x)
            x = self.relu(x)
            x = self.dropout(x)
        out = self.convs[-1](x, adj)
        return out

def train(model, data, train_idx, optimizer, criterion):
    model.train()
    
    model.zero_grad()
    outputs = model(data.x, data.adj_t)[train_idx]
    loss = criterion(outputs, data.y.squeeze(1)[train_idx])
    loss.backward()
    optimizer.step()

    return loss.item()

@torch.no_grad()
def test(model, data, split_idx, evaluator):
    model.eval()
    
    outputs = model(data.x, data.adj_t)
    y_pred = outputs.argmax(dim=1, keepdim=True)

    train_acc = evaluator.eval({
        'y_true': data.y[split_idx['train']],
        'y_pred': y_pred[split_idx['train']],
    })['acc']
    valid_acc = evaluator.eval({
        'y_true': data.y[split_idx['valid']],
        'y_pred': y_pred[split_idx['valid']],
    })['acc']
    test_acc = evaluator.eval({
        'y_true': data.y[split_idx['test']],
        'y_pred': y_pred[split_idx['test']],
    })['acc']

    return train_acc, valid_acc, test_acc

In [3]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

hidden_channels = 256
num_layers = 3
dropout = 0.5
epochs = 500
log_steps = 1
lr = 0.01

In [4]:
dataset = PygNodePropPredDataset(name='ogbn-arxiv', root='../dataset', transform=T.ToSparseTensor())

data = dataset[0]
data.adj_t = data.adj_t.to_symmetric()
data = data.to(device)

In [5]:
split_idx = dataset.get_idx_split()
train_idx = split_idx['train'].to(device)

In [6]:
model = GCN(data.num_features, dataset.num_classes, hidden_channels, num_layers, dropout).to(device)

In [7]:
evaluator = Evaluator(name='ogbn-arxiv')
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

In [8]:
data.adj_t = torch_sparse.fill_diag(data.adj_t, 1)
deg = torch_sparse.sum(data.adj_t, 0).pow_(-0.5)
data.adj_t = torch_sparse.mul(data.adj_t, deg.view(-1, 1))
data.adj_t = torch_sparse.mul(data.adj_t, deg.view(1, -1))

test_scores = []
for epoch in range(1, 1 + epochs):
    loss = train(model, data, train_idx, optimizer, criterion)
    result = test(model, data, split_idx, evaluator)

    if epoch % log_steps == 0:
        train_acc, valid_acc, test_acc = result
        test_scores.append(test_acc)
        print(f'Run: {1:02d}, '
              f'Epoch: {epoch:02d}, '
              f'Loss: {loss:.4f}, '
              f'Train: {100 * train_acc:.2f}%, '
              f'Valid: {100 * valid_acc:.2f}% '
              f'Test: {100 * test_acc:.2f}%')
print(f"Best test accuracy: {max(test_scores) * 100:.2f}%")

Run: 01, Epoch: 01, Loss: 3.9477, Train: 26.65%, Valid: 29.29% Test: 26.34%
Run: 01, Epoch: 02, Loss: 2.3603, Train: 29.00%, Valid: 30.13% Test: 35.09%
Run: 01, Epoch: 03, Loss: 1.9769, Train: 29.28%, Valid: 29.36% Test: 32.77%
Run: 01, Epoch: 04, Loss: 1.7967, Train: 28.94%, Valid: 24.46% Test: 23.08%
Run: 01, Epoch: 05, Loss: 1.6758, Train: 34.74%, Valid: 29.80% Test: 29.49%
Run: 01, Epoch: 06, Loss: 1.5711, Train: 41.19%, Valid: 37.40% Test: 41.55%
Run: 01, Epoch: 07, Loss: 1.4975, Train: 41.76%, Valid: 38.06% Test: 42.68%
Run: 01, Epoch: 08, Loss: 1.4542, Train: 43.07%, Valid: 37.61% Test: 41.57%
Run: 01, Epoch: 09, Loss: 1.4036, Train: 44.34%, Valid: 34.39% Test: 36.70%
Run: 01, Epoch: 10, Loss: 1.3596, Train: 45.19%, Valid: 32.98% Test: 34.49%
Run: 01, Epoch: 11, Loss: 1.3296, Train: 46.05%, Valid: 34.82% Test: 37.33%
Run: 01, Epoch: 12, Loss: 1.3057, Train: 46.39%, Valid: 37.94% Test: 41.49%
Run: 01, Epoch: 13, Loss: 1.2796, Train: 47.41%, Valid: 43.46% Test: 47.23%
Run: 01, Epo

Run: 01, Epoch: 109, Loss: 0.9026, Train: 73.91%, Valid: 71.48% Test: 70.08%
Run: 01, Epoch: 110, Loss: 0.8995, Train: 74.06%, Valid: 71.95% Test: 71.09%
Run: 01, Epoch: 111, Loss: 0.8998, Train: 74.11%, Valid: 71.94% Test: 71.19%
Run: 01, Epoch: 112, Loss: 0.8965, Train: 74.21%, Valid: 71.74% Test: 70.52%
Run: 01, Epoch: 113, Loss: 0.8942, Train: 74.18%, Valid: 71.87% Test: 70.98%
Run: 01, Epoch: 114, Loss: 0.8948, Train: 73.91%, Valid: 71.38% Test: 70.68%
Run: 01, Epoch: 115, Loss: 0.8939, Train: 74.13%, Valid: 71.16% Test: 69.91%
Run: 01, Epoch: 116, Loss: 0.8898, Train: 74.12%, Valid: 70.90% Test: 69.21%
Run: 01, Epoch: 117, Loss: 0.8901, Train: 74.19%, Valid: 72.00% Test: 71.30%
Run: 01, Epoch: 118, Loss: 0.8904, Train: 74.26%, Valid: 72.12% Test: 71.43%
Run: 01, Epoch: 119, Loss: 0.8885, Train: 74.46%, Valid: 71.69% Test: 70.52%
Run: 01, Epoch: 120, Loss: 0.8889, Train: 74.54%, Valid: 71.67% Test: 70.43%
Run: 01, Epoch: 121, Loss: 0.8888, Train: 74.51%, Valid: 71.82% Test: 71.36%

Run: 01, Epoch: 216, Loss: 0.8026, Train: 76.85%, Valid: 72.93% Test: 72.12%
Run: 01, Epoch: 217, Loss: 0.8000, Train: 76.88%, Valid: 72.63% Test: 71.68%
Run: 01, Epoch: 218, Loss: 0.8024, Train: 76.41%, Valid: 70.76% Test: 68.21%
Run: 01, Epoch: 219, Loss: 0.8025, Train: 75.85%, Valid: 71.64% Test: 71.58%
Run: 01, Epoch: 220, Loss: 0.8013, Train: 76.90%, Valid: 72.38% Test: 71.25%
Run: 01, Epoch: 221, Loss: 0.8014, Train: 76.57%, Valid: 72.26% Test: 71.37%
Run: 01, Epoch: 222, Loss: 0.8009, Train: 76.58%, Valid: 72.18% Test: 70.91%
Run: 01, Epoch: 223, Loss: 0.8000, Train: 76.91%, Valid: 72.31% Test: 71.29%
Run: 01, Epoch: 224, Loss: 0.8010, Train: 76.94%, Valid: 72.46% Test: 71.62%
Run: 01, Epoch: 225, Loss: 0.7958, Train: 76.66%, Valid: 72.54% Test: 72.09%
Run: 01, Epoch: 226, Loss: 0.7962, Train: 77.03%, Valid: 71.94% Test: 69.87%
Run: 01, Epoch: 227, Loss: 0.7950, Train: 76.98%, Valid: 72.24% Test: 70.64%
Run: 01, Epoch: 228, Loss: 0.7960, Train: 76.81%, Valid: 71.36% Test: 69.30%

Run: 01, Epoch: 323, Loss: 0.7507, Train: 78.36%, Valid: 71.80% Test: 69.63%
Run: 01, Epoch: 324, Loss: 0.7505, Train: 78.00%, Valid: 72.15% Test: 71.24%
Run: 01, Epoch: 325, Loss: 0.7484, Train: 78.20%, Valid: 72.24% Test: 70.74%
Run: 01, Epoch: 326, Loss: 0.7463, Train: 78.29%, Valid: 72.17% Test: 70.96%
Run: 01, Epoch: 327, Loss: 0.7480, Train: 78.20%, Valid: 72.47% Test: 71.48%
Run: 01, Epoch: 328, Loss: 0.7477, Train: 78.15%, Valid: 72.70% Test: 72.11%
Run: 01, Epoch: 329, Loss: 0.7465, Train: 77.88%, Valid: 72.36% Test: 71.59%
Run: 01, Epoch: 330, Loss: 0.7446, Train: 77.70%, Valid: 71.98% Test: 70.47%
Run: 01, Epoch: 331, Loss: 0.7439, Train: 77.83%, Valid: 70.89% Test: 68.66%
Run: 01, Epoch: 332, Loss: 0.7442, Train: 78.01%, Valid: 71.77% Test: 70.33%
Run: 01, Epoch: 333, Loss: 0.7514, Train: 78.26%, Valid: 71.52% Test: 69.54%
Run: 01, Epoch: 334, Loss: 0.7455, Train: 78.28%, Valid: 72.05% Test: 70.65%
Run: 01, Epoch: 335, Loss: 0.7450, Train: 77.07%, Valid: 72.00% Test: 72.02%

Run: 01, Epoch: 430, Loss: 0.7172, Train: 78.96%, Valid: 72.09% Test: 71.06%
Run: 01, Epoch: 431, Loss: 0.7181, Train: 78.55%, Valid: 72.45% Test: 72.02%
Run: 01, Epoch: 432, Loss: 0.7189, Train: 78.95%, Valid: 72.18% Test: 70.52%
Run: 01, Epoch: 433, Loss: 0.7142, Train: 79.27%, Valid: 72.55% Test: 70.43%
Run: 01, Epoch: 434, Loss: 0.7167, Train: 79.36%, Valid: 72.38% Test: 70.58%
Run: 01, Epoch: 435, Loss: 0.7167, Train: 79.31%, Valid: 71.97% Test: 70.00%
Run: 01, Epoch: 436, Loss: 0.7135, Train: 79.24%, Valid: 71.51% Test: 69.44%
Run: 01, Epoch: 437, Loss: 0.7140, Train: 79.17%, Valid: 72.57% Test: 71.48%
Run: 01, Epoch: 438, Loss: 0.7153, Train: 79.25%, Valid: 72.91% Test: 71.79%
Run: 01, Epoch: 439, Loss: 0.7138, Train: 79.08%, Valid: 72.10% Test: 69.73%
Run: 01, Epoch: 440, Loss: 0.7133, Train: 79.39%, Valid: 72.31% Test: 70.10%
Run: 01, Epoch: 441, Loss: 0.7120, Train: 79.30%, Valid: 72.23% Test: 70.37%
Run: 01, Epoch: 442, Loss: 0.7105, Train: 78.90%, Valid: 71.89% Test: 69.89%