In [1]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
from graph_data import GraphData
from models import DeeperGCN
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt


graph_data = GraphData()


# model = GCN(hidden_channels=64)
model = DeeperGCN(5, 256, device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.05) 

scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()

    for data in graph_data.train_loader: # Iterate in batches over the training dataset.
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
        loss = criterion(out, data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.


def test(loader):
    model.eval()

    correct = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)  
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        correct += int((pred == data.y).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(1, 100):
    train()
    train_acc = test(graph_data.train_loader)
    test_acc = test(graph_data.test_loader)
    scheduler.step()
    print(f"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}, lr: {optimizer.param_groups[0]['lr']}")

Epoch: 001, Train Acc: 0.5000, Test Acc: 0.5000, lr: 0.05
Epoch: 002, Train Acc: 0.5000, Test Acc: 0.5000, lr: 0.05
Epoch: 003, Train Acc: 0.8675, Test Acc: 0.8450, lr: 0.05
Epoch: 004, Train Acc: 0.7025, Test Acc: 0.6600, lr: 0.05
Epoch: 005, Train Acc: 0.8925, Test Acc: 0.8700, lr: 0.05
Epoch: 006, Train Acc: 0.8450, Test Acc: 0.8450, lr: 0.05
Epoch: 007, Train Acc: 0.9050, Test Acc: 0.8900, lr: 0.05
Epoch: 008, Train Acc: 0.6175, Test Acc: 0.6600, lr: 0.05
Epoch: 009, Train Acc: 0.8900, Test Acc: 0.8800, lr: 0.05
Epoch: 010, Train Acc: 0.9175, Test Acc: 0.8850, lr: 0.025
Epoch: 011, Train Acc: 0.8900, Test Acc: 0.8600, lr: 0.025
Epoch: 012, Train Acc: 0.8925, Test Acc: 0.8650, lr: 0.025
Epoch: 013, Train Acc: 0.9175, Test Acc: 0.8750, lr: 0.025
Epoch: 014, Train Acc: 0.8950, Test Acc: 0.8650, lr: 0.025
Epoch: 015, Train Acc: 0.9125, Test Acc: 0.8800, lr: 0.025
Epoch: 016, Train Acc: 0.9200, Test Acc: 0.8850, lr: 0.025
Epoch: 017, Train Acc: 0.9175, Test Acc: 0.8850, lr: 0.025
Epoch:

In [5]:
torch.save(model.state_dict(), 'gnn.pt')