In [5]:
import os
import pathlib

import numpy as np
from torch_geometric.utils import from_networkx
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.loader import DataLoader

data = torch.load("datasets/data.pt")

In [6]:
# GCN model with 2 layers
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(data.num_node_features, 32)
        self.conv2 = GCNConv(32, int(data.y.max() + 2))

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(500):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

    model.eval()
    pred = model(data).argmax(dim=1)
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = int(correct) / int(data.test_mask.sum())

    print("Epoch: " + str(epoch))
    print(f'Loss: {loss:.4f}, Accuracy: {acc:.4f}')

Epoch: 0
Loss: 22.9644, Accuracy: 0.0006
Epoch: 1
Loss: 12.5930, Accuracy: 0.1635
Epoch: 2
Loss: 3.9960, Accuracy: 0.4189
Epoch: 3
Loss: 1.4703, Accuracy: 0.5196
Epoch: 4
Loss: 0.8084, Accuracy: 0.5687
Epoch: 5
Loss: 1.1733, Accuracy: 0.5686
Epoch: 6
Loss: 1.2425, Accuracy: 0.5631
Epoch: 7
Loss: 0.8748, Accuracy: 0.4504
Epoch: 8
Loss: 0.8789, Accuracy: 0.4327
Epoch: 9
Loss: 1.0960, Accuracy: 0.4431
Epoch: 10
Loss: 0.8347, Accuracy: 0.5601
Epoch: 11
Loss: 0.7797, Accuracy: 0.5682
Epoch: 12
Loss: 0.9575, Accuracy: 0.5679
Epoch: 13
Loss: 0.8747, Accuracy: 0.5657
Epoch: 14
Loss: 0.7074, Accuracy: 0.4316
Epoch: 15
Loss: 0.8385, Accuracy: 0.4332
Epoch: 16
Loss: 0.8498, Accuracy: 0.5279
Epoch: 17
Loss: 0.7015, Accuracy: 0.5676
Epoch: 18
Loss: 0.7738, Accuracy: 0.5682
Epoch: 19
Loss: 0.8230, Accuracy: 0.5649
Epoch: 20
Loss: 0.7254, Accuracy: 0.4603
Epoch: 21
Loss: 0.7223, Accuracy: 0.4253
Epoch: 22
Loss: 0.7957, Accuracy: 0.4563
Epoch: 23
Loss: 0.7272, Accuracy: 0.5698
Epoch: 24
Loss: 0.7118, 