### Prepare Relevant imports

In [1]:
import torch.optim as optim 
import torch.nn as nn
import torch_geometric as tg
import torch

import sys
import os

gat_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(gat_path)
from gat import GAT_Transductive

### Preparing the data (the unique graph)

In [2]:
dataset = tg.datasets.Planetoid(root='data', name='Cora', split='full')
cora_dataloader = tg.loader.DataLoader(dataset)
cora_graph = next(iter(cora_dataloader))

nodes = cora_graph.x
y = cora_graph.y
adjacency_matrix = tg.utils.to_dense_adj(cora_graph.edge_index).squeeze(dim=0)

train_mask = cora_graph.train_mask
test_mask = cora_graph.test_mask
val_mask = cora_graph.val_mask

### Preparing Model and Optimisers

In [3]:
# default values from GAT paper
lr = 0.005
weight_decay = 0.0005

trans_model = GAT_Transductive(1433, 7)
criterion = nn.CrossEntropyLoss()
optimiser = optim.Adam(trans_model.parameters(), 
                       lr=lr, 
                       weight_decay=weight_decay)

### Running training with (very) simple early stopping with 20 epochs tolerance

In [4]:
tolerance = 0
best_val_accuracy = 0

for epoch in range(300):
    trans_model.train()
    
    output = trans_model(nodes, adjacency_matrix)
    loss = criterion(output[train_mask], y[train_mask])
    
    optimiser.zero_grad()
    loss.backward()
    optimiser.step()
    
    # compute validatio accuracy to fascilitate early stopping if needed
    trans_model.eval()
    with torch.no_grad():
        output = trans_model(nodes, adjacency_matrix)
        current_val_accuracy = (output[val_mask].argmax(dim=1) == y[val_mask]).sum().item() / val_mask.sum().item()
    
    if current_val_accuracy > best_val_accuracy:
        best_val_accuracy = current_val_accuracy
        tolerance = 0
        print(f'EPOCH {epoch+1} improved validation accuracy to: {current_val_accuracy}')
        
        with torch.no_grad():
            output = trans_model(nodes, adjacency_matrix)
            test_accuracy = (output[test_mask].argmax(dim=1) == y[test_mask]).sum().item() / test_mask.sum().item()
            print(f'PRINTING TEST ACCURACY: {test_accuracy}')
        print('SAVING MODEL to trans_model.pt')
        torch.save(trans_model.state_dict(), 'trans_model.pt')
        
    else:
        tolerance += 1
        print(f'EPOCH {epoch+1} did not improve validation accuracy')
        
        if tolerance == 19:
            print('Tolerance 20 epochs reached, exiting')
            break
    print('----------------------------------------------- \n')

EPOCH 1 improved validation accuracy to: 0.468
PRINTING TEST ACCURACY: 0.454
SAVING MODEL to trans_model.pt
----------------------------------------------- 

EPOCH 2 improved validation accuracy to: 0.488
PRINTING TEST ACCURACY: 0.484
SAVING MODEL to trans_model.pt
----------------------------------------------- 

EPOCH 3 improved validation accuracy to: 0.512
PRINTING TEST ACCURACY: 0.506
SAVING MODEL to trans_model.pt
----------------------------------------------- 

EPOCH 4 improved validation accuracy to: 0.544
PRINTING TEST ACCURACY: 0.535
SAVING MODEL to trans_model.pt
----------------------------------------------- 

EPOCH 5 improved validation accuracy to: 0.596
PRINTING TEST ACCURACY: 0.59
SAVING MODEL to trans_model.pt
----------------------------------------------- 

EPOCH 6 improved validation accuracy to: 0.652
PRINTING TEST ACCURACY: 0.643
SAVING MODEL to trans_model.pt
----------------------------------------------- 

EPOCH 7 improved validation accuracy to: 0.712
PRINTI

### Printing the final accuracy

In [5]:
trans_model.load_state_dict(torch.load('trans_model.pt'))
trans_model.eval()

with torch.no_grad():
    output = trans_model(nodes, adjacency_matrix)
    final_test_accuracy = (output[test_mask].argmax(dim=1) == y[test_mask]).sum().item() / test_mask.sum().item()
    print(f'FINAL test accuracy is: {final_test_accuracy}')

FINAL test accuracy is: 0.873
