In [4]:
import os.path as osp
from pathlib import Path

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv

from torch.utils.tensorboard import SummaryWriter

In [6]:
dataset = 'Cora'
#path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
path = Path.cwd()
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[-1]

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [11]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv0 = GCNConv(dataset.num_features, 16, cached=True)
        self.conv1 = GCNConv(16, dataset.num_classes, cached=True)

    def forward(self, x, edge_index):
        x = F.relu(self.conv0(x, edge_index, None))
        x = F.dropout(x, training=self.training)
        x = self.conv1(x, edge_index, None)
        return F.log_softmax(x, dim=0)

In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

In [18]:
def train():
    model.train()
    optimizer.zero_grad()
    logits = model(data.x, data.edge_index)
    loss = F.nll_loss(logits[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

In [22]:
@torch.no_grad()
def test():
    model.eval()
    logits, accs = model(data.x, data.edge_index), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

In [23]:
model(data.x, data.edge_index)
writer = SummaryWriter('runs/zkc-separation-pytorch-geometric')
writer.add_graph(model, [data.x, data.edge_index])

In [24]:
best_val_acc = test_acc = -1
for epoch in range(0, 201):
    train_loss = train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    log = 'Epoch: {:02d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
    print(log.format(epoch, train_acc, best_val_acc, test_acc))

    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Accuracy/train', train_acc, epoch)
    writer.add_scalar('Accuracy/val', val_acc, epoch)
    writer.add_scalar('Accuracy/test', test_acc, epoch)

Epoch: 00, Train: 0.2643, Val: 0.1980, Test: 0.1860
Epoch: 01, Train: 0.3571, Val: 0.2460, Test: 0.2360
Epoch: 02, Train: 0.4500, Val: 0.2880, Test: 0.2840
Epoch: 03, Train: 0.5071, Val: 0.3120, Test: 0.3210
Epoch: 04, Train: 0.5643, Val: 0.3500, Test: 0.3410
Epoch: 05, Train: 0.5929, Val: 0.3780, Test: 0.3690
Epoch: 06, Train: 0.6071, Val: 0.3940, Test: 0.3940
Epoch: 07, Train: 0.6286, Val: 0.4140, Test: 0.4180
Epoch: 08, Train: 0.6786, Val: 0.4260, Test: 0.4350
Epoch: 09, Train: 0.7000, Val: 0.4540, Test: 0.4570
Epoch: 10, Train: 0.7214, Val: 0.4760, Test: 0.4800
Epoch: 11, Train: 0.7286, Val: 0.4880, Test: 0.4940
Epoch: 12, Train: 0.7500, Val: 0.5040, Test: 0.5060
Epoch: 13, Train: 0.7571, Val: 0.5080, Test: 0.5250
Epoch: 14, Train: 0.7714, Val: 0.5200, Test: 0.5360
Epoch: 15, Train: 0.7714, Val: 0.5340, Test: 0.5490
Epoch: 16, Train: 0.7929, Val: 0.5440, Test: 0.5620
Epoch: 17, Train: 0.8143, Val: 0.5600, Test: 0.5710
Epoch: 18, Train: 0.8214, Val: 0.5820, Test: 0.5780
Epoch: 19, T