In [1]:
import torch
import numpy as np
import csv
import pandas as pd

import json

from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import GCNConv

import torch.nn.functional as F
from torch.optim import SGD
from torch import nn

from tqdm import tqdm

from collections import defaultdict

from sklearn.manifold import TSNE
import pickle

from torch.optim.lr_scheduler import ExponentialLR, StepLR


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
from dataset import Dataset
from modules import GCN
from train import train_model
from predict import model_inference

# Load Data

In [4]:
data = Dataset('datasets/facebook.npz', device=device)

In [5]:
dataloader = data.data_loader(batchsize=32)

# Build Model

# Train Model

In [59]:
gcn_model = GCN(data.graph.x.shape[1], 4, hidden_layers=[128]*2).to(device)

optimizer = SGD(gcn_model.parameters(), lr=0.65)

CELoss = torch.nn.CrossEntropyLoss()
lr_schedular = ExponentialLR(optimizer, gamma=1)

train_model(gcn_model, optimizer, CELoss, dataloader, lr_schedular=lr_schedular, epochs=400)

399/400, Loss 0.3388: 100%|██████████| 400/400 [02:46<00:00,  2.41it/s]


In [60]:
print("> Saving model to picle")
with open("models/gcn_model.pkl", "wb") as file:
    pickle.dump(gcn_model, file)
    file.close()
print("> Model Saved Sucessfully")

> Saving model to picle
> Model Saved Sucessfully


In [61]:
logits = gcn_model(data.graph)
CELoss(logits[data.graph.test_mask], data.graph.y[data.graph.test_mask])

tensor(0.3855, device='cuda:0', grad_fn=<NllLossBackward0>)

In [62]:
prediction = model_inference(gcn_model, data.graph)

print("> Inference Complete")
accuracy = (prediction == data.graph.y).sum()/len(data.graph.y)
test_accuracy = (prediction[data.graph.test_mask] == data.graph.y[data.graph.test_mask]).sum()/len(data.graph.y[data.graph.test_mask])

print("> Model Results")
print(f"\tTest Accuracy: {test_accuracy}")
print(f"\tFull graph Accuracy: {accuracy}")

> Inference Complete
> Model Results
	Test Accuracy: 0.8724966645240784
	Full graph Accuracy: 0.8751668930053711


In [63]:
for i in data.graph.y.unique():
    filt = (data.graph.test_mask) & (data.graph.y == i)
    class_acc = (prediction[filt] == data.graph.y[filt]).sum()/len(data.graph.y[filt])
    print(i.to('cpu').detach().numpy(), class_acc.to('cpu').detach().numpy())

0 0.7549168
1 0.8172783
2 0.94213647
3 0.9201359
