In [1]:
import sys
import torch
import time
from data import load_dataset
from models import StyleTransformer, Discriminator
from train import train, auto_eval, get_lengths, batch_preprocess
from main import Config
from utils import tensor2text, calc_ppl, idx2onehot, add_noise, word_drop
from evaluator import Evaluator

In [2]:
config = Config()
train_iters, dev_iters, test_iters, vocab = load_dataset(config)

In [3]:
model_F = StyleTransformer(config, vocab).to(config.device)
model_D = Discriminator(config, vocab).to(config.device)

In [4]:
  
model_F.load_state_dict(torch.load("save/Mar09174431/ckpts/2100_F.pth"))
model_F.eval()
vocab_size = len(vocab)
eos_idx = vocab.stoi['<eos>']


In [5]:
def calc_temperature(temperature_config):
    num = len(temperature_config)
    for i in range(num):
        t_a, s_a = temperature_config[i]
        if i == num - 1:
            return t_a

In [6]:
temperature = calc_temperature(config.temperature_config)

In [7]:
#def auto_eval(config, vocab, model_F, test_iters, global_step, temperature):

def inference(data_iter, raw_style):
    gold_text = []
    raw_output = []
    rev_output = []
    for batch in data_iter:
        inp_tokens = batch.text
        inp_lengths = get_lengths(inp_tokens, eos_idx)
        raw_styles = torch.full_like(inp_tokens[:, 0], raw_style)
        rev_styles = 1 - raw_styles
    
        with torch.no_grad():
            raw_log_probs = model_F(
                inp_tokens,
                None,
                inp_lengths,
                raw_styles,
                generate=True,
                differentiable_decode=False,
                temperature=temperature,
            )
        
        with torch.no_grad():
            rev_log_probs = model_F(
                inp_tokens, 
                None,
                inp_lengths,
                rev_styles,
                generate=True,
                differentiable_decode=False,
                temperature=temperature,
            )
            
        gold_text += tensor2text(vocab, inp_tokens.cpu())
        raw_output += tensor2text(vocab, raw_log_probs.argmax(-1).cpu())
        rev_output += tensor2text(vocab, rev_log_probs.argmax(-1).cpu())

    return gold_text, raw_output, rev_output

  

In [8]:

pos_iter = test_iters.pos_iter
neg_iter = test_iters.neg_iter

gold_text, raw_output, rev_output = zip(inference(neg_iter, 0), inference(pos_iter, 1))

In [10]:
outpath = "/home/ubuntu/style-transformer/outputs/soph_tagged/"

In [11]:
with open(outpath + 'gold_text.txt' ,'w') as f:
    f.writelines([x + '\n' for x in gold_text[0]])
with open(outpath + 'raw_output_0.txt' ,'w') as f:
    f.writelines([x + '\n' for x in raw_output[0]])
with open(outpath + 'raw_output_1.txt' ,'w') as f:
    f.writelines([x + '\n' for x in raw_output[1]])
with open(outpath + 'rev_output_0.txt' ,'w') as f:
    f.writelines([x + '\n' for x in rev_output[0]])
with open(outpath + 'rev_output_1.txt' ,'w') as f:
    f.writelines([x + '\n' for x in rev_output[1]])
