In [None]:
import argparse
import time
import shutil
import os
import os.path as osp
import csv
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
import torchvision.models as models
from data_cnn60 import NTUDataLoaders, AverageMeter, make_dir, get_cases, get_num_classes
from sklearn.metrics import confusion_matrix
from collections import OrderedDict
import torch.nn.functional as F
from cada_vae import Encoder, Decoder, KL_divergence, Wasserstein_distance, reparameterize 

# parser = argparse.ArgumentParser(description='View adaptive')
# parser.add_argument('--ss', type=int, help="split size")
# parser.add_argument('--st', type=str, help="split type")
# parser.add_argument('--dataset', type=str, help="dataset path")
# parser.add_argument('--wdir', type=str, help="directory to save weights path")
# parser.add_argument('--le', type=str, help="language embedding model")
# parser.add_argument('--ve', type=str, help="visual embedding model")
# parser.add_argument('--phase', type=str, help="train or val")
# parser.add_argument('--gpu', type=str, help="gpu device number")
# parser.add_argument('--ntu', type=int, help="ntu120 or ntu60")
# args = parser.parse_args()

gpu = '0'
ss = 5
st = 'r'
dataset_path = 'ntu_results/shift_5_r'
wdir = 'pos_aware_cada_vae_concatenated_latent_space_shift_5_r'
le = 'bert'
ve = 'shift'
phase = 'train'
num_class = 60

os.environ["CUDA_VISIBLE_DEVICES"] = gpu
seed = 5
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
device = torch.device("cuda")
print(torch.cuda.device_count())


