In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.backends.cudnn as cudnn
import numpy as np

torch.manual_seed(0)
np.random.seed(0)
cudnn.deterministic = True
cudnn.benchmark = False

import seaborn as sns
import matplotlib.pyplot as plt
from aug import TUDataset_aug as TUDataset
from torch_geometric.data import DataLoader
from gsimclr_pt import simclr, Encoder
from arguments import arg_parse
import sys

from sklearn.metrics.pairwise import cosine_similarity
import tqdm.autonotebook as tqdm

In [3]:
sns.set_style("whitegrid")

In [4]:
class Supervised(torch.nn.Module):
    def __init__(self, args):
        super(Supervised, self).__init__()
    
        self.args = args
        self.encoder = Encoder(args.dataset_num_features, args.hidden_dim, args.num_gc_layers)
        self.embedding_dim = mi_units = args.hidden_dim * args.num_gc_layers
        self.classifer = torch.nn.Linear(self.embedding_dim,args.num_labels)

        if args.bn_int == True:
            self.bn_int = torch.nn.BatchNorm1d(args.hidden_dim) 
        else:
            self.bn_int = None

        self.init_emb()
    def init_emb(self):
        initrange = -1.5 / self.embedding_dim
        for m in self.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.fill_(0.0)
    
    def forward(self, x, edge_index, batch, num_graphs):
        # batch_size = data.num_graphs
        if x is None:
            x = torch.ones(batch.shape[0]).to(device)
        y, M = self.encoder(x, edge_index, batch)
        
        if self.bn_int is not None:
            y = self.bn_int(y)
        y = self.classifer(y)
        return y

In [5]:
def test(net, data_loader, device):
    net.eval()
    correct = 0
    for data in data_loader:
        data,_  = data
        data.to(device)
        pred = net(data.x, data.edge_index, data.batch, data.num_graphs)
        pred = pred.argmax(dim=1)
        target = data.y
        correct += pred.eq(target).sum().item()

    acc = correct / data_loader.dataset.__len__()
    return acc

## Load Data

In [6]:
### define dataloader to be shared over random and trained models
DS = 'MUTAG'

sys.argv = [".. ",'--DS={}'.format(DS)]
args = arg_parse() 
args.seed=3
args.batch_size = 128
args.bn_int = False
args.epochs = 30
args.lr = 0.01
args.num_gc_layers = 3
args.prior =False

if DS == 'COLLAB':
    args.num_labels = 3
    args.aug = 'random4'

else:
    args.num_labels=2
    args.aug = 'random2'


dataset = TUDataset("/home/sc/eslubana/graphssl/GraphCL/unsupervised_TU/data/{}".format(DS), name=DS, aug=args.aug)
dataset_eval = TUDataset("/home/sc/eslubana/graphssl/GraphCL/unsupervised_TU/data/{}".format(DS), name=DS, aug="none")

args.dataset_num_features = dataset_eval[0][0].num_node_features

train_test_split = int(np.floor(len(dataset)*0.9))
idx = list(np.random.permutation(len(dataset_eval)))

print("Train vs. Test: ",train_test_split, len(dataset))

dataloader = DataLoader(dataset[idx[0:train_test_split]], batch_size=args.batch_size,shuffle=False)
dataloader_eval = DataLoader(dataset[idx[train_test_split:]], batch_size=args.batch_size,shuffle=False)

Downloading http://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/MUTAG.zip
Extracting /home/sc/eslubana/graphssl/GraphCL/unsupervised_TU/data/MUTAG/MUTAG/MUTAG.zip
Processing...
Done!
Train vs. Test:  169 188


## MixUp!

In [13]:
for sample in dataloader:
    break

In [14]:
sample = sample[0]

In [20]:
sample.num_graphs

128

In [19]:
sample.edge_index

tensor([[   0,    0,    1,  ..., 2246, 2247, 2248],
        [   1,    9,    0,  ..., 2246, 2247, 2248]])

In [24]:
alpha = 1.
beta = torch.distributions.beta.Beta(alpha, alpha)
#shuffle on batchsize
randind = torch.randperm(sample.num_graphs, device=sample.edge_index.device)
randind

