In [1]:
# This model is currently unusable 

In [3]:
import torch
import numpy as np
import math
import time
import argparse
import scipy.sparse as sp
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm.notebook import trange, tqdm

from models import process
from models import AvgReadout, LogReg
from models import Discriminator
from models import GCN
from models import DGI

import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv

In [4]:
from ogb.graphproppred import PygGraphPropPredDataset
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from torch_geometric.data import DataLoader

In [5]:
device_args = 0 #default
device = f'cuda:{device_args}' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

In [6]:
dataset = PygNodePropPredDataset(name='ogbn-arxiv', transform=T.ToSparseTensor())
split_idx = dataset.get_idx_split()
data = dataset[0]

data.adj_t = data.adj_t.to_symmetric()
data = data.to(device)

split_idx = dataset.get_idx_split()
train_idx = split_idx['train'].to(device)

In [7]:
class DGI(nn.Module):
    def __init__(self, n_in, n_h, activation):
        super(DGI, self).__init__()
        self.gcn = GCN(n_in, n_h, activation)
        self.read = AvgReadout()
        self.sigm = nn.Sigmoid()
        self.disc = Discriminator(n_h)

    def forward(self, seq1, seq2, adj, sparse, msk, samp_bias1, samp_bias2):
        h_1 = self.gcn(seq1, adj, sparse)
        c = self.read(h_1, msk)
        c = self.sigm(c)
        h_2 = self.gcn(seq2, adj, sparse)
        ret = self.disc(c, h_1, h_2, samp_bias1, samp_bias2)
        return ret

    # Detach the return variables
    def embed(self, seq, adj, sparse, msk):
        h_1 = self.gcn(seq, adj, sparse)
        c = self.read(h_1, msk)
        return h_1.detach(), c.detach()

In [8]:
dataset = 'cora'

# training params
batch_size = 1
nb_epochs = 10000
patience = 50
lr = 0.001
l2_coef = 0.0
drop_prob = 0.0
hid_units = 512
sparse = True
nonlinearity = 'prelu' # special name to separate parameters

adj, features, labels, idx_train, idx_val, idx_test = process.load_data(dataset)
features, _ = process.preprocess_features(features)

nb_nodes = features.shape[0]
ft_size = features.shape[1]
nb_classes = labels.shape[1]

adj = process.normalize_adj(adj + sp.eye(adj.shape[0]))

if sparse:
    sp_adj = process.sparse_mx_to_torch_sparse_tensor(adj)
else:
    adj = (adj + sp.eye(adj.shape[0])).todense()

features = torch.FloatTensor(features[np.newaxis])
if not sparse:
    adj = torch.FloatTensor(adj[np.newaxis])
labels = torch.FloatTensor(labels[np.newaxis])
idx_train = torch.LongTensor(idx_train)
idx_val = torch.LongTensor(idx_val)
idx_test = torch.LongTensor(idx_test)

model = DGI(ft_size, hid_units, nonlinearity)
optimiser = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_coef)

if torch.cuda.is_available():
    print('Using CUDA')
    model.cuda()
    features = features.cuda()
    if sparse:
        sp_adj = sp_adj.cuda()
    else:
        adj = adj.cuda()
    labels = labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()

b_xent = nn.BCEWithLogitsLoss()
xent = nn.CrossEntropyLoss()
cnt_wait = 0
best = 1e9
best_t = 0


Using CUDA


In [10]:
for epoch in tqdm(range(nb_epochs)):
    model.train()
    optimiser.zero_grad()

    idx = np.random.permutation(nb_nodes)
    shuf_fts = features[:, idx, :]

    lbl_1 = torch.ones(batch_size, nb_nodes)
    lbl_2 = torch.zeros(batch_size, nb_nodes)
    lbl = torch.cat((lbl_1, lbl_2), 1)

    if torch.cuda.is_available():
        shuf_fts = shuf_fts.cuda()
        lbl = lbl.cuda()
    
    logits = model(features, shuf_fts, sp_adj if sparse else adj, sparse, None, None, None) 

    loss = b_xent(logits, lbl)

    # print('Loss:', loss)

    if loss < best:
        best = loss
        best_t = epoch
        cnt_wait = 0
        torch.save(model.state_dict(), 'best_dgi.pkl')
        print('Loss:', loss)
    else:
        cnt_wait += 1

    if cnt_wait == patience:
        print('Early stopping!')
        break

    loss.backward()
    optimiser.step()

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: tensor(0.0112, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.0110, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.0096, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.0095, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.0088, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.0081, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.0077, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Early stopping!


In [27]:
print('Loading {}th epoch'.format(best_t))
model.load_state_dict(torch.load('best_dgi.pkl'))

embeds, _ = model.embed(features, sp_adj if sparse else adj, sparse, None)
train_embs = embeds[0, idx_train]
val_embs = embeds[0, idx_val]
test_embs = embeds[0, idx_test]

train_lbls = torch.argmax(labels[0, idx_train], dim=1)
val_lbls = torch.argmax(labels[0, idx_val], dim=1)
test_lbls = torch.argmax(labels[0, idx_test], dim=1)

tot = torch.zeros(1)
tot = tot.cuda()

accs = []

Loading 125th epoch


In [28]:
for _ in tqdm(range(50)):
    log = LogReg(hid_units, nb_classes)
    opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)
    log.cuda()

    pat_steps = 0
    best_acc = torch.zeros(1)
    best_acc = best_acc.cuda()
    for _ in range(100):
        log.train()
        opt.zero_grad()

        logits = log(train_embs)
        loss = xent(logits, train_lbls)
        
        loss.backward()
        opt.step()

    logits = log(test_embs)
    preds = torch.argmax(logits, dim=1)
    acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
    accs.append(acc * 100)
    # print(acc)
    tot += acc

print('Average accuracy:', tot / 50)
accs = torch.stack(accs)

  0%|          | 0/50 [00:00<?, ?it/s]

Average accuracy: tensor([0.8250], device='cuda:0')


In [31]:
print("mean accuracy: {}".format(accs.mean()))
print("standard deviation accuracy: {}".format(accs.std()))

mean accuracy: 82.49600219726562
mean accuracy: 0.15380848944187164
