### Prepare Relevant imports

In [1]:
import torch.optim as optim 
import torch.nn as nn
import torch_geometric as tg
import sklearn.metrics as metrics
import torch

import sys
import os
import time

gat_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(gat_path)
from gat import GAT_Inductive

### Preparing the data (the unique graph)

In [2]:
train_dataset = tg.datasets.PPI(root='data', split='train')
val_dataset = tg.datasets.PPI(root='data', split='val')
test_dataset = tg.datasets.PPI(root='data', split='test')

train_loader = tg.loader.DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = tg.loader.DataLoader(val_dataset, batch_size=2)
test_loader = tg.loader.DataLoader(test_dataset, batch_size=2)

# for simplicity later on we just cache those
val_graph_pair = next(iter(val_loader))
test_graph_pair = next(iter(test_loader))

nodes_val = val_graph_pair.x
y_val = val_graph_pair.y
adjacency_val = tg.utils.to_dense_adj(val_graph_pair.edge_index).squeeze(dim=0)

nodes_test = test_graph_pair.x
y_test = test_graph_pair.y
adjacency_test = tg.utils.to_dense_adj(test_graph_pair.edge_index).squeeze(dim=0)

### Preparing Model and Optimisers

In [3]:
# default values from GAT paper
lr = 0.005
weight_decay = 0.0005

ind_model = GAT_Inductive(50, 121)
criterion = nn.CrossEntropyLoss()
optimiser = optim.Adam(ind_model.parameters(), 
                       lr=lr, 
                       weight_decay=weight_decay)

### Running training with (very) simple early stopping with 10 epochs tolerance

In [4]:
tolerance = 0
best_val_f1 = 0

for epoch in range(50):
    ind_model.train()
    start_time = time.time()
    
    for train_graph_pair in train_loader:
        nodes_train = train_graph_pair.x
        y_train = train_graph_pair.y
        adjacency_train = tg.utils.to_dense_adj(train_graph_pair.edge_index, max_num_nodes=nodes_train.shape[0]).squeeze(dim=0)
    
        output = ind_model(nodes_train, adjacency_train)
        loss = criterion(output, y_train)
    
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
    
    # compute validation micro f1-score to fascilitate early stopping if needed
    ind_model.eval()
    with torch.no_grad():
        output = ind_model(nodes_val, adjacency_val)
        output_labeled = torch.where(output > 0.5, 1.0, 0.0)
        current_val_f1 = metrics.f1_score(output_labeled, y_val, average='micro')
    
    if current_val_f1 > best_val_f1:
        best_val_f1 = current_val_f1
        tolerance = 0
        print(f'EPOCH {epoch+1} improved validation micro f1-score to: {current_val_f1}')
        
        with torch.no_grad():
            output = ind_model(nodes_test, adjacency_test)
            output_labeled = torch.where(output > 0.5, 1.0, 0.0)
            test_f1 = metrics.f1_score(output_labeled, y_test, average='micro')
            print(f'PRINTING TEST MICRO F1-SCORE: {test_f1}')
        print('SAVING MODEL to ind_model.pt')
        torch.save(ind_model.state_dict(), 'ind_model.pt')
        
    else:
        tolerance += 1
        print(f'EPOCH {epoch+1} did not improve validation micro f1-score')
        
        if tolerance == 19:
            print('Tolerance 20 epochs reached, exiting')
            break
    end_time = time.time()
    print('Seconds taken for this epoch: {:.4f}s'.format(end_time - start_time))
    print('----------------------------------------------- \n')

EPOCH 1 improved validation micro f1-score to: 0.49921082990921806
PRINTING TEST MICRO F1-SCORE: 0.503750089531469
SAVING MODEL to ind_model.pt
Seconds taken for this epoch: 182.7012s
----------------------------------------------- 

EPOCH 2 did not improve validation micro f1-score
Seconds taken for this epoch: 163.8051s
----------------------------------------------- 

EPOCH 3 improved validation micro f1-score to: 0.5157148060300792
PRINTING TEST MICRO F1-SCORE: 0.5180258538973553
SAVING MODEL to ind_model.pt
Seconds taken for this epoch: 171.8168s
----------------------------------------------- 

EPOCH 4 improved validation micro f1-score to: 0.5300411591838794
PRINTING TEST MICRO F1-SCORE: 0.5346166062208697
SAVING MODEL to ind_model.pt
Seconds taken for this epoch: 173.5455s
----------------------------------------------- 

EPOCH 5 improved validation micro f1-score to: 0.5382811652494639
PRINTING TEST MICRO F1-SCORE: 0.5428615347828315
SAVING MODEL to ind_model.pt
Seconds taken 

### Printing the final accuracy

In [5]:
ind_model.load_state_dict(torch.load('ind_model.pt'))
ind_model.eval()

with torch.no_grad():
    output = ind_model(nodes_test, adjacency_test)
    output_labeled = torch.where(output > 0.5, 1.0, 0.0)
    final_test_f1 = metrics.f1_score(output_labeled, y_test, average='micro')
    print(f'FINAL test micro F1-score is: {final_test_f1}')

FINAL test micro F1-score is: 0.7660773039235489
