### Prepare Relevant imports

In [None]:
import torch.optim as optim 
import torch.nn as nn
import torch_geometric as tg
import torch

import sys
import os

gat_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(gat_path)
from gat import Layer_Attention_Dynamic_GATWithBias, Layer_Attention_MultiHead_GAT

### Preparing the data (the unique graph)

In [None]:
dataset = tg.datasets.Planetoid(root='data', name='Cora', split='full')
cora_dataloader = tg.loader.DataLoader(dataset)
cora_graph = next(iter(cora_dataloader))

nodes = cora_graph.x
y = cora_graph.y
adjacency_matrix = tg.utils.to_dense_adj(cora_graph.edge_index).squeeze(dim=0)

train_mask = cora_graph.train_mask
test_mask = cora_graph.test_mask
val_mask = cora_graph.val_mask

### Preparing Model and Optimisers

The dynamic attention which adds trainable biases before computing a LeakyReLU requires us to know the number of nodes in the graph in advance. Because of this limitation we construct the model classes inside the experiments directly in here.

In [None]:
class GAT_Transductive_DynamicBiases_CORA(nn.Module):
    
    def __init__(self, input_dim, num_classes):
        super(GAT_Transductive_DynamicBiases_CORA, self).__init__()

        self.dropout_1 = nn.Dropout(p=0.6)
        self.attention_layer_1 = Layer_Attention_Dynamic_GATWithBias(input_dim=input_dim,
                                                               repr_dim=8,
                                                               n_heads=8,
                                                               n_nodes=2708,
                                                               epsilon_bias=1.0,
                                                               alpha=0.2,
                                                               attention_aggr='concat',
                                                               dropout=0.6)
        self.activation_1 = nn.ELU()

        self.dropout_2 = nn.Dropout(p=0.6)
        self.attention_layer_2 = Layer_Attention_Dynamic_GATWithBias(input_dim=64,
                                                               repr_dim=num_classes,
                                                               n_heads=1,
                                                               n_nodes=2708,
                                                               epsilon_bias=1.0,
                                                               alpha=0.2,
                                                               attention_aggr='concat',
                                                               dropout=0.6)

    def forward(self, node_matrix, adj_matrix):
        node_matrix_dropout = self.dropout_1(node_matrix)

        z_1 = self.attention_layer_1(node_matrix_dropout, adj_matrix)
        a_1 = self.activation_1(z_1)

        a_1_dropout = self.dropout_2(a_1)
        
        z_2 = self.attention_layer_2(a_1_dropout, adj_matrix)

        return z_2

In [10]:
# default values from GAT paper
lr = 0.01
weight_decay = 0.005
# lr = 0.05
# weight_decay = 0.05

trans_model = GAT_Transductive_DynamicBiases_CORA(1433, 7)
criterion = nn.CrossEntropyLoss()
optimiser = optim.Adam(trans_model.parameters(), 
                       lr=lr, 
                       weight_decay=weight_decay)

### Running training with (very) simple early stopping with 20 epochs tolerance

In [None]:
tolerance = 0
best_val_accuracy = 0

for epoch in range(300):
    trans_model.train()
    
    output = trans_model(nodes, adjacency_matrix)
    loss = criterion(output[train_mask], y[train_mask])
    
    optimiser.zero_grad()
    loss.backward()
    optimiser.step()
    
    # compute validatio accuracy to fascilitate early stopping if needed
    trans_model.eval()
    with torch.no_grad():
        output = trans_model(nodes, adjacency_matrix)
        current_val_accuracy = (output[val_mask].argmax(dim=1) == y[val_mask]).sum().item() / val_mask.sum().item()
    
    if current_val_accuracy > best_val_accuracy:
        best_val_accuracy = current_val_accuracy
        tolerance = 0
        print(f'EPOCH {epoch+1} improved validation accuracy to: {current_val_accuracy}')
        
        with torch.no_grad():
            output = trans_model(nodes, adjacency_matrix)
            test_accuracy = (output[test_mask].argmax(dim=1) == y[test_mask]).sum().item() / test_mask.sum().item()
            print(f'PRINTING TEST ACCURACY: {test_accuracy}')
        print('SAVING MODEL to trans_model_dynamicbias.pt')
        torch.save(trans_model.state_dict(), 'trans_model_dynamicbias.pt')
        
    else:
        tolerance += 1
        print(f'EPOCH {epoch+1} did not improve validation accuracy')
        
        if tolerance == 19:
            print('Tolerance 20 epochs reached, exiting')
            break
    print('----------------------------------------------- \n')

EPOCH 1 improved validation accuracy to: 0.502
PRINTING TEST ACCURACY: 0.482
SAVING MODEL to trans_model_dynamicbias.pt
----------------------------------------------- 

EPOCH 2 improved validation accuracy to: 0.578
PRINTING TEST ACCURACY: 0.559
SAVING MODEL to trans_model_dynamicbias.pt
----------------------------------------------- 

EPOCH 3 improved validation accuracy to: 0.644
PRINTING TEST ACCURACY: 0.632
SAVING MODEL to trans_model_dynamicbias.pt
----------------------------------------------- 

EPOCH 4 improved validation accuracy to: 0.692
PRINTING TEST ACCURACY: 0.687
SAVING MODEL to trans_model_dynamicbias.pt
----------------------------------------------- 

EPOCH 5 improved validation accuracy to: 0.746
PRINTING TEST ACCURACY: 0.737
SAVING MODEL to trans_model_dynamicbias.pt
----------------------------------------------- 

EPOCH 6 improved validation accuracy to: 0.788
PRINTING TEST ACCURACY: 0.798
SAVING MODEL to trans_model_dynamicbias.pt
------------------------------

### Printing the final accuracy

In [None]:
trans_model.load_state_dict(torch.load('trans_model_dynamicbias.pt'))
trans_model.eval()

with torch.no_grad():
    output = trans_model(nodes, adjacency_matrix)
    final_test_accuracy = (output[test_mask].argmax(dim=1) == y[test_mask]).sum().item() / test_mask.sum().item()
    print(f'FINAL test accuracy is: {final_test_accuracy}')

FINAL test accuracy is: 0.865
