In [3]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m45.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1


In [5]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv

dataset = Planetoid(root='/tmp/Cora', name='Cora')

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def test():
    model.eval()
    _, pred = model(data).max(dim=1)
    correct = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()
    accuracy = correct / data.test_mask.sum().item()
    return accuracy

for epoch in range(200):
    loss = train()
    if epoch % 10 == 0:
        acc = test()
        print(f'Epoch: {epoch}, Loss: {loss:.4f}, Test Accuracy: {acc:.4f}')

Epoch: 0, Loss: 1.9592, Test Accuracy: 0.5190
Epoch: 10, Loss: 0.7539, Test Accuracy: 0.7740
Epoch: 20, Loss: 0.2462, Test Accuracy: 0.7960
Epoch: 30, Loss: 0.1192, Test Accuracy: 0.7930
Epoch: 40, Loss: 0.0699, Test Accuracy: 0.7900
Epoch: 50, Loss: 0.0448, Test Accuracy: 0.7950
Epoch: 60, Loss: 0.0462, Test Accuracy: 0.7960
Epoch: 70, Loss: 0.0378, Test Accuracy: 0.7940
Epoch: 80, Loss: 0.0341, Test Accuracy: 0.7910
Epoch: 90, Loss: 0.0473, Test Accuracy: 0.7940
Epoch: 100, Loss: 0.0342, Test Accuracy: 0.8010
Epoch: 110, Loss: 0.0314, Test Accuracy: 0.7910
Epoch: 120, Loss: 0.0343, Test Accuracy: 0.8010
Epoch: 130, Loss: 0.0543, Test Accuracy: 0.7970
Epoch: 140, Loss: 0.0366, Test Accuracy: 0.8040
Epoch: 150, Loss: 0.0288, Test Accuracy: 0.7980
Epoch: 160, Loss: 0.0353, Test Accuracy: 0.7970
Epoch: 170, Loss: 0.0234, Test Accuracy: 0.7970
Epoch: 180, Loss: 0.0336, Test Accuracy: 0.8000
Epoch: 190, Loss: 0.0355, Test Accuracy: 0.8010
