In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.backends.cudnn as cudnn
import numpy as np

torch.manual_seed(0)
np.random.seed(0)
cudnn.deterministic = True
cudnn.benchmark = False

import seaborn as sns
import matplotlib.pyplot as plt
from aug import TUDataset_aug as TUDataset
from torch_geometric.data import DataLoader
from gsimclr_pt import simclr
from arguments import arg_parse
import sys

from sklearn.metrics.pairwise import cosine_similarity


In [3]:
sns.set_style("darkgrid")

## Load Data

In [4]:
### define dataloader to be shared over random and trained models
DS = 'PTC_MR'
sys.argv = [".. ",'--DS={}'.format(DS)]
args = arg_parse() 
args.dataset_num_features = 7
batch_size = 128

dataset = TUDataset("/home/sc/eslubana/graphssl/GraphCL/unsupervised_TU/data/{}".format(DS), name=DS, aug="random2").shuffle()
dataset_eval = TUDataset("/home/sc/eslubana/graphssl/GraphCL/unsupervised_TU/data/{}".format(DS), name=DS, aug="none").shuffle()

dataloader = DataLoader(dataset, batch_size=batch_size,shuffle=False)
dataloader_eval = DataLoader(dataset_eval, batch_size=batch_size,shuffle=False)

In [5]:
### book-keeping for initializing model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Random Plots

In [6]:
### Init random model
model = simclr(32, 3,args).to(device)
model.eval()


### Pull out the labels that will be used to index the TRAINED model in the same way!
emb, y = model.encoder.get_embeddings(dataloader_eval)
pos_num = len(np.where(y == 0)[0])
neg_num = len(np.where(y != 0)[0])

print("Positive samples: ",pos_num)
print("Negative samples: ",neg_num)

RuntimeError: mat1 dim 1 must match mat2 dim 0

In [None]:
### generate two permutations that will be shared across trained and random model
pos_idx = np.where(y==0)
pos_idx_1 = (np.random.permutation(pos_idx[0]),)
assert (pos_idx[0][0] - pos_idx_1[0][0]) != 0

neg_idx = np.where(y!=0)
neg_idx_1 = (np.random.permutation(neg_idx[0]),)
assert (neg_idx[0][0] - neg_idx_1[0][0])  != 0

In [None]:
### SORT the random embeddings into POSITIVE samples, then NEGATIVE SAMPLES
sorted_embs = np.concatenate((emb[pos_idx],emb[neg_idx]))
print("Pos Samples: ",len(pos_idx[0]))
print("Neg Samples: ",len(neg_idx[0]))

In [None]:
### Compute Sim over all pairs. 
sim = cosine_similarity(sorted_embs)
ax =sns.heatmap(sim,center=0.5,square=True)
ax.set_title("Random Init: {} Representation Similarity".format(DS));

In [None]:
### Viz Random, Normalized Embeddings
normed_emb = sorted_embs/ np.linalg.norm(sorted_embs,axis=1,ord=2,keepdims=True)
ax = sns.heatmap(normed_emb,center=0)
ax.set_xticks([],[])
ax.set_yticks([],[])
ax.set_title("Random Init: {} Norm. Embeddings".format(DS));

In [None]:
print("Expected Std: ",1/np.sqrt(96))
print("Std: ",(sorted_embs/ np.linalg.norm(sorted_embs,axis=1,keepdims=True)).std(axis=0).mean())

In [None]:
### represent the embeddings in histogram form (upper triangle only)
pos_v_pos = np.triu(sim[0:len(pos_idx[0]),0:len(pos_idx[0])])
pos_v_pos[np.tril_indices(pos_v_pos.shape[0], -1)] = np.nan

neg_v_neg = np.triu(sim[len(pos_idx[0]):,len(pos_idx[0]):])
neg_v_neg[np.tril_indices(neg_v_neg.shape[0], -1)] = np.nan


pos_v_neg = np.triu(sim[0:len(pos_idx[0]),len(pos_idx[0]):])
pos_v_neg[np.tril_indices(pos_v_neg.shape[0], -1)] = np.nan

assert len(pos_v_pos.reshape(-1)) ==  pos_num * pos_num
assert len(pos_v_neg.reshape(-1)) == pos_num * neg_num
assert len(neg_v_neg.reshape(-1)) == neg_num * neg_num
plt.hist(pos_v_pos[~np.isnan(pos_v_pos)],alpha=0.5,label='PvP')
plt.hist(neg_v_neg[~np.isnan(neg_v_neg)],alpha=0.5,label='NvN')#,density=True)
plt.hist(pos_v_neg[~np.isnan(pos_v_neg)],alpha=0.5,label='PvN')#,density=True)
plt.legend()
plt.xlabel("Similarity")
plt.ylabel("Counts")
p_txt = "Pos vs. Pos Mean: {0:0.4f} Std: {1:.4f}".format(np.nanmean(pos_v_pos),np.nanstd(pos_v_pos))
n_txt = "Neg vs. Neg  Mean: {0:0.4f} Std: {1:.4f}".format(np.nanmean(neg_v_neg),np.nanstd(neg_v_neg))
pvn_txt = "Pos vs. Neg  Mean: {0:0.4f} Std: {1:.4f} ".format(np.nanmean(pos_v_neg),np.nanstd(pos_v_neg))

