In [1]:
import torch
import random
import numpy as np
import argparse

import torch
from torch import Tensor
from torch_geometric.logging import init_wandb, log
from torch_geometric.datasets import Planetoid
from utils import train, test
from models import GCN, GAT, LP

In [2]:
citeseer = Planetoid(root='.', name='Citeseer')
cora = Planetoid(root='.', name='Cora')
pubmed = Planetoid(root='.', name='Pubmed')
torch.use_deterministic_algorithms(True)

In [3]:
k = 20
seeds = [42, 2021, 1234]
lr = 0.05
epochs = 200

In [4]:
dataset = citeseer
model = GCN(dataset.num_features, 24, dataset.num_classes)

# dataset = cora
# model = GCN(dataset.num_features, 8, dataset.num_classes)

# dataset = pubmed
# model = GCN(dataset.num_features, 8, dataset.num_classes)

# dataset = citeseer
# model = GAT(dataset.num_features, 8, dataset.num_classes, heads=4)

# dataset = cora
# model = GAT(dataset.num_features, 8, dataset.num_classes, heads=4)

# dataset = pubmed
# model = GAT(dataset.num_features, 8, dataset.num_classes, heads=4)

In [5]:
data = dataset[0]
for c in data.y.unique():
    idx = ((data.y == c) & data.train_mask).nonzero(as_tuple=False).view(-1)
    idx = idx[torch.randperm(idx.size(0))]
    idx = idx[k:]
    data.train_mask[idx] = False

In [6]:
av_val_acc = av_test_acc = 0
state_dict = model.state_dict()

for seed in seeds:
    print("RUNNING FOR SEED =", seed)
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    
    model.load_state_dict(state_dict)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)

    best_val_acc = final_test_acc = 0
    for epoch in range(1, 200):
        loss = train(model, data, optimizer, scheduler=None, loss='cross_entropy')
        train_acc, val_acc, tmp_test_acc = test(model, data)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc = tmp_test_acc
        if epoch % 25 == 0:
            log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)
    print(f'Best Val Acc: {best_val_acc:.4f}', f'Test Acc: {test_acc:.4f}')
    av_val_acc += best_val_acc
    av_test_acc += test_acc
    
print(f'Average Val Acc: {av_val_acc / len(seeds):.4f}', f'Average Test Acc: {av_test_acc / len(seeds):.4f}')    

RUNNING FOR SEED = 42
Epoch: 025, Loss: 0.004577638115733862, Train: 1.0000, Val: 0.6840, Test: 0.6820
Epoch: 050, Loss: 0.005804179236292839, Train: 1.0000, Val: 0.6880, Test: 0.6940
Epoch: 075, Loss: 0.005227127578109503, Train: 1.0000, Val: 0.6840, Test: 0.6940
Epoch: 100, Loss: 0.004953151103109121, Train: 1.0000, Val: 0.6840, Test: 0.6940
Epoch: 125, Loss: 0.004842142574489117, Train: 1.0000, Val: 0.6840, Test: 0.6940
Epoch: 150, Loss: 0.004779416136443615, Train: 1.0000, Val: 0.6820, Test: 0.6940
Epoch: 175, Loss: 0.004766649100929499, Train: 1.0000, Val: 0.6820, Test: 0.6940
Best Val Acc: 0.6940 Test Acc: 0.6940
RUNNING FOR SEED = 2021
Epoch: 025, Loss: 0.008416415192186832, Train: 1.0000, Val: 0.5800, Test: 0.5640
Epoch: 050, Loss: 0.005883164703845978, Train: 1.0000, Val: 0.6700, Test: 0.6240
Epoch: 075, Loss: 0.005403239745646715, Train: 1.0000, Val: 0.6740, Test: 0.6340
Epoch: 100, Loss: 0.005045166239142418, Train: 1.0000, Val: 0.6820, Test: 0.6340
Epoch: 125, Loss: 0.00509