In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import os, glob, json
import config as cfg
from IPython.display import display
from ipywidgets import widgets, Layout
from module.relgan_d import RelSpaceGAN_D
from module.cluster import Cluster
from module.relgan_g import RelSpaceG
from dataset import TextSubspaceDataset, seq_collate
from constant import Constants
from utils import get_fixed_temperature, get_losses
from sklearn.cluster import KMeans
import numpy as np
from tensorboardX import SummaryWriter
from utils import gradient_penalty, str2bool, chunks
from sklearn.manifold import SpectralEmbedding
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction
from shutil import copyfile
from collections import namedtuple

In [2]:
class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

In [5]:
log_name = 'save/subrelgan-kmeans-2020-04-08-01-19-21'
iteration = 0
K_BINS = 20
checkpoint_name = 'relgan_G_{}.pt'.format(iteration)

In [7]:
assert os.path.exists(os.path.join(log_name, 'relgan_g.py')) == True
assert os.path.exists(os.path.join(log_name, checkpoint_name)) == True

copyfile(os.path.join(log_name, 'relgan_g.py'), 'module/temp_g.py')
from module.temp_g import RelSpaceG
checkpoint = torch.load(os.path.join(log_name, checkpoint_name))
# checkpoint should contain model state dict,
assert len(checkpoint) == 3
with open(os.path.join(log_name, 'params.json'), 'r') as f:
    params = json.load(f)
args = Struct(**params)

p = checkpoint['p']
latent = checkpoint['latent']
dataset = TextSubspaceDataset(-1, 'data/kkday_dataset/train_title.txt', prefix='train_title', embedding=None, 
    max_length=args.max_seq_len, force_fix_len=args.grad_penalty or args.full_text, k_bins=K_BINS, token_level=args.tokenize)
dataset.p = p
dataset.latent = latent
dataloader = torch.utils.data.DataLoader(dataset, num_workers=4,
                collate_fn=seq_collate, batch_size=args.batch_size, shuffle=True)
model = RelSpaceG(args.mem_slots, args.num_heads, args.head_size, args.gen_embed_dim, 
        args.gen_hidden_dim, dataset.vocab_size,
        k_bins=K_BINS, latent_dim=args.gen_latent_dim, noise_dim=args.gen_noise_dim,
        max_seq_len=args.max_seq_len-1, padding_idx=Constants.PAD, gpu=True)

model.load_state_dict(checkpoint['model'])
model = model.cuda()
model = model.eval()

In [18]:
init_idx = 0
init_latent = latent[init_idx]
bins_slider = [ widgets.FloatSlider(value=init_latent[i],min=latent[:, i].min(), max=latent[:, i].max(), step=1e-5, description='bin-{}'.format(i)) for i in range(K_BINS) ]

display(widgets.HTML(
    value="<br><h3> Play with latent feature </h3> ",
))
for slider in bins_slider:
    display(slider)


HTML(value='<br><h3> Play with latent feature </h3> ')

FloatSlider(value=-3.5900782367804655e-05, description='bin-0', max=0.0004853644253621238, min=-7.849274389219…

FloatSlider(value=5.809760092086068e-06, description='bin-1', max=0.0005321432731221096, min=-0.00032013518692…

FloatSlider(value=3.4259194436485795e-06, description='bin-2', max=0.0005185124792046563, min=-0.0005044731914…

FloatSlider(value=-4.541286852483601e-05, description='bin-3', max=0.0005391386682960635, min=-0.0001826896106…

FloatSlider(value=-1.568568532238776e-05, description='bin-4', max=0.0005910106866853758, min=-0.0002856500124…

FloatSlider(value=-1.0476058704520797e-05, description='bin-5', max=0.0005460161894380403, min=-0.000308167480…

FloatSlider(value=2.9108473946050056e-05, description='bin-6', max=0.000692978852060911, min=-0.00037385085277…

FloatSlider(value=-5.48191004403147e-05, description='bin-7', max=0.00045618778942173657, min=-0.0004531215553…

FloatSlider(value=-8.554894777349051e-05, description='bin-8', max=0.0004589319201071457, min=-0.0002982958324…

FloatSlider(value=5.8369744315205454e-05, description='bin-9', max=0.0004949217279693248, min=-0.0003229007934…

FloatSlider(value=-7.497076625613125e-05, description='bin-10', max=0.0004110669473531189, min=-0.000297908911…

FloatSlider(value=0.00028240089831139286, description='bin-11', max=0.0007143132926565295, min=-0.000372019509…

FloatSlider(value=0.00018561781146423856, description='bin-12', max=0.00035750788263047686, min=-0.00033453619…

FloatSlider(value=1.6297882686739917e-05, description='bin-13', max=0.00051931824252762, min=-0.00040337457415…

FloatSlider(value=-3.099225759177233e-06, description='bin-14', max=0.0004228857766965139, min=-0.000306267429…

FloatSlider(value=-4.6562863545154437e-05, description='bin-15', max=0.0004834263977331264, min=-0.00039939276…

FloatSlider(value=-3.1306357284574506e-05, description='bin-16', max=0.0004776485454351822, min=-0.00035114031…

FloatSlider(value=1.38239885222493e-05, description='bin-17', max=0.0008535105814350084, min=-0.00023652468140…

FloatSlider(value=-4.412091347212204e-05, description='bin-18', max=0.000711213351857245, min=-0.0004114540177…

FloatSlider(value=0.00012198565188097484, description='bin-19', max=0.0005714568072616696, min=-0.000211742501…

(140950, 20)