# Transductive Transformer experiments on CORA

By Sam Barrett.
These follow essentially the same steps as Kally's GAT experiments on CORA.
Most of the setup is copied from `demo_kally_transductive_cora`.

In [1]:
import torch.optim as optim 
import torch.nn as nn
import torch_geometric as tg
import torch

import sys
import os

gat_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(gat_path)
from gat import VanillaTransformer_Transductive

In [2]:
dataset = tg.datasets.Planetoid(root='data', name='Cora', split='full')
cora_dataloader = tg.loader.DataLoader(dataset)
cora_graph = next(iter(cora_dataloader))

nodes = cora_graph.x
y = cora_graph.y
adjacency_matrix = tg.utils.to_dense_adj(cora_graph.edge_index).squeeze(dim=0)

train_mask = cora_graph.train_mask
test_mask = cora_graph.test_mask
val_mask = cora_graph.val_mask

In [41]:
lr = 1.0e-1
weight_decay = 3.0e-5

MODEL_FILENAME = 'vanilla_transformer_model.pt'
model = VanillaTransformer_Transductive(1433, 7, 64, 2, 8, dropout_hidden=0.1,
                                        identity_bias=0.02)
criterion = nn.CrossEntropyLoss()
optimiser = optim.Adam(model.parameters(), 
                       lr=lr, weight_decay=weight_decay)
sched = optim.lr_scheduler.StepLR(optimiser, 50, gamma=0.8)

In [42]:
tolerance = 0
MAX_TOL = 40
best_val_accuracy = 0

for epoch in range(300):
    model.train()
    
    output = model(nodes, adjacency_matrix)
    loss = criterion(output[train_mask], y[train_mask])
    print("Training loss =", float(loss))
    
    optimiser.zero_grad()
    loss.backward()
    optimiser.step()
    sched.step()
    
    # compute validatio accuracy to fascilitate early stopping if needed
    model.eval()
    with torch.no_grad():
        output = model(nodes, adjacency_matrix)
        current_val_accuracy = (output[val_mask].argmax(dim=1) == y[val_mask]).sum().item() / val_mask.sum().item()
        test_accuracy = (output[test_mask].argmax(dim=1) == y[test_mask]).sum().item() / test_mask.sum().item()

    if current_val_accuracy > best_val_accuracy:
        best_val_accuracy = current_val_accuracy
        tolerance = 0
        print(f'EPOCH {epoch+1} improved validation accuracy to: {current_val_accuracy}')
        print(f'PRINTING TEST ACCURACY: {test_accuracy}')
        print('SAVING MODEL to', MODEL_FILENAME)
        torch.save(model.state_dict(), MODEL_FILENAME)
        
    else:
        tolerance += 1
        print(f'EPOCH {epoch+1} did not improve validation accuracy')
        
        if tolerance == MAX_TOL - 1:
            print(f'Tolerance {MAX_TOL} epochs reached, exiting')
            break
    print('----------------------------------------------- \n')

Training loss = 2.2685842514038086
EPOCH 1 improved validation accuracy to: 0.316
PRINTING TEST ACCURACY: 0.319
SAVING MODEL to vanilla_transformer_model.pt
----------------------------------------------- 

Training loss = 2.195664882659912
EPOCH 2 did not improve validation accuracy
----------------------------------------------- 

Training loss = 2.070953607559204
EPOCH 3 did not improve validation accuracy
----------------------------------------------- 

Training loss = 1.99169921875
EPOCH 4 did not improve validation accuracy
----------------------------------------------- 

Training loss = 1.9254754781723022
EPOCH 5 did not improve validation accuracy
----------------------------------------------- 

Training loss = 1.8780906200408936
EPOCH 6 did not improve validation accuracy
----------------------------------------------- 

Training loss = 1.8588660955429077
EPOCH 7 did not improve validation accuracy
----------------------------------------------- 

Training loss = 1.85772383

Training loss = 0.9683818817138672
EPOCH 56 did not improve validation accuracy
----------------------------------------------- 

Training loss = 0.9475275278091431
EPOCH 57 improved validation accuracy to: 0.676
PRINTING TEST ACCURACY: 0.633
SAVING MODEL to vanilla_transformer_model.pt
----------------------------------------------- 

Training loss = 0.9050538539886475
EPOCH 58 did not improve validation accuracy
----------------------------------------------- 

Training loss = 0.9258151054382324
EPOCH 59 did not improve validation accuracy
----------------------------------------------- 

Training loss = 0.9561179280281067
EPOCH 60 did not improve validation accuracy
----------------------------------------------- 

Training loss = 0.9318699836730957
EPOCH 61 did not improve validation accuracy
----------------------------------------------- 

Training loss = 0.8259556293487549
EPOCH 62 did not improve validation accuracy
----------------------------------------------- 

Training los

In [5]:
model.load_state_dict(torch.load(MODEL_FILENAME))
model.eval()

with torch.no_grad():
    output = model(nodes, adjacency_matrix)
    final_test_accuracy = (output[test_mask].argmax(dim=1) == y[test_mask]).sum().item() / test_mask.sum().item()
    print(f'FINAL test accuracy is: {final_test_accuracy}')

FINAL test accuracy is: 0.715