txt = p_txt + "\n" + n_txt + "\n" + pvn_txt
plt.figtext(0.5, -0.1, txt, wrap=True, horizontalalignment='center', fontsize=12);
plt.title("Random Init: {} Dist of Rep. Similarities".format(DS))
plt.tight_layout()

## Trained Plots

In [None]:
### Load Ckpt for TRAINED MODEL
rep_num = 1
### Init random model
model = simclr(hidden_dim=32, num_gc_layers=3,args=args).to(device)
model.eval()

ckpt = torch.load("{}_{}.pkl".format(DS,rep_num))
#ckpt = torch.load("MUTAG_0.pkl")
model.load_state_dict(ckpt['net'])
model.eval();
print("Using trained model!")

In [None]:
### Get trained embeddings
trained_emb, trained_y = model.encoder.get_embeddings(dataloader_eval)
assert (trained_y - y).sum() == 0, print("Embeddings have been shuffled!")

### Sort w/ same indices are the random model
trained_sorted_embs = np.concatenate((trained_emb[pos_idx],trained_emb[neg_idx]))
trained_sim = cosine_similarity(trained_sorted_embs)

ax =sns.heatmap(trained_sim,center=0.5,square=True)
ax.set_title("Trained: {} Representation Similarity".format(DS));

In [None]:
### Trained Normalized Embeddings
trained_normed_emb = trained_sorted_embs/ np.linalg.norm(trained_sorted_embs,axis=1,ord=2,keepdims=True)
ax = sns.heatmap(trained_normed_emb,center=0)
ax.set_xticks([],[])
ax.set_yticks([],[])
ax.set_title("Trained: {} Norm. Embeddings".format(DS));

In [None]:
### represent the embeddings in histogram form (upper triangle only)
pos_v_pos = np.triu(trained_sim[0:len(pos_idx[0]),0:len(pos_idx[0])])
pos_v_pos[np.tril_indices(pos_v_pos.shape[0], -1)] = np.nan

neg_v_neg = np.triu(trained_sim[len(pos_idx[0]):,len(pos_idx[0]):])
neg_v_neg[np.tril_indices(neg_v_neg.shape[0], -1)] = np.nan


pos_v_neg = np.triu(trained_sim[0:len(pos_idx[0]),len(pos_idx[0]):])
pos_v_neg[np.tril_indices(pos_v_neg.shape[0], -1)] = np.nan

assert len(pos_v_pos.reshape(-1)) ==  pos_num * pos_num
assert len(pos_v_neg.reshape(-1)) == pos_num * neg_num
assert len(neg_v_neg.reshape(-1)) == neg_num * neg_num
plt.hist(pos_v_pos[~np.isnan(pos_v_pos)],alpha=0.5,label='PvP')
plt.hist(neg_v_neg[~np.isnan(neg_v_neg)],alpha=0.5,label='NvN')
plt.hist(pos_v_neg[~np.isnan(pos_v_neg)],alpha=0.5,label='PvN')
plt.legend()
p_txt = "Pos vs. Pos Mean: {0:0.4f} Std: {1:.4f}".format(np.nanmean(pos_v_pos),np.nanstd(pos_v_pos))
n_txt = "Neg vs. Neg  Mean: {0:0.4f} Std: {1:.4f}".format(np.nanmean(neg_v_neg),np.nanstd(neg_v_neg))
pvn_txt = "Pos vs. Neg  Mean: {0:0.4f} Std: {1:.4f} ".format(np.nanmean(pos_v_neg),np.nanstd(pos_v_neg))

txt = p_txt + "\n" + n_txt + "\n" + pvn_txt
plt.figtext(0.5, -0.1, txt, wrap=True, horizontalalignment='center', fontsize=12);

plt.title("Trained: {} Dist of Rep. Similarities".format(DS))

In [None]:
plt.plot(ckpt['stats']['std'],label='Std')
plt.ylabel("Std")
plt.xlabel('Epoch')
plt.title("Trained: {} Std. of Reps.".format(DS,rep_num))
plt.hlines(y=1/np.sqrt(96),xmin=0,xmax=len(ckpt['stats']['std']))
plt.legend()

In [None]:
plt.plot(ckpt['acc']['test'],label='Test Acc')
plt.plot(ckpt['acc']['val'],label='Val Acc')
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Trained: {} Accuracy".format(DS))