In [None]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)
        # parser = argparse.ArgumentParser() # add_args(parser)
args = Namespace(data_name="yelp", bsz=128, load_path="checkpoint/ours-yelp/20201211-184811-first_run/")

In [None]:
from utils.exp_utils import create_exp_dir
from utils.text_utils import MonoTextData
import argparse
import os
import torch
import time
import config
# from models.decomposed_vae import DecomposedVAE
import numpy as np
from file_io import *
from vocab import Vocabulary, build_vocab
from models.vae import TrainerVAE, VAE

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
np.random.seed(0)
torch.manual_seed(0)

conf = config.CONFIG[args.data_name] # Need to update !!
data_pth = "data/%s" % args.data_name
train_data_pth = os.path.join(data_pth, "train_identical_sentiment_90_tense.csv")
train_class = MonoTextData(train_data_pth, glove=True)
train_data, train_sentiments, train_tenses = train_class.create_data_batch_labels(args.bsz, device)

vocab = train_class.vocab
print('Vocabulary size: %d' % len(vocab))

test_data_pth = os.path.join(data_pth, "test_identical_sentiment_90_tense.csv")
test_class = MonoTextData(test_data_pth, vocab=vocab, glove=True)
test_data, test_sentiments, test_tenses = test_class.create_data_batch_labels(args.bsz, device)

print("data done.")

params = conf["params"]
params["vae_params"]["vocab"] = vocab
params["vae_params"]["device"] = device

def add_args(parser):
    parser.add_argument('--data_name', type=str, default='yelp',
                        help='data name')
    parser.add_argument('--train', type=str, default='./data/yelp/sentiment.train',
                        help='train data path')
    parser.add_argument('--dev', type=str, default='./data/yelp/sentiment.dev',
                        help='val data path')
    parser.add_argument('--test', type=str, default='./data/yelp/sentiment.test',
                        help='test data path')
    parser.add_argument('--load_path', type=str, default='',
                        help='directory name to load')
    parser.add_argument('--bsz', type=int, default=128,
                        help='batch size for training')
    parser.add_argument('--vocab', type=str, default='./tmp/yelp.vocab')
    parser.add_argument('--embedding', type=str, default='./data/glove.840B.300d.txt')
    parser.add_argument('--dim_emb', type=int, default=300)


# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     add_args(parser)
#     args = parser.parse_args()

#     main(args)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import time
import os
import torch.nn.functional as F
import pandas as pd

from models.base_network import LSTMEncoder, StyleClassifier, ContentDecoder, SgivenC, LSTMDecoder

class EvaluateVAE:
    def __init__(self, test, test_labels1, test_labels2, load_path, vocab, vae_params):
        super(EvaluateVAE, self).__init__()

        self.use_cuda = torch.cuda.is_available()
