In [None]:
import argparse
import time
import shutil
import os
import os.path as osp
import csv
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
import torchvision.models as models
from resnext_specialist import VA
from data_cnn60 import NTUDataLoaders, AverageMeter, make_dir, get_cases, get_num_classes
from sklearn.metrics import confusion_matrix
from collections import OrderedDict
import torch.nn.functional as F
from cada_vae import Encoder, Decoder, KL_divergence, Wasserstein_distance, reparameterize, triplet_loss

# parser = argparse.ArgumentParser(description='View adaptive')
# parser.add_argument('--ss', type=int, help="split size")
# parser.add_argument('--st', type=str, help="split type")
# parser.add_argument('--dataset', type=str, help="dataset path")
# parser.add_argument('--wdir', type=str, help="directory to save weights path")
# parser.add_argument('--le', type=str, help="language embedding model")
# parser.add_argument('--ve', type=str, help="visual embedding model")
# parser.add_argument('--phase', type=str, help="train or val")
# parser.add_argument('--gpu', type=str, help="gpu device number")
# parser.add_argument('--ntu', type=int, help="ntu120 or ntu60")
# args = parser.parse_args()

gpu = '0'
ss = 5
st = 'r'
dataset_path = 'ntu_results/shift_5_r'
wdir = 'cada_vae_shift_5_r'
le = 'bert'
ve = 'shift'
phase = 'train'
num_class = 60

os.environ["CUDA_VISIBLE_DEVICES"] = gpu
seed = 5
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
device = torch.device("cuda")
print(torch.cuda.device_count())


