In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import os, glob, json
import config as cfg
from IPython.display import display
from ipywidgets import widgets, Layout
from module.relgan_d import RelSpaceGAN_D
from module.cluster import Cluster
from module.relgan_g import RelSpaceG
from dataset import TextSubspaceDataset, seq_collate
from constant import Constants
from utils import get_fixed_temperature, get_losses
from sklearn.cluster import KMeans
import numpy as np
from tensorboardX import SummaryWriter
from utils import gradient_penalty, str2bool, chunks
from sklearn.manifold import SpectralEmbedding
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction
from shutil import copyfile
from collections import namedtuple

In [7]:
class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

In [19]:
log_name = 'save/subrelgan-gp-kmeans-2020-04-08-05-55-22'
iteration = 1000
K_BINS = 20
checkpoint_name = 'relgan_G_{}.pt'.format(iteration)

In [20]:
assert os.path.exists(os.path.join(log_name, 'relgan_g.py')) == True
assert os.path.exists(os.path.join(log_name, checkpoint_name)) == True

copyfile(os.path.join(log_name, 'relgan_g.py'), 'module/temp_g.py')
from module.temp_g import RelSpaceG
checkpoint = torch.load(os.path.join(log_name, checkpoint_name))
# checkpoint should contain model state dict,
assert len(checkpoint) == 3
with open(os.path.join(log_name, 'params.json'), 'r') as f:
    params = json.load(f)
args = Struct(**params)

p = checkpoint['p']
latent = checkpoint['latent']
dataset = TextSubspaceDataset(-1, 'data/kkday_dataset/train_title.txt', prefix='train_title', embedding=None, 
    max_length=args.max_seq_len, force_fix_len=args.grad_penalty or args.full_text, k_bins=K_BINS, token_level=args.tokenize)
dataset.p = p
dataset.latent = latent
dataloader = torch.utils.data.DataLoader(dataset, num_workers=4,
                collate_fn=seq_collate, batch_size=args.batch_size, shuffle=True)
model = RelSpaceG(args.mem_slots, args.num_heads, args.head_size, args.gen_embed_dim, 
        args.gen_hidden_dim, dataset.vocab_size,
        k_bins=K_BINS, latent_dim=args.gen_latent_dim, noise_dim=args.gen_noise_dim,
        max_seq_len=args.max_seq_len-1, padding_idx=Constants.PAD, gpu=True)

model.load_state_dict(checkpoint['model'])
model = model.cuda()
model = model.eval()

In [21]:
init_idx = 0
init_latent = latent[init_idx]
bins_slider = [ widgets.FloatSlider(value=init_latent[i],min=latent[:, i].min(), max=latent[:, i].max(), step=0.1, description='bin-{}'.format(i)) for i in range(K_BINS) ]

display(widgets.HTML(
    value="<br><h3> Play with latent feature </h3> ",
))
for slider in bins_slider:
    display(slider)


HTML(value='<br><h3> Play with latent feature </h3> ')

FloatSlider(value=-1.1459346062308682, description='bin-0', max=15.492585674727687, min=-2.505448355265655)

FloatSlider(value=0.18544452105186307, description='bin-1', max=16.985741354562872, min=-10.218551585503882)

FloatSlider(value=0.10935357696551865, description='bin-2', max=16.55065280398624, min=-16.10252537742402)

FloatSlider(value=-1.449555410873142, description='bin-3', max=17.209029787384345, min=-5.831358691699588)

FloatSlider(value=-0.5006789131883659, description='bin-4', max=18.86475697696866, min=-9.117801530508219)

FloatSlider(value=-0.3343904804139407, description='bin-5', max=17.428554658245435, min=-9.836547563265324)

FloatSlider(value=0.9291275495961163, description='bin-6', max=22.11952880708638, min=-11.933126404039713)

FloatSlider(value=-1.749797414098167, description='bin-7', max=14.561279812024612, min=-14.46340701800207)

FloatSlider(value=-2.7306781962545994, description='bin-8', max=14.648870991234888, min=-9.521448554678706)

FloatSlider(value=1.8631313898973347, description='bin-9', max=15.797646888292546, min=-10.30682843554967)

FloatSlider(value=-2.3930286697600884, description='bin-10', max=13.121045220841836, min=-9.509098406339989)

FloatSlider(value=9.014090855635379, description='bin-11', max=22.800511515694488, min=-11.874670183202966)

FloatSlider(value=5.924824713945221, description='bin-12', max=11.411466476505565, min=-10.678222727543119)

FloatSlider(value=0.520219900165486, description='bin-13', max=16.576370891449276, min=-12.87550915591419)

FloatSlider(value=-0.09892492064259147, description='bin-14', max=13.498296300875742, min=-9.775896894246427)

FloatSlider(value=-1.4862626702597919, description='bin-15', max=15.43072014680294, min=-12.748411691460513)

FloatSlider(value=-0.999282652738437, description='bin-16', max=15.246294368463234, min=-11.208217345719968)

FloatSlider(value=0.4412534232347893, description='bin-17', max=27.243617786136706, min=-7.5497452920886055)

FloatSlider(value=-1.408317607965634, description='bin-18', max=22.70156204611232, min=-13.13339995350316)

FloatSlider(value=3.893719072508225, description='bin-19', max=18.240608421895416, min=-6.758710446310465)

In [23]:
print(init_latent)

[-1.14593461  0.18544452  0.10935358 -1.44955541 -0.50067891 -0.33439048
  0.92912755 -1.74979741 -2.7306782   1.86313139 -2.39302867  9.01409086
  5.92482471  0.5202199  -0.09892492 -1.48626267 -0.99928265  0.44125342
 -1.40831761  3.89371907]


In [1]:
import pickle
with open('latent_variable_20.pkl', 'rb') as f:
    latents = pickle.load(f)
print(latents[0])

[-0.40110661  0.053731    0.02041557 -0.49952886 -0.1492657  -0.04508191
  0.28627305 -0.60933987 -0.83062814  0.60365014 -0.14524659  2.80684088
  2.81072806  0.21474031 -0.14883904 -0.46788561 -0.35034584  0.2746236
 -0.19765725  1.58832746]
