In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.backends.cudnn as cudnn
import numpy as np

torch.manual_seed(0)
np.random.seed(0)
cudnn.deterministic = True
cudnn.benchmark = False

import seaborn as sns
import matplotlib.pyplot as plt
from aug import TUDataset_aug as TUDataset
from torch_geometric.data import DataLoader
from gsimclr_pt import simclr, Encoder
from arguments import arg_parse
import sys

from sklearn.metrics.pairwise import cosine_similarity
import tqdm.autonotebook as tqdm

In [3]:
sns.set_style("whitegrid")

In [4]:
class Supervised(torch.nn.Module):
    def __init__(self, args):
        super(Supervised, self).__init__()
    
        self.args = args
        self.encoder = Encoder(args.dataset_num_features, args.hidden_dim, args.num_gc_layers)
        self.embedding_dim = mi_units = args.hidden_dim * args.num_gc_layers
        self.classifer = torch.nn.Linear(self.embedding_dim,args.num_labels)

        if args.bn_int == True:
            self.bn_int = torch.nn.BatchNorm1d(args.hidden_dim * args.num_gc_layers) 
        else:
            self.bn_int = None

        self.init_emb()
    def init_emb(self):
        initrange = -1.5 / self.embedding_dim
        for m in self.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.fill_(0.0)
    
    def forward(self, x, edge_index, batch, num_graphs):
        # batch_size = data.num_graphs
        if x is None:
            x = torch.ones(batch.shape[0]).to(device)
        y, M = self.encoder(x, edge_index, batch)
        
        if self.bn_int is not None:
            y = self.bn_int(y)
        y = self.classifer(y)
        return y

In [5]:
def test(net, data_loader, device):
    net.eval()
    correct = 0
    for data in data_loader:
        data,_  = data
        data.to(device)
        pred = net(data.x, data.edge_index, data.batch, data.num_graphs)
        pred = pred.argmax(dim=1)
        target = data.y
        correct += pred.eq(target).sum().item()

    acc = correct / data_loader.dataset.__len__()
    return acc

## Load Data

In [6]:
### define dataloader to be shared over random and trained models
DS = 'DD'

sys.argv = [".. ",'--DS={}'.format(DS)]
args = arg_parse() 
args.seed=3
args.batch_size = 32
args.bn_int = True
args.epochs = 30
args.lr = 0.01
args.num_gc_layers = 3
args.prior =False

if DS == 'COLLAB':
    args.num_labels = 3
    args.aug = 'random4'

else:
    args.num_labels=2
    args.aug = 'random2'


dataset = TUDataset("/home/sc/eslubana/graphssl/PosGraphCL/data/{}".format(DS), name=DS, aug=args.aug)
dataset_eval = TUDataset("/home/sc/eslubana/graphssl/PosGraphCL/data/{}".format(DS), name=DS, aug="none")

args.dataset_num_features = dataset_eval[0][0].num_node_features

train_test_split = int(np.floor(len(dataset)*0.9))
idx = list(np.random.permutation(len(dataset_eval)))

print("Train vs. Test: ",train_test_split, len(dataset))

dataloader = DataLoader(dataset[idx[0:train_test_split]], batch_size=args.batch_size,shuffle=False)
dataloader_eval = DataLoader(dataset[idx[train_test_split:]], batch_size=args.batch_size,shuffle=False)

Train vs. Test:  1060 1178


In [7]:
device='cuda:0'
print(device)

model = Supervised(args)
model.to(device);
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [65,85], gamma=0.6, last_epoch=-1, verbose=False)
best_ckpt = {}
args.epochs =100

cuda:0


RuntimeError: CUDA error: out of memory

In [None]:
best_loss = 1000
for epoch in range(1, args.epochs+1):
    loss_all = 0
    model.train()
    for data in dataloader:
        #unpack the augmented version
        data, _ = data
        data.to(device)
        pred = model(data.x, data.edge_index, data.batch, data.num_graphs)
        L = torch.nn.CrossEntropyLoss()(pred,data.y)
        
        optimizer.zero_grad()
        L.backward()
        #torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
        optimizer.step()
        
        loss_all += L.item()
    scheduler.step()
    loss_all /= len(dataloader)

    model.eval()
    train_acc= test(model, dataloader, device)
    test_acc= test(model, dataloader_eval, device)

    if loss_all < best_loss:
        best_ckpt['net'] = model.state_dict()
        best_ckpt['epoch'] = epoch
        best_ckpt['train_acc'] = train_acc
        best_ckpt['test_acc'] = test_acc
        best_loss = loss_all
        print("Best Epoch: ",epoch)
    print('Epoch: {0}, Loss: {1:.4f}, Train Acc: {2:.4f}, Test Acc: {3:.4f}'.format(epoch, loss_all,train_acc,test_acc))

In [None]:
print("Best Epoch: {0} Train Acc: {1:.4f} Test Acc: {2:.4f}".format(best_ckpt['epoch'],best_ckpt['train_acc'],best_ckpt['test_acc']))
model.load_state_dict(best_ckpt['net'])
model.eval();

### Compute Accuracies on Augmentations

In [None]:
num_correct = 0
avg_correct = 0
accs = []
model.eval()
for epoch in range(1, 5+1):
    num_correct=0
    for data_aug in dataloader:
        #unpack the augmented version
        _, data_aug = data_aug
        data_aug.to(device)
        pred = model(data_aug.x, data_aug.edge_index, data_aug.batch, data_aug.num_graphs)
        pred = pred.argmax(dim=1)
        target = data_aug.y
        num_correct += pred.eq(target).sum().item()
    avg_correct += num_correct / dataloader.dataset.__len__()
    accs.append(num_correct / dataloader.dataset.__len__())
    print('Running Average: ',num_correct / dataloader.dataset.__len__())


In [None]:
avg_correct = avg_correct / 5
print("*"*10)
print("Average Correct: ",avg_correct*100)
print("Average Correct: ",avg_correct*100)

## Create Plots

In [None]:
#acc
misclassification = {
    'NCI1':[0.601,],
    'PROTEINS': [(74.13,46.89), (76.12,54,74),(74.33,52,28)],
    'DD':[0.60],
    'MUTAG':[(92.09, 40.0),(92.31,37.63),(89.35,44.02)],
    'REDDIT-BINARY':[0.66],
    'IMDB-BINARY':[]
}

In [None]:
misclassification = {
    'NCI1':[0.601,],
    'PROTEINS': [0.58, 0.54],
    'DD':[0.60],
    'MUTAG':[0.39, 0.409],
    'REDDIT-BINARY':[0.66],
    'IMDB-BINARY':[]
}

In [None]:
best_ckpt['Aug_Acc'] = avg_correct
torch.save(best_ckpt,"{}_SUPERVISED.pkl".format(DS))