In [1]:
import torch
import random
import numpy as np
import argparse

import torch
from torch import Tensor
from torch_geometric.logging import init_wandb, log
from torch_geometric.datasets import Planetoid
from utils import train, test
from models import GCN, GAT, LP

In [2]:
citeseer = Planetoid(root='.', name='Citeseer')
cora = Planetoid(root='.', name='Cora')
pubmed = Planetoid(root='.', name='Pubmed')
torch.use_deterministic_algorithms(True)

In [19]:
k = 1
seeds = list(range(5))
lr = 0.05
epochs = 200

In [20]:
# dataset = citeseer
# model = GCN(dataset.num_features, 8, dataset.num_classes)

# dataset = cora
# model = GCN(dataset.num_features, 8, dataset.num_classes)

# dataset = pubmed
# model = GCN(dataset.num_features, 8, dataset.num_classes)

# dataset = citeseer
# model = GAT(dataset.num_features, 8, dataset.num_classes, heads=4)

# dataset = cora
# model = GAT(dataset.num_features, 8, dataset.num_classes, heads=4)

dataset = pubmed
model = GAT(dataset.num_features, 8, dataset.num_classes, heads=4)

In [21]:
data = dataset[0]
for c in data.y.unique():
    idx = ((data.y == c) & data.train_mask).nonzero(as_tuple=False).view(-1)
    idx = idx[torch.randperm(idx.size(0))]
    idx = idx[k:]
    data.train_mask[idx] = False

In [22]:
av_val_acc = av_test_acc = 0
state_dict = model.state_dict()

for seed in seeds:
    print("RUNNING FOR SEED =", seed)
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    
    model.load_state_dict(state_dict)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)

    best_val_acc = final_test_acc = 0
    for epoch in range(1, 200):
        loss = train(model, data, optimizer, scheduler=None, loss='cross_entropy')
        train_acc, val_acc, tmp_test_acc = test(model, data)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc = tmp_test_acc
        if epoch % 25 == 0:
            log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)
    print(f'Best Val Acc: {best_val_acc:.4f}', f'Test Acc: {test_acc:.4f}')
    av_val_acc += best_val_acc
    av_test_acc += test_acc
    
print(f'Average Val Acc: {av_val_acc / len(seeds):.4f}', f'Average Test Acc: {av_test_acc / len(seeds):.4f}')    

RUNNING FOR SEED = 0
Epoch: 025, Loss: 0.6572325825691223, Train: 1.0000, Val: 0.5100, Test: 0.5530
Epoch: 050, Loss: 0.5565707683563232, Train: 1.0000, Val: 0.5100, Test: 0.5580
Epoch: 075, Loss: 0.40692010521888733, Train: 1.0000, Val: 0.5600, Test: 0.5580
Epoch: 100, Loss: 0.1948951929807663, Train: 1.0000, Val: 0.5780, Test: 0.6350
Epoch: 125, Loss: 0.13045743107795715, Train: 1.0000, Val: 0.5120, Test: 0.6350
Epoch: 150, Loss: 0.1663239747285843, Train: 1.0000, Val: 0.4680, Test: 0.6350
Epoch: 175, Loss: 0.30762746930122375, Train: 1.0000, Val: 0.5560, Test: 0.6430
Best Val Acc: 0.6680 Test Acc: 0.6430
RUNNING FOR SEED = 1
Epoch: 025, Loss: 0.11234235018491745, Train: 1.0000, Val: 0.5060, Test: 0.5040
Epoch: 050, Loss: 0.4810583293437958, Train: 1.0000, Val: 0.5060, Test: 0.5890
Epoch: 075, Loss: 0.004034919664263725, Train: 1.0000, Val: 0.4340, Test: 0.5890
Epoch: 100, Loss: 0.0012076605344191194, Train: 1.0000, Val: 0.2900, Test: 0.5890
Epoch: 125, Loss: 0.2508697211742401, Trai