tensor([ 94, 121, 123,  98,  24,  80,  64,   4,  75,  14,   7,  38, 108,  37,
        109,  48,  95, 101,  60,  11,  63, 110,  29,  34,  27,  99,  67, 105,
          0,  97,  92, 119,  35,  89,  72,  25,  30,  26, 120,  32, 126,  68,
         96,  44,   6, 127, 106,  31,  56,  21,  86,  15, 112,  91,  90,  20,
         50,  59, 116,  62, 103,   8,  12,  54,  77,  43,  58,  41,  46,  79,
          5,  17, 107,  76,  28, 114,  70,  45,  69,  83, 122,  88,  36,  53,
         39,  57, 115,  82,  42,  13,  93,  85,   2,   3,  49,  19,   1,  33,
        125,  47,  66, 113,  10, 111,  40, 100, 118,  78,  71,  52, 117, 102,
         51,  81,  16,  61,  18,  74,  87,   9, 104,  65,  22,  55,  23, 124,
         84,  73])

In [26]:
lam = beta.sample([sample.num_graphs]).to(device=sample.edge_index.device)
lam = torch.max(lam, 1. - lam)
lam
#lam_expanded = lam.view([-1] + [1]*(input['adj'].dim()-1))

tensor([0.7665, 0.7160, 0.7792, 0.9706, 0.5464, 0.8986, 0.9390, 0.5141, 0.9137,
        0.8952, 0.6564, 0.5793, 0.6264, 0.7888, 0.6231, 0.6706, 0.7214, 0.8503,
        0.9368, 0.5031, 0.6954, 0.8532, 0.7800, 0.8139, 0.5843, 0.9006, 0.6502,
        0.7048, 0.9701, 0.9596, 0.8273, 0.8744, 0.5138, 0.5473, 0.5457, 0.7727,
        0.7894, 0.5110, 0.6307, 0.8525, 0.9864, 0.6867, 0.6926, 0.9391, 0.7805,
        0.5191, 0.7217, 0.7628, 0.5232, 0.6633, 0.6293, 0.7630, 0.8345, 0.8869,
        0.6988, 0.5396, 0.5240, 0.5166, 0.7689, 0.6458, 0.7861, 0.9998, 0.6164,
        0.7646, 0.6589, 0.6933, 0.9044, 0.8820, 0.5495, 0.5012, 0.6982, 0.8366,
        0.8541, 0.8876, 0.6043, 0.9266, 0.6468, 0.5655, 0.8032, 0.5688, 0.6496,
        0.5074, 0.9670, 0.6721, 0.8590, 0.6314, 0.7908, 0.8863, 0.8927, 0.9261,
        0.8217, 0.5944, 0.5954, 0.7951, 0.8130, 0.7058, 0.7985, 0.9203, 0.7202,
        0.9948, 0.8394, 0.7016, 0.8530, 0.5292, 0.8269, 0.8942, 0.6440, 0.5189,
        0.7588, 0.9449, 0.6244, 0.5461, 

In [34]:
ids, counts = torch.unique(sample.batch,return_counts=True)

In [54]:
end_idx = torch.cat([torch.Tensor([0]),torch.cumsum(counts,dim=0)]).int()
tups = []
for i,j in zip(end_idx, end_idx[1:]):
    tups.append((i.item(),j.item()))

In [74]:
batched_x, batched_mask = geom.utils.to_dense_batch(sample.x, batch=sample.batch)

In [75]:
batched_x.shape, batched_mask.shape

(torch.Size([128, 28, 7]), torch.Size([128, 28]))

In [61]:
lam_expanded = lam.view([-1] + [1]*(3-1))

In [67]:
sample_x = geom.utils.to_dense_batch(samplex, batch=sample.batch, fill_value=0, max_num_nodes=None)

NameError: name 'geom' is not defined

In [77]:
 mixed_mask = torch.logical_or(batched_mask,batched_mask[randind])

In [92]:
sample.batch.type()

'torch.LongTensor'

In [93]:
torch.cat([torch.Tensor([e_num] * m.sum()).long() for e_num,m in enumerate(mixed_mask)])

tensor([  0,   0,   0,  ..., 127, 127, 127])

In [60]:
sample = torch.zeros_like(sample.x)
for i in enumerate(randind):
    sample[tupis[i][0][tup]] = lam_expanded[i] * sample.x[tups[i][0]:tups[i][1],:] + (1. - lam_expanded) * sample.x[tups[randind][0]:tups[randind][1],:]

NameError: name 'lam_expanded' is not defined

In [None]:
#get the indices of x corresponding to each graph

In [None]:
def mixup(input, alpha):
    
    #mix adjacency 
    beta = torch.distributions.beta.Beta(alpha, alpha)
    randind = torch.randperm(input['adj'].shape[0], device=input['adj'].device)
    lam = beta.sample([input['adj'].shape[0]]).to(device=input['adj'].device)
    lam = torch.max(lam, 1. - lam)
    lam_expanded = lam.view([-1] + [1]*(input['adj'].dim()-1))
    mixed_adj = lam_expanded * input['adj'] + (1. - lam_expanded) * input['adj'][randind]
    
    
    #mix features to match!
    beta = torch.distributions.beta.Beta(alpha, alpha)
    randind = torch.randperm(input['x'].shape[0], device=input['x'].device)
    lam = beta.sample([input['x'].shape[0]]).to(device=input['x'].device)
    lam = torch.max(lam, 1. - lam)
    lam_expanded = lam.view([-1] + [1]*(input['x'].dim()-1))
    mixed_x = lam_expanded * input['x'] + (1. - lam_expanded) * input['x'][randind]
    
    #adjust the mask to match the mixed samples
    mixed_mask = torch.logical_or(input['mask'],input['mask'][randind])
    return mixed_adj,mixed_x, mixed_mask, randind, lam


In [29]:
 sparse_mtrx = [geom.utils.dense_to_sparse(a) for a in mixed_adj]

tensor([[   0,    0,    1,  ..., 2246, 2247, 2248],
        [   1,    9,    0,  ..., 2246, 2247, 2248]])

In [7]:
pos = 0
neg = 0
third = 0
for _,d in dataset_eval:
    if d.y.item() == 1:
        pos += 1
    elif d.y.item() == 0:
        neg += 1
    elif d.y.item() == 2:
        third += 1
    else:
        print("ERROR")
#print("Percentage Pos: ",np.round(pos/(pos+neg),4)," Num Pos: ",pos)
#print("Percentage Neg: ",np.round(neg/(pos+neg),4)," Num Neg: ",neg)

print("Percentage Pos: ",np.round(pos/(pos+neg+third),4)," Num Pos: ",pos)
print("Percentage Neg: ",np.round(neg/(pos+neg+third),4)," Num Neg: ",neg)
print("Percentage Pos: ",np.round(third/(pos+neg+third),4)," Num Pos: ",third)


Percentage Pos:  0.155  Num Pos:  775
Percentage Neg:  0.52  Num Neg:  2600
Percentage Pos:  0.325  Num Pos:  1625


In [8]:
### book-keeping for initializing model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device='cpu'

### Train Model

In [9]:
model = Supervised(args)
model.to(device);
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
best_ckpt = {}

In [10]:
best_loss = 1000
for epoch in range(1, args.epochs+1):
    loss_all = 0
    model.train()
    for data in tqdm.tqdm(dataloader):
        #unpack the augmented version
        data, _ = data
        data.to(device)
        pred = model(data.x, data.edge_index, data.batch, data.num_graphs)
        L = torch.nn.CrossEntropyLoss()(pred,data.y)
        
        optimizer.zero_grad()
        L.backward()
        optimizer.step()
        
        loss_all += L.item()
        
    loss_all /= len(dataloader)

    model.eval()
    train_acc= test(model, dataloader, device)
    test_acc= test(model, dataloader_eval, device)

    if loss_all < best_loss:
        best_ckpt['net'] = model.state_dict()
        best_ckpt['epoch'] = epoch
        best_ckpt['train_acc'] = train_acc
        best_ckpt['test_acc'] = test_acc
        best_loss = loss_all
        print("Best Epoch: ",epoch)
    print('Epoch: {0}, Loss: {1:.4f}, Train Acc: {2:.4f}, Test Acc: {3:.4f}'.format(epoch, loss_all,train_acc,test_acc))

HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  1
Epoch: 1, Loss: 3.7693, Train Acc: 0.6509, Test Acc: 0.7895


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  2
Epoch: 2, Loss: 2.0104, Train Acc: 0.6509, Test Acc: 0.7895


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  3
Epoch: 3, Loss: 1.3186, Train Acc: 0.6509, Test Acc: 0.7895


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  4
Epoch: 4, Loss: 0.6465, Train Acc: 0.6509, Test Acc: 0.7895


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  5
Epoch: 5, Loss: 0.6228, Train Acc: 0.6509, Test Acc: 0.7895


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  6
Epoch: 6, Loss: 0.3825, Train Acc: 0.6450, Test Acc: 0.7895


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  7
Epoch: 7, Loss: 0.3350, Train Acc: 0.6509, Test Acc: 0.7895


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  8
Epoch: 8, Loss: 0.2325, Train Acc: 0.6509, Test Acc: 0.7895


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  9
Epoch: 9, Loss: 0.2246, Train Acc: 0.6627, Test Acc: 0.7895


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  10
Epoch: 10, Loss: 0.1920, Train Acc: 0.6746, Test Acc: 0.8421


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  11
Epoch: 11, Loss: 0.1789, Train Acc: 0.7041, Test Acc: 0.8421


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  12
Epoch: 12, Loss: 0.1617, Train Acc: 0.7278, Test Acc: 0.8421


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  13
Epoch: 13, Loss: 0.1599, Train Acc: 0.7278, Test Acc: 0.8947


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  14
Epoch: 14, Loss: 0.1458, Train Acc: 0.7456, Test Acc: 0.8947


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  15
Epoch: 15, Loss: 0.1401, Train Acc: 0.7751, Test Acc: 0.7895


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  16
Epoch: 16, Loss: 0.1328, Train Acc: 0.7929, Test Acc: 0.8421


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  17
Epoch: 17, Loss: 0.1259, Train Acc: 0.7929, Test Acc: 0.8421


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  18
Epoch: 18, Loss: 0.1200, Train Acc: 0.7929, Test Acc: 0.8947


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  19
Epoch: 19, Loss: 0.1117, Train Acc: 0.7988, Test Acc: 0.8947


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  20
Epoch: 20, Loss: 0.1057, Train Acc: 0.8107, Test Acc: 0.8947


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  21
Epoch: 21, Loss: 0.1024, Train Acc: 0.8225, Test Acc: 0.8947


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  22
Epoch: 22, Loss: 0.0956, Train Acc: 0.8462, Test Acc: 0.8947


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  23
Epoch: 23, Loss: 0.0888, Train Acc: 0.8462, Test Acc: 0.7895


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  24
Epoch: 24, Loss: 0.0886, Train Acc: 0.8580, Test Acc: 0.8421


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Epoch: 25, Loss: 0.0894, Train Acc: 0.8639, Test Acc: 0.8947


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  26
Epoch: 26, Loss: 0.0854, Train Acc: 0.8521, Test Acc: 0.8421


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Epoch: 27, Loss: 0.0903, Train Acc: 0.8994, Test Acc: 0.8947


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  28
Epoch: 28, Loss: 0.0849, Train Acc: 0.9172, Test Acc: 0.8947


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  29
Epoch: 29, Loss: 0.0774, Train Acc: 0.8876, Test Acc: 0.8421


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


Best Epoch:  30
Epoch: 30, Loss: 0.0771, Train Acc: 0.8994, Test Acc: 0.8947


In [11]:
print("Best Epoch: {0} Train Acc: {1:.4f} Test Acc: {2:.4f}".format(best_ckpt['epoch'],best_ckpt['train_acc'],best_ckpt['test_acc']))
model.load_state_dict(best_ckpt['net'])
model.eval();

Best Epoch: 30 Train Acc: 0.8994 Test Acc: 0.8947


### Compute Accuracies on Augmentations

In [12]:
num_correct = 0
avg_correct = 0
accs 
for epoch in range(1, 5+1):
    num_correct=0
    for data_aug in dataloader:
        #unpack the augmented version
        _, data_aug = data_aug
        data_aug.to(device)
        pred = model(data_aug.x, data_aug.edge_index, data_aug.batch, data_aug.num_graphs)
        pred = pred.argmax(dim=1)
        target = data_aug.y
        num_correct += pred.eq(target).sum().item()
    avg_correct += num_correct / dataloader.dataset.__len__()
    print('Running Average: ',num_correct / dataloader.dataset.__len__())


Running Average:  0.4260355029585799
Running Average:  0.40236686390532544
Running Average:  0.40828402366863903
Running Average:  0.40828402366863903
Running Average:  0.40236686390532544


In [13]:
avg_correct = avg_correct / 5
print("*"*10)
print("Average Correct: ",avg_correct*100)
print("Average Correct: ",avg_correct*100)

**********
Average Correct:  0.4094674556213017


## Create Plots

In [None]:
misclassification = {
    'NCI1':[0.601],
    'PROTEINS': [0.58],
    'DD':[0.60],
    'MUTAG':[0.39],
    'REDDIT-BINARY':[0.66],
    'IMDB-BINARY':[]
}

In [None]:
best_ckpt['Aug_Acc'] = avg_correct
torch.save(best_ckpt,"{}_SUPERVISED.pkl".format(DS))