In [120]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from dgl.data import CoraGraphDataset
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Setting Up DGL Dataset

In [97]:
g = None
in_feats = None
h_feats = None
num_classes = None
cora_dataset = None
features = None

# reset_dataset() is used to reset the dataset (duh!) after changes
def reset_dataset():
    global g, in_feats, h_feats, num_classes, cora_dataset, features

    cora_dataset = CoraGraphDataset()
    g = cora_dataset[0]
    features = g.ndata['feat']

    in_feats = features.shape[1]
    h_feats = 64
    num_classes = cora_dataset.num_classes

reset_dataset()

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


# Setting Up Saved Model + Ground Truth

In [98]:
class GCN(nn.Module):
    def __init__(self, g, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = dgl.nn.GraphConv(in_feats, h_feats)
        self.conv2 = dgl.nn.GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

In [99]:
model = GCN(g, in_feats, h_feats, num_classes)
model.load_state_dict(torch.load("../model/cora_gt.pt"))
model.eval()

GCN(
  (conv1): GraphConv(in=1433, out=64, normalization=both, activation=None)
  (conv2): GraphConv(in=64, out=7, normalization=both, activation=None)
)

In [100]:
def test(data):
    model.eval()
    out = model(data, features)
    pred = out.argmax(dim=1)

    acc = (pred[data.ndata["test_mask"]] == data.ndata["label"][data.ndata["test_mask"]]).sum().item() / data.ndata["test_mask"].sum().item()
    return acc

In [116]:
# used to calculate the change in accuracy
def changed_acc(gt, cv):
    print("\n----")
    if gt != cv:
        print(f'The accuracy has changed by {gt - cv:.4f}')
    else:
        print("The accuracy has not changed.")

In [101]:
ground_truth = test(cora_dataset[0])
ground_truth

0.769

# Experiments
**Note:** The ground truth here is $0.769$