if not os.path.exists('/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir):
    os.mkdir('/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir)
if not os.path.exists('/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le):
    os.mkdir('/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le)

if ve == 'vacnn':
    vis_emb_input_size = 2048
elif ve == 'shift':
    vis_emb_input_size = 256
elif ve == 'msg3d':
    vis_emb_input_size = 384
else: 
    pass    
    
text_hidden_size = 100
vis_hidden_size = 100
latent_size = 50

if le == 'bert':
    text_emb_input_size = 1024
    # verb_emb_input_size = 1024
elif le == 'w2v':
    text_emb_input_size = 300
    # verb_emb_input_size = 300
else:
    pass

sequence_encoder = Encoder([vis_emb_input_size, vis_hidden_size, latent_size]).to(device)
sequence_decoder = Decoder([latent_size, vis_hidden_size, vis_emb_input_size]).to(device)
text_encoder = Encoder([text_emb_input_size, latent_size]).to(device)
text_decoder = Decoder([latent_size, text_emb_input_size]).to(device)

params = []
for model in [sequence_encoder, sequence_decoder, text_encoder, text_decoder]:
    params += list(model.parameters())

optimizer = optim.Adam(params, lr = 0.0001)
# NounPosMmen_scheduler = ReduceLROnPlateau(NounPosMmen_optimizer, mode='max', factor=0.1, patience=14, cooldown=6, verbose=True)

ntu_loaders = NTUDataLoaders(dataset_path, 'max', 1)
train_loader = ntu_loaders.get_train_loader(64, 8)
zsl_loader = ntu_loaders.get_val_loader(64, 8)
val_loader = ntu_loaders.get_test_loader(64, 8)
train_size = ntu_loaders.get_train_size()
zsl_size = ntu_loaders.get_val_size()
val_size = ntu_loaders.get_test_size()
print('Train on %d samples, validate on %d samples' % (train_size, val_size))


labels = np.load('labels.npy')

if phase == 'val':
    gzsl_inds = np.load('./label_splits/'+ st + 's' + str(num_class - ss) +'.npy')
    unseen_inds = np.sort(np.load('./label_splits/' + st + 'v' + str(ss) + '_0.npy'))
    seen_inds = np.load('./label_splits/'+ st + 's' + str(num_class -ss - ss) + '_0.npy')
else:
    gzsl_inds = np.arange(num_class)
    unseen_inds = np.sort(np.load('./label_splits/' + st + 'u' + str(ss) + '.npy'))
    seen_inds = np.load('./label_splits/'+ st + 's' + str(num_class  -ss) + '.npy')

unseen_labels = labels[unseen_inds]
seen_labels = labels[seen_inds]

labels_emb = torch.from_numpy(np.load(le + '_labels.npy')[:num_class,:]).view([num_class, text_emb_input_size])
labels_emb = labels_emb/torch.norm(labels_emb, dim = 1).view([num_class, 1]).repeat([1, text_emb_input_size])

unseen_labels_emb = labels_emb[unseen_inds, :]
seen_labels_emb = labels_emb[seen_inds, :]
print("loaded language embeddings")

criterion1 = nn.MSELoss().to(device)

def get_text_data(target, labels_emb):
    return labels_emb[target].view(target.shape[0], text_emb_input_size).float()

def save_checkpoint(state, filename='checkpoint.pth.tar', is_best=False):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

In [None]:
load_epoch = 8499
se_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/se_trip'+str(load_epoch)+'.pth.tar'
sd_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/sd_trip'+str(load_epoch)+'.pth.tar'
te_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/te_trip'+str(load_epoch)+'.pth.tar'
td_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/td_trip'+str(load_epoch)+'.pth.tar'

sequence_encoder.load_state_dict(torch.load(se_checkpoint)['state_dict'])
sequence_decoder.load_state_dict(torch.load(sd_checkpoint)['state_dict'])
text_encoder.load_state_dict(torch.load(te_checkpoint)['state_dict'])
text_decoder.load_state_dict(torch.load(td_checkpoint)['state_dict'])

In [None]:
for epoch in range(8500, 10200):
    losses = AverageMeter()
    ce_loss_vals = []
    sequence_encoder.train()
    sequence_decoder.train()    
    text_encoder.train()
    text_decoder.train()
    k_trip = 0
    k_fact = max((0.1*(epoch-9500)/3000), 0)
#     k_fact2 = max((0.1*(epoch-3100)/3000), 0)
    k_fact2 = k_fact*(epoch>9900)
    cr_fact = 1*(epoch>9900)
    lw_fact = 0
    (inputs, target) = next(iter(train_loader))
    s = inputs.to(device)
    t = get_text_data(target, labels_emb).to(device)
    smu, slv = sequence_encoder(s)
    sz = reparameterize(smu, slv)
    sout = sequence_decoder(sz)

    tmu, tlv = text_encoder(t)
    tz = reparameterize(tmu, tlv)
    tout = text_decoder(tz)

    # cross reconstruction
    tfroms = text_decoder(sz)
    sfromt = sequence_decoder(tz)

    s_triplet = triplet_loss(smu, target, device)
#     t_triplet = triplet_loss(tmu, target, device)
    s_recons = criterion1(s, sout)
    t_recons = criterion1(t, tout)
    s_kld = KL_divergence(smu, slv).to(device) 
    t_kld = KL_divergence(tmu, tlv).to(device)
    s_crecons = criterion1(s, sfromt)
    t_crecons = criterion1(t, tfroms)
    l_wass = Wasserstein_distance(smu, slv, tmu, tlv)
    

    loss = s_recons + t_recons 
    loss += k_trip*s_triplet
    loss -= k_fact*(s_kld)
    loss -= k_fact2*(t_kld)
    loss += cr_fact*(s_crecons + t_crecons)
    loss += lw_fact*(l_wass)

    # backward
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.update(loss.item(), inputs.size(0))
    ce_loss_vals.append(loss.cpu().detach().numpy())
    if epoch % 1 == 0:
        print('Epoch-{:<3d} \t'
            'loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
            epoch, loss=losses))
        print('srecons {:.4f}\ttrecons {:.4f}\t'.format(s_recons.item(), t_recons.item()))
        print('skld {:.4f}\ttkld {:.4f}\t'.format(s_kld.item(), t_kld.item()))
        print('screcons {:.4f}\ttcrecons {:.4f}\t'.format(s_crecons.item(), t_crecons.item()))        
        print('lwass {:.4f}\t'.format(l_wass.item()))
        print('strip {:.4f}\t'.format(s_triplet.item()))


In [None]:
se_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/se_trip'+str(epoch)+'.pth.tar'
sd_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/sd_trip'+str(epoch)+'.pth.tar'
te_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/te_trip'+str(epoch)+'.pth.tar'
td_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/td_trip'+str(epoch)+'.pth.tar'
save_checkpoint({ 'epoch': epoch + 1,
    'state_dict': sequence_encoder.state_dict(),
    'optimizer': optimizer.state_dict()
}, se_checkpoint)
save_checkpoint({ 'epoch': epoch + 1,
    'state_dict': sequence_decoder.state_dict(),
#     'optimizer': optimizer.state_dict()
}, sd_checkpoint)
save_checkpoint({ 'epoch': epoch + 1,
    'state_dict': text_encoder.state_dict(),
#     'optimizer': optimizer.state_dict()
}, te_checkpoint)
save_checkpoint({ 'epoch': epoch + 1,
    'state_dict': text_decoder.state_dict(),
#     'optimizer': optimizer.state_dict()
}, td_checkpoint)

In [None]:
a = smu.detach().cpu().numpy()
b = slv.detach().cpu().numpy()
c = tmu.detach().cpu().numpy()
d = tlv.detach().cpu().numpy()

In [None]:
from cada_vae import MLP

In [None]:
cls = MLP([50, 5]).to(device)

In [None]:
cls_optimizer = optim.Adam(cls.parameters(), lr = 0.001)

In [None]:
with torch.no_grad():
    c_t = unseen_labels_emb.to(device)
    c_t = c_t.repeat([500, 1])
    y = torch.tensor(range(5)).to(device)
    y = y.repeat([500])
    text_encoder.eval()
    t_tmu, t_tlv = text_encoder(c_t)
    t_z = reparameterize(t_tmu, t_tlv)
    v_t = unseen_labels_emb.to(device).repeat([100, 1])
    v_y = torch.tensor(range(5)).to(device).repeat([100])
    v_tmu, v_tlv = text_encoder(v_t)

In [None]:
criterion2 = nn.CrossEntropyLoss().to(device)

In [None]:
# cp = []
best = 0
model_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/'  + wdir + '/' + le + '/cls.pth.tar'
for c_e in range(300):
    cls.train()
    out = cls(t_z)
    c_loss = criterion2(out, y)
    cls_optimizer.zero_grad()
    c_loss.backward()
    cls_optimizer.step()
    c_acc = float(torch.sum(y == torch.argmax(out, -1)))/2500
#     cp.append(torch.argmax(out, -1))
    print("Train Loss :", c_loss.item())
    print("Train Accuracy:", c_acc)
    cls.eval()
    v_out = cls(v_tmu)
    v_acc = float(torch.sum(v_y == torch.argmax(v_out, -1)))/500
    if v_acc > best:
        best = v_acc
        best_epoch = c_e
#         save_checkpoint({ 'epoch': epoch + 1,
#             'state_dict': cls.state_dict(),
#         #     'optimizer': optimizer.state_dict()
#         }, model_checkpoint)
        print(best_epoch)
    print("Val Accuracy:", v_acc)

In [None]:
unseen_inds = torch.from_numpy(unseen_inds)
final_embs = []
with torch.no_grad():
    sequence_encoder.eval()
    cls.eval()
    count = 0
    num = 0
    preds = []
    tars = []
    for (inp, target) in zsl_loader:
        t_s = inp.to(device)
        t_smu, t_slv = sequence_encoder(t_s)
#         t_sz = reparameterize(t_smu, t_slv)
        final_embs.append(t_smu)
        t_out = cls(t_smu)
        pred = torch.argmax(t_out, -1)
        preds.append(unseen_inds[pred])
        tars.append(target)
        count += torch.sum(unseen_inds[pred] == target)
        num += len(target)
    print(float(count)/num)

In [None]:
final_embs = np.array([j.cpu().numpy() for i in final_embs for j in i])

In [None]:
p = [j.item() for i in preds for j in i]

In [None]:
t = [j.item() for i in tars for j in i]

In [None]:
p = np.array(p)
t = np.array(t)

In [None]:
np.save('/ssd_scratch/cvit/pranay.gupta/umap_embeddings/cadavae_5_r_embedding.npy', final_embs)
np.save('/ssd_scratch/cvit/pranay.gupta/umap_embeddings/cadavae_5_r_gt.npy', t)

In [None]:
val_out_embs = []
with torch.no_grad():
    sequence_encoder.eval()
    cls.eval()
    count = 0
    num = 0
    preds = []
    tars = []
    for (inp, target) in val_loader:
        t_s = inp.to(device)
        t_smu, t_slv = sequence_encoder(t_s)
#         t_sz = reparameterize(t_smu, t_slv)
#         final_embs.append(t_smu)
        t_out = cls(t_smu)
        val_out_embs.append(F.softmax(t_out))
        pred = torch.argmax(t_out, -1)
        preds.append(unseen_inds[pred])
        tars.append(target)
        count += torch.sum(unseen_inds[pred] == target)
        num += len(target)
    print(float(count)/num)

In [None]:
val_out_embs = np.array([j.cpu().numpy() for i in val_out_embs for j in i])

In [None]:
val_out_embs.shape

In [None]:
np.save('/ssd_scratch/cvit/pranay.gupta/unseen_out/cadavae_5_r_gzsl_zs.npy', val_out_embs)

In [None]:
cmat = confusion_matrix(t, p)
unseen_acc = 0
# seen_acc = 0
for i, val in enumerate(unseen_inds.numpy()):
    unseen_acc += cmat[i, i]/np.sum(cmat[i])
    print(labels[val], ' : ', cmat[i, i]/np.sum(cmat[i]))
    print(labels[unseen_inds.numpy()[np.argsort(cmat[i])[::-1]]])
    print(np.sort(cmat[i])[::-1])

# for i in seen_inds:
#     seen_acc += cmat[i, i]/np.sum(cmat[i])
    
unseen_acc = unseen_acc/ss
# seen_acc = seen_acc/(60-ss)
# h_mean = 2*unseen_acc*seen_acc/(unseen_acc+ seen_acc)
print('\n')
print('unseen_class_accuracy : ', unseen_acc)
# print('seen_class_accuacy : ',  seen_acc)
# print('harmonic_mean : ', h_mean)

In [None]:
from cada_vae import MLP
cls = MLP([50, 60]).to(device)
cls_optimizer = optim.Adam(cls.parameters(), lr = 0.001)

In [None]:
seen_feats = {}
for num, (inp, target) in enumerate(train_loader):
    for i, label in enumerate(target):
        if label.item() not in seen_feats:
            seen_feats[label.item()] = inp[i, :].view(1, 256)
        else:
            seen_feats[label.item()] = torch.cat([seen_feats[label.item()], inp[i,:].view(1, 256)], 0)

with torch.no_grad():
    c_t = unseen_labels_emb.to(device)
    c_t = c_t.repeat([500, 1])
    
    y = torch.tensor(range(5)).to(device)
    y = y.repeat([500])
    
    for i, l in enumerate(seen_feats):
        if i == 0:
            s_t = seen_feats[l][sorted(np.random.choice(seen_feats[l].shape[0], 200, replace = False)), :]
            y_s = [l]*200
        else:
            s_t = np.vstack([s_t, seen_feats[l][sorted(np.random.choice(seen_feats[l].shape[0], 200, replace = False)), :]])
            y_s += [l]*200
            
    s_t = torch.from_numpy(s_t).to(device)
    y_s = torch.tensor(y_s).to(device)
    text_encoder.eval()
    sequence_encoder.eval()
    t_tmu, t_tlv = text_encoder(c_t)
    t_z = reparameterize(t_tmu, t_tlv)
    
    s_tmu, s_tlv = sequence_encoder(s_t)
    s_z = reparameterize(s_tmu, s_tlv)
    
    f_z = torch.cat([t_z, s_z], 0)
    f_y = torch.cat([y, y_s], 0)
#     v_t = unseen_labels_emb.to(device).repeat([100, 1])
#     v_y = torch.tensor(range(5)).to(device).repeat([100])
#     v_tmu, v_tlv = text_encoder(v_t)

criterion2 = nn.CrossEntropyLoss().to(device)

In [None]:
# cp = []
best = 0
model_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/'  + wdir + '/' + le + '/cls.pth.tar'
for c_e in range(2000):
    cls.train()
    out = cls(f_z)
    c_loss = criterion2(out, f_y)
    cls_optimizer.zero_grad()
    c_loss.backward()
    cls_optimizer.step()
    c_acc = float(torch.sum(f_y == torch.argmax(out, -1)))/13000
#     cp.append(torch.argmax(out, -1))
    print("Train Loss :", c_loss.item())
    print("Train Accuracy:", c_acc)

In [None]:
# gzsl_inds = torch.from_numpy(gzsl_inds)
final_embs = []
with torch.no_grad():
    sequence_encoder.eval()
    cls.eval()
    count = 0
    num = 0
    preds = []
    tars = []
    for (inp, target) in val_loader:
        t_s = inp.to(device)
        t_smu, t_slv = sequence_encoder(t_s)
#         t_sz = reparameterize(t_smu, t_slv)
        final_embs.append(t_smu)
        t_out = cls(t_smu)
        pred = torch.argmax(t_out, -1)
        preds.append(gzsl_inds[pred])
        tars.append(target)
        count += torch.sum(gzsl_inds[pred] == target)
        num += len(target)
    print(float(count)/num)

In [None]:
# final_embs = np.array([j.cpu().numpy() for i in final_embs for j in i])
p = [j.item() for i in preds for j in i]
t = [j.item() for i in tars for j in i]
p = np.array(p)
t = np.array(t)

cmat = confusion_matrix(t, p)
unseen_acc = 0
seen_acc = 0
for i, val in enumerate(unseen_inds):
    unseen_acc += cmat[val, val]/np.sum(cmat[val])
    print(labels[val], ' : ', cmat[val, val]/np.sum(cmat[val]))
    print(labels[gzsl_inds.numpy()[np.argsort(cmat[val])[::-1]]])
    print(np.sort(cmat[val])[::-1])

for i in seen_inds:
    seen_acc += cmat[i, i]/np.sum(cmat[i])
    
unseen_acc = unseen_acc/ss
seen_acc = seen_acc/(60-ss)
h_mean = 2*unseen_acc*seen_acc/(unseen_acc+ seen_acc)
print('\n')
print('unseen_class_accuracy : ', unseen_acc)
print('seen_class_accuacy : ',  seen_acc)
print('harmonic_mean : ', h_mean)

In [None]:
unseen_inds

In [None]:
cmat[19]

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

In [None]:
t_plot = []
for i in t:
    t_plot.append(np.argwhere(unseen_inds == i).flatten()[0])

In [None]:
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=1500)
tsne_results = tsne.fit_transform(final_embs)
plt.figure(figsize=(5,5))
plt.scatter(tsne_results[:,0], tsne_results[:,1], c = t_plot, cmap='Dark2')

for i in range(5):
    plt.annotate(labels[t[i]], (tsne_results[i, 0], tsne_results[i, 1]))
plt.show()

In [None]:
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=500)
tsne_results = tsne.fit_transform(tz.detach().cpu().numpy()[inds, :])
plt.figure(figsize=(5,5))
plt.scatter(tsne_results[:,0], tsne_results[:,1], c = target.detach().cpu().numpy()[inds], cmap='plasma')
plt.show()

In [None]:
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=500)
tsne_results = tsne.fit_transform(sout.detach().cpu().numpy())
plt.figure(figsize=(5,5))
plt.scatter(tsne_results[:,0], tsne_results[:,1], c = target.detach().cpu().numpy(), cmap='plasma')
plt.show()