#         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.device = torch.device("cpu")
        self.load_path = load_path

        self.vocab = vocab
        self.test_data = test
        self.test_labels1 = test_labels1
        self.test_labels2 = test_labels2

        self.vae = VAE(**vae_params)
        if self.use_cuda:
            self.vae.cuda()

        self.nbatch = len(self.test_data)
        self.batch_size = len(self.test_data[0])
        self.load(self.load_path)
    
    def load(self, path):
        model_path = os.path.join(path, "model.pt")
        self.vae.load_state_dict(torch.load(model_path))
    
    def eval_style_transfer(self):
        sent_idx = 0
        total_sent = 0
        df_111 = []
        df_222 = []
        df_221 = []
        df_112 = []
        df_121 = []
        df_212 = []
        
        for i in range(800,820):
            for next_sent_idx in range(sent_idx+1, self.test_labels1[i].size()[0]):
                if total_sent >= 10:
                    break
                if (self.test_labels1[i][sent_idx:sent_idx+1][0] != self.test_labels1[i][next_sent_idx:next_sent_idx+1][0]) and (self.test_labels2[i][sent_idx:sent_idx+1][0] != self.test_labels2[i][next_sent_idx:next_sent_idx+1][0]):
                    print("Sizes:", self.test_data[i].size(), self.test_labels1[i].size(), self.test_labels2[i].size())
                    print("Sizes_Sliced:", self.test_data[i][:, sent_idx:sent_idx+1].size(), self.test_data[i][:, next_sent_idx:next_sent_idx+1].size())

                    original_sentence_11 = ""
                    original_sentence_22 = ""

                    for j in range(self.test_data[i].size()[0]):
                        original_sentence_11 += self.vocab.id2word(self.test_data[i][j, sent_idx:sent_idx+1]) + " "
                        original_sentence_22 += self.vocab.id2word(self.test_data[i][j, next_sent_idx:next_sent_idx+1]) + " "

                    print("11 original_sentence:", original_sentence_11, "tense:", self.test_labels1[i][sent_idx:sent_idx+1], "sentiment:", self.test_labels2[i][sent_idx:sent_idx+1])
                    print("22 original_sentence:", original_sentence_22, "tense:", self.test_labels1[i][next_sent_idx:next_sent_idx+1], "sentiment:", self.test_labels2[i][next_sent_idx:next_sent_idx+1])
                    c1, s1_1, s1_2, _ = self.vae.encode(self.test_data[i][:, sent_idx:sent_idx+1])
                    c2, s2_1, s2_2, _ = self.vae.encode(self.test_data[i][:, next_sent_idx:next_sent_idx+1])

                    transfer_sentence_112 = (self.vae.decoder.beam_search_decode(c1, s1_1, s2_2)).cpu()
                    transfer_sentence_221 = (self.vae.decoder.beam_search_decode(c2, s2_1, s1_2)).cpu()
                    transfer_sentence_111 = (self.vae.decoder.beam_search_decode(c1, s1_1, s1_2)).cpu()
                    transfer_sentence_222 = (self.vae.decoder.beam_search_decode(c2, s2_1, s2_2)).cpu()
                    transfer_sentence_121 = (self.vae.decoder.beam_search_decode(c1, s2_1, s1_2)).cpu()
                    transfer_sentence_212 = (self.vae.decoder.beam_search_decode(c2, s1_1, s2_2)).cpu()

                    df_111 += [" ".join(transfer_sentence_111[0][:-1])]
                    df_222 += [" ".join(transfer_sentence_222[0][:-1])]
                    df_112 += [" ".join(transfer_sentence_112[0][:-1])]
                    df_221 += [" ".join(transfer_sentence_221[0][:-1])]
                    df_121 += [" ".join(transfer_sentence_121[0][:-1])]
                    df_212 += [" ".join(transfer_sentence_212[0][:-1])]
                    
#                     print("111 sentence:", " ".join(transfer_sentence_111[0][:-1]))
#                     print("222 sentence:", " ".join(transfer_sentence_222[0][:-1]))
#                     print("112 sentence:", " ".join(transfer_sentence_112[0][:-1]))
#                     print("221 sentence:", " ".join(transfer_sentence_221[0][:-1]))
#                     print("121 sentence:", " ".join(transfer_sentence_121[0][:-1]))
#                     print("212 sentence:", " ".join(transfer_sentence_212[0][:-1]))
                    total_sent += 1
                    break
        
        pd.options.display.max_colwidth = 100
        df = pd.DataFrame()
        df["OriginalA"] = df_111
        df["OriginalB"] = df_222
        df["A - Sent Swapped"] = df_121
        df["A - Tense Swapped"] = df_112
        df["B - Sent Swapped"] = df_212
        df["B - Tense Swapped"] = df_221
        print(df)
        return df

In [None]:
evalVAE = EvaluateVAE(train_data, train_sentiments, train_tenses, args.load_path, vocab, params["vae_params"])
df = evalVAE.eval_style_transfer()

In [None]:
df