In [None]:
import os
os.chdir(os.path.abspath('..'))

In [None]:
import argparse
import math
import importlib
import random
import numpy as np
import pandas as pd

import torch
from torch import nn, optim
from tqdm import tqdm, trange
from matplotlib import pyplot as plt


from data import SupervisedTextData
from text_supervised import init_config
from modules import SemisupervisedVAE
from modules import LSTMEncoder, LSTMDecoder
from logger import Logger


class uniform_initializer(object):
    def __init__(self, stdv):
        self.stdv = stdv

    def __call__(self, tensor):
        nn.init.uniform_(tensor, -self.stdv, self.stdv)

In [None]:
SAVED_MODELS = {
    (783435, 0): 'models/g06n/g06n_aggressive1_kls0.10_warm10_0_0_783435.pt',
    (101, 1): 'models/g06n/g06n_aggressive1_kls0.10_warm10_0_1_101.pt',
    (202, 2): 'models/g06n/g06n_aggressive1_kls0.10_warm10_0_2_202.pt',
    (303, 3): 'models/g06n/g06n_aggressive1_kls0.10_warm10_0_3_303.pt',
}

model_seed, task_id = 101, 1

In [None]:
args = init_config(f'--dataset g06n --seed {model_seed} --taskid {task_id} '
                   f'--decode_from {SAVED_MODELS[(model_seed,task_id)]}')
args.device = torch.device(f"cuda:{args.cudaid}" if args.cuda else "cpu")

print('>> Loading training data...')
train_data = SupervisedTextData(fdoc=args.train_doc, fnum=args.train_num, flabel=args.train_label)
train_texts = np.load('./datasets/g06n_data/g06n.doc.train.npy')
#test_label = np.load('g06n.label.test.npy')
#val_label = np.load('g06n.label.valid.npy')

print('>> Loading model and weights...')
model_init = uniform_initializer(0.01)
emb_init = uniform_initializer(0.1)
encoder = LSTMEncoder(args, len(train_data.vocab), model_init, emb_init)
decoder = LSTMDecoder(args, train_data.vocab, model_init, emb_init)
svae = SemisupervisedVAE(encoder, decoder, args).to(args.device)
svae.load_state_dict(torch.load(args.decode_from, map_location=args.device))
svae.eval()

print('>> All is loaded.')

In [None]:
if True:
    def diff_words(row, baseline_col = 'recon_0'):
        baseline_vocab = set(row[baseline_col].split())
        return ({
            i: ' '.join([word for word in row[i].split() if word not in baseline_vocab])
            for i in row.index if i.startswith('recon')
        })


    # generate new documents and save them under `./generated/` folder
    args.decoding_strategy = 'greedy'
    seed = 0
    random.seed(seed)

    patent_idx = random.sample(range(len(train_data)), 1024)
    batch_texts = train_texts[patent_idx]
    batch_docs, batch_nums, batch_labels, _ = train_data[patent_idx]
    batch_docs, _ = train_data.to_tensor(batch_docs, batch_first=True, device=args.device)


    with torch.no_grad():
        z, kl, mu, var = svae.encoder.encode(batch_docs, args.nsamples, return_var=True)
        for lat_idx in trange(32):
            df_results = pd.DataFrame(index=batch_labels)
            df_results['origin'] = batch_texts
            for i in (-50, -10, -5, -1, 0, 1, 5, 10, 50):
                mask = torch.zeros_like(var)
                mask[:, lat_idx] = i
                #df_results[f'recon_{i}'] = [' '.join(d) for d in svae.decode(mu + mask * var, args.decoding_strategy)]
                df_results[f'recon_{i}'] = [' '.join(d) for d in svae.decode(mu + mask, args.decoding_strategy)]
                
            filename = f'{args.dataset}_{args.seed}_seed{seed}_lat{lat_idx:02d}.csv'
            df_results.to_csv(f'./generated/{filename}')
            
            # the reconstructed documnetare too messy, instead, we only look at the differences 
            df_results.apply(diff_words, axis=1, result_type='expand').to_csv(f'../difference/{filename}')

In [None]:
if False:
    args.decoding_strategy = 'beam'
    assert args.decoding_strategy in ('beam', 'greedy', 'sample')

    with torch.no_grad():
        z = svae.sample_from_prior(100)
        decoded_batch = svae.decode(z, args.decoding_strategy, K=5)

    [' '.join(b) for b in decoded_batch]

In [None]:
if False:
    batch_size = 2048
    train_data_iter = train_data.data_iter(batch_size, 'cuda', batch_first=True, shuffle=False)

    z_list = []
    mu_list = []
    nums_list = []
    with torch.no_grad():
        for docs, nums, sents_len in tqdm(train_data_iter, total=math.ceil(len(train_data)/batch_size)):
            z, kl, mu = svae.encoder.encode(docs, args.nsamples, return_mu=True)
            z_list.append(z.cpu().numpy())
            mu_list.append(mu.cpu().numpy())
            nums_list.append(nums)

    z_array = np.concatenate(z_list).squeeze()
    mu_array = np.concatenate(mu_list).squeeze()
    nums_array = np.concatenate(nums_list).squeeze()

    i = 0
    fig = plt.figure(dpi=120)
    ax = fig.gca()
    ax.scatter(nums_array[:1000, i], mu_array[:1000,i])