if not os.path.exists('/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir):
    os.mkdir('/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir)
if not os.path.exists('/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le):
    os.mkdir('/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le)

if ve == 'vacnn':
    vis_emb_input_size = 2048
elif ve == 'shift':
    vis_emb_input_size = 256
elif ve == 'msg3d':
    vis_emb_input_size = 384
else: 
    pass    
    
text_hidden_size = 100
vis_hidden_size = 100
latent_size = 100

if le == 'bert':
    text_emb_input_size = 1024
    # verb_emb_input_size = 1024
elif le == 'w2v':
    text_emb_input_size = 300
    # verb_emb_input_size = 300
else:
    pass

sequence_encoder = Encoder([vis_emb_input_size, latent_size]).to(device)
sequence_decoder = Decoder([latent_size, vis_emb_input_size]).to(device)
v_text_encoder = Encoder([text_emb_input_size, latent_size//2]).to(device)
v_text_decoder = Decoder([latent_size//2, text_emb_input_size]).to(device)

n_text_encoder = Encoder([text_emb_input_size, latent_size//2]).to(device)
n_text_decoder = Decoder([latent_size//2, text_emb_input_size]).to(device)

params = []
for model in [sequence_encoder, sequence_decoder, v_text_encoder, v_text_decoder, n_text_encoder, n_text_decoder]:
    params += list(model.parameters())

optimizer = optim.Adam(params, lr = 0.0001)

ntu_loaders = NTUDataLoaders(dataset_path, 'max', 1)
train_loader = ntu_loaders.get_train_loader(64, 8)
zsl_loader = ntu_loaders.get_val_loader(64, 8)
val_loader = ntu_loaders.get_test_loader(64, 8)
train_size = ntu_loaders.get_train_size()
zsl_size = ntu_loaders.get_val_size()
val_size = ntu_loaders.get_test_size()
print('Train on %d samples, validate on %d samples' % (train_size, val_size))


labels = np.load('labels.npy')
nouns_vocab = np.load('nouns_vocab.npy')
nouns_ohe = np.load('nouns_ohe.npy')
verbs_vocab = np.load('verbs_vocab.npy')
verbs_ohe = np.load('verbs_ohe.npy')
nouns = nouns_vocab[np.argmax(nouns_ohe, -1)]
verbs = verbs_vocab[np.argmax(verbs_ohe, -1)]

if phase == 'val':
    gzsl_inds = np.load('./label_splits/'+ st + 's' + str(num_class - ss) +'.npy')
    unseen_inds = np.sort(np.load('./label_splits/' + st + 'v' + str(ss) + '_0.npy'))
    seen_inds = np.load('./label_splits/'+ st + 's' + str(num_class -ss - ss) + '_0.npy')
else:
    gzsl_inds = np.arange(num_class)
    unseen_inds = np.sort(np.load('./label_splits/' + st + 'u' + str(ss) + '.npy'))
    seen_inds = np.load('./label_splits/'+ st + 's' + str(num_class  -ss) + '.npy')

unseen_labels = labels[unseen_inds]
seen_labels = labels[seen_inds]

seen_verbs = verbs[seen_inds]
unseen_verbs = verbs[unseen_inds]

seen_nouns = nouns[seen_inds]
unseen_nouns = nouns[unseen_inds]

nouns_emb = torch.from_numpy(np.load(le + '_noun.npy')[:num_class,:]).view([num_class, text_emb_input_size])
nouns_emb = nouns_emb/torch.norm(nouns_emb, dim = 1).view([num_class, 1]).repeat([1, text_emb_input_size])

verbs_emb = torch.from_numpy(np.load(le + '_verb.npy')[:num_class,:]).view([num_class, text_emb_input_size])
verbs_emb = verbs_emb/torch.norm(verbs_emb, dim = 1).view([num_class, 1]).repeat([1, text_emb_input_size])

unseen_nouns_emb = nouns_emb[unseen_inds, :]
seen_nouns_emb = nouns_emb[seen_inds, :]
unseen_verbs_emb = verbs_emb[unseen_inds, :]
seen_verbs_emb = verbs_emb[seen_inds, :]
print("loaded language embeddings")

criterion1 = nn.MSELoss().to(device)

def get_text_data(target):
    return nouns_emb[target].view(target.shape[0], text_emb_input_size).float(), verbs_emb[target].view(target.shape[0], text_emb_input_size).float()

def save_checkpoint(state, filename='checkpoint.pth.tar', is_best=False):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

In [None]:
load_epoch = 3399
se_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/se_'+str(load_epoch)+'.pth.tar'
sd_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/sd_'+str(load_epoch)+'.pth.tar'
vte_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/tve_'+str(load_epoch)+'.pth.tar'
vtd_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/tvd_'+str(load_epoch)+'.pth.tar'
nte_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/tne_'+str(load_epoch)+'.pth.tar'
ntd_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/tnd_'+str(load_epoch)+'.pth.tar'

sequence_encoder.load_state_dict(torch.load(se_checkpoint)['state_dict'])
sequence_decoder.load_state_dict(torch.load(sd_checkpoint)['state_dict'])
v_text_encoder.load_state_dict(torch.load(vte_checkpoint)['state_dict'])
v_text_decoder.load_state_dict(torch.load(vtd_checkpoint)['state_dict'])
n_text_encoder.load_state_dict(torch.load(nte_checkpoint)['state_dict'])
n_text_decoder.load_state_dict(torch.load(ntd_checkpoint)['state_dict'])

In [None]:
for epoch in range(1700, 3400):
    losses = AverageMeter()
    ce_loss_vals = []
    
    # verb models
    sequence_encoder.train()
    sequence_decoder.train()    
    v_text_encoder.train()
    v_text_decoder.train()
    
    # verb params
    k_fact = max((0.1*(epoch-2700)/3000), 0)
    cr_fact = 1*(epoch>3100)
    v_k_fact2 = max((0.1*(epoch-3100)/3000), 0)
    n_k_fact2 = max((0.1*(epoch-3100)/3000), 0)
    v_cr_fact = 1*(epoch>3100)
    n_cr_fact = 1*(epoch>3100)
    v_lw_fact = 0
    n_lw_fact = 0
    
    # noun models
    n_text_encoder.train()
    n_text_decoder.train()
    
    # nouns params
    
    
    (inputs, target) = next(iter(train_loader))
    s = inputs.to(device)
    nt, vt = get_text_data(target)
    nt = nt.to(device)
    vt = vt.to(device)
    
    smu, slv = sequence_encoder(s)
    sz = reparameterize(smu, slv)
    sout = sequence_decoder(sz)

    # noun forward pass
    
    ntmu, ntlv = n_text_encoder(nt)
    ntz = reparameterize(ntmu, ntlv)
    ntout = n_text_decoder(ntz)

    ntfroms = n_text_decoder(sz[:,:50])

    s_recons = criterion1(s, sout)
    nt_recons = criterion1(nt, ntout)
    s_kld = KL_divergence(smu, slv).to(device) 
    nt_kld = KL_divergence(ntmu, ntlv).to(device)
    nt_crecons = criterion1(nt, ntfroms)
    nl_wass = Wasserstein_distance(smu[:, :50], slv[:, :50], ntmu, ntlv)

    
    # verb forward pass
    vtmu, vtlv = v_text_encoder(vt)
    vtz = reparameterize(vtmu, vtlv)
    vtout = v_text_decoder(vtz)

    vtfroms = v_text_decoder(sz[:,50:])
    vt_recons = criterion1(vt, vtout)
    vt_kld = KL_divergence(vtmu, vtlv).to(device)
    vt_crecons = criterion1(vt, vtfroms)
    vl_wass = Wasserstein_distance(smu[:, 50:], slv[:, 50:], vtmu, vtlv)
    
    sfromt = sequence_decoder(torch.cat([ntz, vtz], 1))
    s_crecons = criterion1(s, sfromt)

    loss = s_recons + vt_recons + nt_recons 
    loss -= k_fact*(s_kld) + v_k_fact2*(vt_kld) + n_k_fact2*(nt_kld)
    loss += n_cr_fact*(nt_crecons) + v_cr_fact*(vt_crecons) + cr_fact*(s_crecons)
    loss += v_lw_fact*(vl_wass) + n_lw_fact*(nl_wass)
    
    # backward
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.update(loss.item(), inputs.size(0))
    ce_loss_vals.append(loss.cpu().detach().numpy())
    if epoch % 1 == 0:
        print('---------------------')
        print('Epoch-{:<3d} \t'
            'loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
            epoch, loss=losses))
        print('srecons {:.4f}\t ntrecons {:.4f}\t vtrecons {:.4f}\t'.format(s_recons.item(), nt_recons.item(), vt_recons.item()))
        print('skld {:.4f}\t ntkld {:.4f}\t vtkld {:.4f}\t'.format(s_kld.item(), nt_kld.item(), vt_kld.item()))
        print('screcons {:.4f}\t ntcrecons {:.4f}\t ntcrecons {:.4f}\t'.format(s_crecons.item(), nt_crecons.item(), vt_crecons.item()))        
        print('nlwass {:.4f}\t vlwass {:.4f}\n'.format(nl_wass.item(), vl_wass.item()))

In [None]:
se_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/se_'+str(epoch)+'.pth.tar'
sd_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/sd_'+str(epoch)+'.pth.tar'
tve_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/tve_'+str(epoch)+'.pth.tar'
tvd_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/tvd_'+str(epoch)+'.pth.tar'
tne_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/tne_'+str(epoch)+'.pth.tar'
tnd_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/' + wdir + '/' + le + '/tnd_'+str(epoch)+'.pth.tar'

save_checkpoint({ 'epoch': epoch + 1,
    'state_dict': sequence_encoder.state_dict(),
    'optimizer': optimizer.state_dict()
}, se_checkpoint)
save_checkpoint({ 'epoch': epoch + 1,
    'state_dict': sequence_decoder.state_dict(),
#     'optimizer': optimizer.state_dict()
}, sd_checkpoint)
save_checkpoint({ 'epoch': epoch + 1,
    'state_dict': v_text_encoder.state_dict(),
#     'optimizer': optimizer.state_dict()
}, tve_checkpoint)
save_checkpoint({ 'epoch': epoch + 1,
    'state_dict': v_text_decoder.state_dict(),
#     'optimizer': optimizer.state_dict()
}, tvd_checkpoint)
save_checkpoint({ 'epoch': epoch + 1,
    'state_dict': n_text_encoder.state_dict(),
#     'optimizer': optimizer.state_dict()
}, tne_checkpoint)
save_checkpoint({ 'epoch': epoch + 1,
    'state_dict': n_text_decoder.state_dict(),
#     'optimizer': optimizer.state_dict()
}, tnd_checkpoint)

In [None]:
from cada_vae import MLP

In [None]:
cls = MLP([100, 5]).to(device)

In [None]:
cls_optimizer = optim.Adam(cls.parameters(), lr = 0.001)

In [None]:
with torch.no_grad():
    n_t = unseen_nouns_emb.to(device).float()
    n_t = n_t.repeat([500, 1])
    v_t = unseen_verbs_emb.to(device).float()
    v_t = v_t.repeat([500, 1])
    y = torch.tensor(range(5)).to(device)
    y = y.repeat([500])
    v_text_encoder.eval()
    n_text_encoder.eval() 
    nt_tmu, nt_tlv = n_text_encoder(n_t)
    vt_tmu, vt_tlv = v_text_encoder(v_t)
    vt_z = reparameterize(vt_tmu, vt_tlv)
    nt_z = reparameterize(nt_tmu, nt_tlv)

In [None]:
criterion2 = nn.CrossEntropyLoss().to(device)

In [None]:
# cp = []
best = 0
model_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/'  + wdir + '/' + le + '/cls.pth.tar'
for c_e in range(300):
    cls.train()
    out = cls(torch.cat([nt_z, vt_z], 1))
#     out = cls(vt_z)
    c_loss = criterion2(out, y)
    cls_optimizer.zero_grad()
    c_loss.backward()
    cls_optimizer.step()
    c_acc = float(torch.sum(y == torch.argmax(out, -1)))/(5*500)
#     cp.append(torch.argmax(out, -1))
    print("Train Loss :", c_loss.item())
    print("Train Accuracy:", c_acc)
    cls.eval()
#     v_out = cls(v_tmu)
#     v_acc = float(torch.sum(v_y == torch.argmax(v_out, -1)))/500
#     if v_acc > best:
#         best = v_acc
#         best_epoch = c_e
#         save_checkpoint({ 'epoch': epoch + 1,
#             'state_dict': cls.state_dict(),
#         #     'optimizer': optimizer.state_dict()
#         }, model_checkpoint)
#         print(best_epoch)
#     print("Val Accuracy:", v_acc)

In [None]:
unseen_inds = torch.from_numpy(unseen_inds)
final_embs = []
with torch.no_grad():
    sequence_encoder.eval()
#     n_sequence_encoder.eval()
    cls.eval()
    count = 0
    num = 0
    preds = []
    tars = []
    for (inp, target) in zsl_loader:
        t_s = inp.to(device)
        nt_smu, t_slv = sequence_encoder(t_s)
#         vt_smu, t_slv = v_sequence_encoder(t_s)
        #         t_sz = reparameterize(t_smu, t_slv)
        final_embs.append(nt_smu)
        t_out = cls(nt_smu)
#         t_out = cls(vt_smu)
        pred = torch.argmax(t_out, -1)
        preds.append(unseen_inds[pred])
        tars.append(target)
        count += torch.sum(unseen_inds[pred] == target)
        num += len(target)
    print(float(count)/num)

In [None]:
final_embs = np.array([j.cpu().numpy() for i in final_embs for j in i])

In [None]:
p = [j.item() for i in preds for j in i]

In [None]:
t = [j.item() for i in tars for j in i]

In [None]:
p = np.array(p)
t = np.array(t)

In [None]:
np.save('/ssd_scratch/cvit/pranay.gupta/synse_5_r_embedding.npy', final_embs)
np.save('/ssd_scratch/cvit/pranay.gupta/synse_5_r_gt.npy', t)

In [None]:
val_out_embs = []
with torch.no_grad():
    sequence_encoder.eval()
    cls.eval()
    count = 0
    num = 0
    preds = []
    tars = []
    for (inp, target) in val_loader:
        t_s = inp.to(device)
        t_smu, t_slv = sequence_encoder(t_s)
#         t_sz = reparameterize(t_smu, t_slv)
#         final_embs.append(t_smu)
        t_out = cls(t_smu)
        val_out_embs.append(F.softmax(t_out))
        pred = torch.argmax(t_out, -1)
        preds.append(unseen_inds[pred])
        tars.append(target)
        count += torch.sum(unseen_inds[pred] == target)
        num += len(target)
    print(float(count)/num)

In [None]:
val_out_embs = np.array([j.cpu().numpy() for i in val_out_embs for j in i])

In [None]:
np.save('/ssd_scratch/cvit/pranay.gupta/unseen_out/synse_5_r_gzsl_zs.npy', val_out_embs)

In [None]:
cmat = confusion_matrix(t, p)
unseen_acc = 0
# seen_acc = 0
for i, val in enumerate(unseen_inds.numpy()):
    unseen_acc += cmat[i, i]/np.sum(cmat[i])
    print(labels[val], ' : ', cmat[i, i]/np.sum(cmat[i]))
    print(labels[unseen_inds.numpy()[np.argsort(cmat[i])[::-1]]])
    print(np.sort(cmat[i])[::-1])

# for i in seen_inds:
#     seen_acc += cmat[i, i]/np.sum(cmat[i])
    
unseen_acc = unseen_acc/ss
# seen_acc = seen_acc/(60-ss)
# h_mean = 2*unseen_acc*seen_acc/(unseen_acc+ seen_acc)
print('\n')
print('unseen_class_accuracy : ', unseen_acc)
# print('seen_class_accuacy : ',  seen_acc)
# print('harmonic_mean : ', h_mean)

In [None]:
final_embs.shape

In [None]:
from cada_vae import MLP
cls = MLP([100, 60]).to(device)
cls_optimizer = optim.Adam(cls.parameters(), lr = 0.001)

In [None]:
seen_feats = {}
for num, (inp, target) in enumerate(train_loader):
    for i, label in enumerate(target):
        if label.item() not in seen_feats:
            seen_feats[label.item()] = inp[i, :].view(1, 256)
        else:
            seen_feats[label.item()] = torch.cat([seen_feats[label.item()], inp[i,:].view(1, 256)], 0)

with torch.no_grad():
    n_t = unseen_nouns_emb.to(device).float()
    n_t = n_t.repeat([500, 1])
    v_t = unseen_verbs_emb.to(device).float()
    v_t = v_t.repeat([500, 1])
    y = torch.tensor(range(5)).to(device)
    y = y.repeat([500])
    
    for i, l in enumerate(seen_feats):
        if i == 0:
            s_t = seen_feats[l][sorted(np.random.choice(seen_feats[l].shape[0], 200, replace = False)), :]
            y_s = [l]*200
        else:
            s_t = np.vstack([s_t, seen_feats[l][sorted(np.random.choice(seen_feats[l].shape[0], 200, replace = False)), :]])
            y_s += [l]*200
            
    s_t = torch.from_numpy(s_t).to(device)
    y_s = torch.tensor(y_s).to(device)
    v_text_encoder.eval()
    n_text_encoder.eval() 
    nt_tmu, nt_tlv = n_text_encoder(n_t)
    vt_tmu, vt_tlv = v_text_encoder(v_t)
    vt_z = reparameterize(vt_tmu, vt_tlv)
    nt_z = reparameterize(nt_tmu, nt_tlv)
    
    s_tmu, s_tlv = sequence_encoder(s_t)
    s_z = reparameterize(s_tmu, s_tlv)
    
    f_z = torch.cat([torch.cat([nt_z, vt_z], 1), s_z], 0)
    f_y = torch.cat([y, y_s], 0)
#     v_t = unseen_labels_emb.to(device).repeat([100, 1])
#     v_y = torch.tensor(range(5)).to(device).repeat([100])
#     v_tmu, v_tlv = text_encoder(v_t)

criterion2 = nn.CrossEntropyLoss().to(device)

In [None]:
# cp = []
best = 0
model_checkpoint = '/ssd_scratch/cvit/pranay.gupta/language_modelling/'  + wdir + '/' + le + '/cls.pth.tar'
for c_e in range(2000):
    cls.train()
    out = cls(f_z)
    c_loss = criterion2(out, f_y)
    cls_optimizer.zero_grad()
    c_loss.backward()
    cls_optimizer.step()
    c_acc = float(torch.sum(f_y == torch.argmax(out, -1)))/13000
#     cp.append(torch.argmax(out, -1))
    print("Train Loss :", c_loss.item())
    print("Train Accuracy:", c_acc)

In [None]:
gzsl_inds = torch.from_numpy(gzsl_inds)
final_embs = []
with torch.no_grad():
    sequence_encoder.eval()
    cls.eval()
    count = 0
    num = 0
    preds = []
    tars = []
    for (inp, target) in val_loader:
        t_s = inp.to(device)
        t_smu, t_slv = sequence_encoder(t_s)
#         t_sz = reparameterize(t_smu, t_slv)
        final_embs.append(t_smu)
        t_out = cls(t_smu)
        pred = torch.argmax(t_out, -1)
        preds.append(gzsl_inds[pred])
        tars.append(target)
        count += torch.sum(gzsl_inds[pred] == target)
        num += len(target)
    print(float(count)/num)

In [None]:
# final_embs = np.array([j.cpu().numpy() for i in final_embs for j in i])
p = [j.item() for i in preds for j in i]
t = [j.item() for i in tars for j in i]
p = np.array(p)
t = np.array(t)

cmat = confusion_matrix(t, p)
unseen_acc = 0
seen_acc = 0
for i, val in enumerate(unseen_inds):
    unseen_acc += cmat[val, val]/np.sum(cmat[val])
    print(labels[val], ' : ', cmat[val, val]/np.sum(cmat[val]))
    print(labels[gzsl_inds.numpy()[np.argsort(cmat[val])[::-1]]])
    print(np.sort(cmat[val])[::-1])

for i in seen_inds:
    seen_acc += cmat[i, i]/np.sum(cmat[i])
    
unseen_acc = unseen_acc/ss
seen_acc = seen_acc/(60-ss)
h_mean = 2*unseen_acc*seen_acc/(unseen_acc+ seen_acc)
print('\n')
print('unseen_class_accuracy : ', unseen_acc)
print('seen_class_accuacy : ',  seen_acc)
print('harmonic_mean : ', h_mean)

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

In [None]:
t_plot = []
for i in t:
    t_plot.append(np.argwhere(unseen_inds == i).flatten()[0])

In [None]:
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=1500)
tsne_results = tsne.fit_transform(final_embs)
plt.figure(figsize=(5,5))
plt.scatter(tsne_results[:,0], tsne_results[:,1], c = t_plot, cmap='Dark2')

for i in range(5):
    plt.annotate(labels[t[i]], (tsne_results[i, 0], tsne_results[i, 1]))
plt.show()

In [None]:
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=500)
tsne_results = tsne.fit_transform(tz.detach().cpu().numpy()[inds, :])
plt.figure(figsize=(5,5))
plt.scatter(tsne_results[:,0], tsne_results[:,1], c = target.detach().cpu().numpy()[inds], cmap='plasma')
plt.show()

In [None]:
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=500)
tsne_results = tsne.fit_transform(sout.detach().cpu().numpy())
plt.figure(figsize=(5,5))
plt.scatter(tsne_results[:,0], tsne_results[:,1], c = target.detach().cpu().numpy(), cmap='plasma')
plt.show()

In [None]:
torch.min(s)

In [None]:
smu

In [None]:
sout[0]

In [None]:
torch.max((s[0] - sout[0])**2)

In [None]:
max(torch.mean(smu, 1))

In [None]:
max(torch.mean(slv, 1))

In [None]:
 sigma = torch.exp(0.5*slv)

In [None]:
eps = torch.FloatTensor(sigma.size()[0], 1).normal_(0, 1).expand(sigma.size()).cuda()

In [None]:
sz_test = eps*sigma + smu

In [None]:
sz

In [None]:
ind = []
for num, t in enumerate(target):
    if t == 0:
        ind.append(num)

In [None]:
smu[ind, :]

In [None]:
max(torch.mean(tmu, 1))

In [None]:
max(torch.mean(tlv, 1))