In [None]:
torch.min(s)

In [None]:
smu

In [None]:
sout[0]

In [None]:
torch.max((s[0] - sout[0])**2)

In [None]:
max(torch.mean(smu, 1))

In [None]:
max(torch.mean(slv, 1))

In [None]:
 sigma = torch.exp(0.5*slv)

In [None]:
eps = torch.FloatTensor(sigma.size()[0], 1).normal_(0, 1).expand(sigma.size()).cuda()

In [None]:
sz_test = eps*sigma + smu

In [None]:
sz

In [None]:
ind = []
for num, t in enumerate(target):
    if t == 0:
        ind.append(num)

In [None]:
smu[ind, :]

In [None]:
max(torch.mean(tmu, 1))

In [None]:
max(torch.mean(tlv, 1))

In [None]:
from scipy.spatial.distance import cdist

In [None]:
dists = cdist(unseen_labels_emb, unseen_labels_emb)

In [None]:
unseen_labels

In [None]:
dists

In [None]:
latent_class_embedding = t_z[:5]

In [None]:
latent_class_embedding.shape

In [None]:
dists = cdist(latent_class_embedding.cpu(), latent_class_embedding.cpu())

In [None]:
dists

In [None]:
latent_mu = t_tmu[:5].cpu()

In [None]:
latent_lv = t_tlv[:5].cpu()

In [None]:
cdist(latent_mu, latent_mu)

In [None]:
cdist(latent_lv, latent_lv)

In [None]:
text_encoder