## Demo

In [1]:
import os
# os.environ[“CUDA_DEVICE_ORDER”]=“PCI_BUS_ID”
os.environ["CUDA_VISIBLE_DEVICES"]="4"

In [2]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

In [43]:
from utils.exp_utils import create_exp_dir
from utils.text_utils import MonoTextData
import argparse
import os
import torch
import time
import config
# from models.decomposed_vae import DecomposedVAE
import numpy as np
# from vocab import Vocabulary, build_vocab
from models.vae import VAE
import pandas as pd

import random

class AmolData:
    def __init__(self, test, test_labels1, test_labels2, load_path, vocab, vae_params):
        super(AmolData, self).__init__()

        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.load_path = load_path

        self.vocab = vocab
        self.test_data = test
        self.test_labels1 = test_labels1
        self.test_labels2 = test_labels2

        self.vae = VAE(**vae_params)
        if self.use_cuda:
            self.vae.to(self.device)

        self.nbatch = len(self.test_data)
        self.load(self.load_path)
        self.load_embeddings()
    
    def load(self, path):
        model_path = os.path.join(path, "model.pt")
        self.vae.load_state_dict(torch.load(model_path))
    
    def load_embeddings(self):
        data = np.load('data/demo_embeddings.npz')

        self.mu_pos = torch.tensor(data['mu_pos']).unsqueeze(0).to(self.device)
        self.logvar_pos = torch.tensor(data['logvar_pos']).unsqueeze(0).to(self.device)
        self.mu_neg = torch.tensor(data['mu_neg']).unsqueeze(0).to(self.device)
        self.logvar_neg = torch.tensor(data['logvar_neg']).unsqueeze(0).to(self.device)
        self.mu_past = torch.tensor(data['mu_past']).unsqueeze(0).to(self.device)
        self.logvar_past = torch.tensor(data['logvar_past']).unsqueeze(0).to(self.device)
        self.mu_present = torch.tensor(data['mu_present']).unsqueeze(0).to(self.device)
        self.logvar_present = torch.tensor(data['logvar_present']).unsqueeze(0).to(self.device)

    
    def run_conversion(self, utterance, sentiment=0, tense=0):
#         print(utterance, sentiment, tense)
        self.vae.eval()
        mu_c, logvar_c, _, logvar_s1, _, logvar_s2 = self.vae.encoder(utterance)
        if (sentiment == 0):
            mu_s1 = self.mu_neg
        else:
            mu_s1 = self.mu_pos
        if (tense == 0):
            mu_s2 = self.mu_past
        else:
            mu_s2 = self.mu_present
        return self.vae.decoder.beam_search_decode(mu_c.unsqueeze(0), mu_s1, mu_s2)

In [44]:
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # np.random.seed(0)
    # torch.manual_seed(0)

    conf = config.CONFIG[args.data_name] # Need to update !!
    data_pth = "data/%s" % args.data_name
    print(data_pth)
    train_data_pth = os.path.join(data_pth, "train_identical_sentiment_90_tense.csv")
    train_class = MonoTextData(train_data_pth, glove=True)
    train_data, train_sentiments, train_tenses = train_class.create_data_batch_labels(args.bsz, device)

    vocab = train_class.vocab
    print('Vocabulary size: %d' % len(vocab))

    test_data_pth = os.path.join(data_pth, "eval_data.csv")
    test_class = MonoTextData(test_data_pth, vocab=vocab, glove=True)
    test_data, test_sentiments, test_tenses = test_class.create_data_batch_labels(args.bsz, device)

    print("data done.")

    params = conf["params"]
    params["vae_params"]["vocab"] = vocab
    params["vae_params"]["device"] = device

    amolData = AmolData(test_data, test_sentiments, test_tenses, args.load_path, vocab, params["vae_params"])
    return amolData, train_class

In [45]:
args = Namespace(data_name="yelp", load_path = './checkpoint/ours-yelp/20201211-184811-first_run/', vocab = './tmp/yelp.vocab', bsz=256, embedding = './data/glove.840B.300d.txt', dim = 300)

In [6]:
obj, mtd = main(args)

data/yelp
Vocabulary size: 9482
data done.


In [24]:
# input_sentence = input()
input_sentence = "The service is bad"
processed_input = [mtd.vocab[w] for w in input_sentence.split()]

In [25]:
processed_input = mtd._to_tensor([processed_input], False, 'cuda')

In [26]:
# processed_input

In [37]:
obj.run_conversion(processed_input[0], 1, 1)

tensor([[  1],
        [  3],
        [ 21],
        [  9],
        [114],
        [  2]], device='cuda:0') 1 1


[['<s>', '5-star', 'service', 'is', 'good', '.', '</s>']]

In [42]:
while(True):
    input_sentence = input("Input Sentence: \n")
    sentiment = input("Choose Sentiment:\nNegative:0\nPositive:1\n")
    tense = input("Choose Tense:\nPast:0\nPresent:1\n")
    # input_sentence = "The service is bad"
    # sentiment=0
    # tense = 0
    processed_input = [mtd.vocab[w] for w in input_sentence.split()]
    processed_input = mtd._to_tensor([processed_input], False, 'cuda')
    output = obj.run_conversion(processed_input[0], int(sentiment), int(tense))
    output = " ".join(output[0])
    print(output)

Input Sentence: 
the service is good
Choose Sentiment:
Negative:0
Positive:1
0
Choose Tense:
Past:0
Present:1
1
tensor([[ 1],
        [ 5],
        [21],
        [ 9],
        [19],
        [ 2]], device='cuda:0') 0 1
<s> the service is horrible . </s>


KeyboardInterrupt: Interrupted by user