In [15]:
import torch.nn as nn
import torch
from torch_geometric.nn import GCNConv
import torch.nn.functional as F

import sys
sys.path.append('../../src/')
from utils.datasets import *

In [43]:
class GCN(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_dims = [16], dropout = 0.5):
        super().__init__()
        conv = []

        for dim in hidden_dims:
            conv.append(GCNConv(in_channels, dim))
            conv.append(nn.ReLU())
            conv.append(nn.Dropout(dropout))
            in_channels = dim
        conv.append(GCNConv(in_channels, out_channels))
        self.conv = nn.ModuleList(conv)
        
    def forward(self, x, edge_index):
        for layer in self.conv:
            if isinstance(layer, GCNConv):
                x = layer(x, edge_index)
            else:
                x = layer(x)
        return x

def train(model, data, optimizer, criterion, device):
    model.train()
    data = data.to(device)
    optimizer.zero_grad()
    output = model(data.x, data.edge_index)
    loss = criterion(output[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def test(model, data, criterion, device):
    model.eval()
    data = data.to(device)
    with torch.no_grad():
        output = model(data.x, data.edge_index)
        val_loss = criterion(output[data.val_mask], data.y[data.val_mask]).item()
        pred = output.argmax(dim=1)
        correct = (pred[data.val_mask] == data.y[data.val_mask]).sum().item()
        accuracy = correct / data.val_mask.sum().item()
    return val_loss, accuracy

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [17]:
cora_dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = cora_dataset[0].to(device)

In [18]:
data

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [48]:
model = GCN(data.x.shape[1], cora_dataset.num_classes, [16]).to(device)

In [49]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(1, 201):
    train_loss = train(model, data, optimizer, criterion, device)
    val_loss, val_accuracy = test(model, data, criterion, device)
    if epoch % 20 == 0:
        print(f'Epoch {epoch}, Train Loss - {train_loss}, Val Loss - {val_loss}, Val Accuracy - {val_accuracy}')

Epoch 20, Train Loss - 0.22458970546722412, Val Loss - 0.7786887288093567, Val Accuracy - 0.772
Epoch 40, Train Loss - 0.06594665348529816, Val Loss - 0.7225316166877747, Val Accuracy - 0.77
Epoch 60, Train Loss - 0.03585854172706604, Val Loss - 0.7300599217414856, Val Accuracy - 0.778
Epoch 80, Train Loss - 0.046450234949588776, Val Loss - 0.7358435988426208, Val Accuracy - 0.76
Epoch 100, Train Loss - 0.04005562514066696, Val Loss - 0.7371999025344849, Val Accuracy - 0.748
Epoch 120, Train Loss - 0.03932250291109085, Val Loss - 0.7474761605262756, Val Accuracy - 0.762
Epoch 140, Train Loss - 0.02957715280354023, Val Loss - 0.758110523223877, Val Accuracy - 0.76
Epoch 160, Train Loss - 0.033548079431056976, Val Loss - 0.7405297756195068, Val Accuracy - 0.774
Epoch 180, Train Loss - 0.04427113011479378, Val Loss - 0.771914541721344, Val Accuracy - 0.76
Epoch 200, Train Loss - 0.02660244330763817, Val Loss - 0.7440562844276428, Val Accuracy - 0.764
