In [None]:
import torch
from torch.utils.data import DataLoader
from linguistic_style_transfer_pytorch.config import GeneralConfig, ModelConfig
from linguistic_style_transfer_pytorch.data_loader import TextDataset
from linguistic_style_transfer_pytorch.model import AdversarialVAE
from tqdm import tqdm, trange
import os
import numpy as np
import pickle
import json

#%load_ext autoreload

%autoreload 2

In [None]:

gconfig = GeneralConfig()
config = GeneralConfig()
mconfig = ModelConfig()

# load word embeddings
weights = torch.FloatTensor(np.load(gconfig.word_embedding_path))
# load checkpoint
model_checkpoint = torch.load('linguistic_style_transfer_pytorch/checkpoints/model_epoch_20.pt')
# Load model
model = AdversarialVAE(weight=weights)
model.load_state_dict(model_checkpoint)
model.eval()

# Load average style embeddings
with open(config.avg_style_emb_path, 'rb') as f:
    avg_style_embeddings = pickle.load(f)
# set avg_style_emb attribute of the model
model.avg_style_emb = avg_style_embeddings
# load word2index
with open(gconfig.w2i_file_path) as f:
    word2index = json.load(f)
# load index2word
with open(gconfig.i2w_file_path) as f:
    index2word = json.load(f)
label2index = {'neg': 0, 'pos': 1}
# Read input sentence
source_sentence = "this soup is good"#input("Enter the source sentence")
target_style = "neg"#input("Enter the target style: pos or neg")
# Get token ids



In [None]:

token_ids = [word2index.get(word, gconfig.unk_token)
             for word in source_sentence.split()]
token_ids = torch.LongTensor(token_ids)
target_style_id = torch.LongTensor(label2index[target_style])
# Get transfered sentence token ids

In [None]:
z = torch.FloatTensor([[1,2,3,4,5,6,7,8]]).view([8])

In [None]:
z.repeat(128).view([128,8])

In [None]:
torch.from_numpy(np.array([4]))

In [None]:
torch.LongTensor([0, 1]) == torch.tensor([0, 1])

In [None]:
torch.tensor([0, 1])

In [None]:
test_dataset = TextDataset(mode='test')
test_dataloader = DataLoader(test_dataset, batch_size=mconfig.batch_size)
for iteration, batch in enumerate(tqdm(test_dataloader)):
    
    # unpacking
    sequences, seq_lens, labels, bow_rep = batch
    print(sequences.shape, seq_lens.shape)
    model.transfer_style(sequences, seq_lens.view([128]), 1)

In [None]:
rnn = torch.nn.RNN(10, 20, 2, batch_first=True)
input_ = torch.autograd.Variable(torch.randn(5, 3, 10).transpose(0, 1))
h0 = torch.autograd.Variable(torch.randn(2, 3, 20))
# output, hn = rnn(input, h0)
input_.size(), h0.size()

In [None]:
mconfig.embedding_size + mconfig.generative_emb_dim

In [None]:
mconfig.hidden_dim

In [None]:

target_tokenids = model.transfer_style(token_ids, target_style_id)
target_sentence = "".join([index2word.get(idx) for idx in target_tokenids])
print("Style transfered sentence: {}".format(target_sentence))

### Generate From Scratch

In [50]:
import torch
from torch.utils.data import DataLoader
from linguistic_style_transfer_pytorch.config import GeneralConfig, ModelConfig
from linguistic_style_transfer_pytorch.data_loader import TextDataset
from linguistic_style_transfer_pytorch.model import AdversarialVAE
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from tqdm import tqdm, trange
import os
import numpy as np
import pickle

### Loading for generation

In [8]:
use_cuda = False

In [10]:
gconfig = GeneralConfig()
config = GeneralConfig()
mconfig = ModelConfig()

# load word embeddings
weights = torch.FloatTensor(np.load(gconfig.word_embedding_path))
# load checkpoint
model_checkpoint = torch.load('linguistic_style_transfer_pytorch/checkpoints/model_epoch_10.pt')
# Load model
model = AdversarialVAE(weight=weights)
model.load_state_dict(model_checkpoint)
model.eval()

# Load average style embeddings
with open(config.avg_style_emb_path, 'rb') as f:
    avg_style_embeddings = pickle.load(f)
# set avg_style_emb attribute of the model
model.avg_style_emb = avg_style_embeddings
# load word2index
with open(gconfig.w2i_file_path) as f:
    word2index = json.load(f)
# load index2word
with open(gconfig.i2w_file_path) as f:
    index2word = json.load(f)
label2index = {'neg': 0, 'pos': 1}

In [14]:

if use_cuda:
    model = model.to("cuda")

#=============== Define dataloader ================#
test_dataset = TextDataset(mode='train')
test_dataloader = DataLoader(test_dataset, batch_size=mconfig.batch_size)
content_discriminator_params, style_discriminator_params, vae_and_classifier_params = model.get_params()
#model.transfer_style(sequences, seq_lens.view([128]), 1)

In [39]:
sequences, seq_lens, labels, bow_rep = [x for x in test_dataloader][0]

In [40]:
seq_lengths = seq_lens.view([128])
seq_lens.shape

torch.Size([128, 1])

In [41]:
sequences.shape

torch.Size([128, 40])

### model.transfer_style(self, sequences, seq_lengths, style)

In [42]:
seq_lengths, perm_index = seq_lengths.sort(descending=True)
sequences = sequences[perm_index]
print(sequences.shape)

In [45]:
print(seq_lengths)

tensor([39, 38, 36, 35, 35, 34, 34, 34, 34, 34, 34, 34, 33, 33, 33, 32, 32, 31,
        30, 30, 30, 29, 28, 28, 28, 27, 27, 27, 26, 26, 26, 25, 25, 25, 24, 24,
        24, 23, 23, 23, 23, 23, 23, 23, 22, 21, 21, 21, 20, 20, 19, 19, 19, 18,
        18, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 16, 16, 16, 16, 16, 16, 15,
        15, 15, 15, 15, 15, 15, 15, 14, 14, 14, 13, 13, 13, 12, 12, 12, 12, 12,
        12, 12, 12, 12, 12, 12, 12, 12, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10,
        10, 10, 10, 10, 10, 10,  9,  9,  9,  9,  9,  9,  9,  9,  8,  8,  8,  7,
         7,  6])


In [48]:
embedded_seqs = model.embedding(sequences)
print(embedded_seqs.shape)

torch.Size([128, 40, 300])


In [52]:
#seq_lengths = torch.from_numpy(np.array([15]))
packed_seqs = pack_padded_sequence(
    embedded_seqs, lengths=seq_lengths, batch_first=True)
packed_output, (_) = model.encoder(packed_seqs)
output, _ = pad_packed_sequence(packed_output, batch_first=True)
print(output.shape)

torch.Size([128, 39, 512])


In [54]:
sentence_emb = output[torch.arange(output.size(0)), seq_lengths-1]
print(sentence_emb.shape)

content_emb_mu, content_emb_log_var = model.get_content_emb(
    sentence_emb)
print(content_emb_mu.shape)

torch.Size([128, 512])
torch.Size([128, 128])


In [56]:
sampled_content_emb = model.sample_prior(
    content_emb_mu, content_emb_log_var)
print(sampled_content_emb.shape)

torch.Size([128, 128])


In [58]:
target_style_emb = model.avg_style_emb[1].repeat(128).view([128,8])
print(target_style_emb.shape)


torch.Size([128, 8])


In [59]:
generative_emb = torch.cat(
    (target_style_emb, sampled_content_emb), axis=1)

In [60]:
generative_emb.shape

torch.Size([128, 136])

In [61]:
generative_emb

tensor([[-0.8261, -1.6626, -2.1619,  ...,  1.1458, -0.9463,  0.2676],
        [-0.8261, -1.6626, -2.1619,  ...,  0.9806, -0.8630,  0.3588],
        [-0.8261, -1.6626, -2.1619,  ...,  1.0237, -0.9309,  0.2863],
        ...,
        [-0.8261, -1.6626, -2.1619,  ...,  1.2252, -0.9657,  0.2883],
        [-0.8261, -1.6626, -2.1619,  ...,  1.1768, -1.0424,  0.3253],
        [-0.8261, -1.6626, -2.1619,  ...,  0.8604, -0.8106,  0.4840]],
       grad_fn=<CatBackward>)

### self.generate_sentences(self, input_sentences, latent_emb, inference=False)

In [None]:
# transfered_sentence = self.generate_sentences(
#     input_sentences=None, latent_emb=generative_emb, inference=True)

In [184]:
input_sentences = sequences
latent_emb=generative_emb

In [185]:
input_sentences.shape

torch.Size([128, 40])

In [186]:
latent_emb.shape

torch.Size([128, 136])

In [65]:
sos_token_tensor = torch.tensor(
        [gconfig.predefined_word_index['<sos>']], 
        device=input_sentences.device).unsqueeze(0).repeat(mconfig.batch_size, 1)

In [67]:
sos_token_tensor.shape

torch.Size([128, 1])

In [177]:
input_sentences = torch.cat(
        (sos_token_tensor, input_sentences), dim=1)

In [187]:
input_sentences.shape

torch.Size([128, 40])

In [188]:
sentence_embs = model.embedding(input_sentences)

In [189]:
sentence_embs.shape

torch.Size([128, 40, 300])

In [190]:
# latent_emb = latent_emb.unsqueeze(1).repeat(
#         1, mconfig.max_seq_len+1, 1)
latent_emb = latent_emb.unsqueeze(1).repeat(
        1, mconfig.max_seq_len, 1)

In [191]:
print(generative_emb.shape)
print(latent_emb.shape)

torch.Size([128, 136])
torch.Size([128, 40, 136])


In [192]:
gen_sent_embs = torch.cat(
        (sentence_embs, latent_emb), dim=2)

In [193]:
gen_sent_embs.shape

torch.Size([128, 40, 436])

In [194]:
# output_sentences = torch.zeros(
#     mconfig.max_seq_len, mconfig.batch_size, 
#     device=input_sentences.device)
output_sentences = torch.zeros(
        mconfig.max_seq_len, mconfig.batch_size, mconfig.vocab_size, device=input_sentences.device)

In [195]:
output_sentences.shape

torch.Size([40, 128, 9203])

In [196]:
hidden_states = torch.zeros(
        mconfig.batch_size, mconfig.hidden_dim, device=input_sentences.device)
print(mconfig.hidden_dim)

256


In [197]:
hidden_states.shape

torch.Size([128, 256])

In [198]:
mconfig.max_seq_len

40

In [199]:
for idx in range(mconfig.max_seq_len):
    # get words at the index idx from all the batches
    words = gen_sent_embs[:, idx, :]
    hidden_states = model.decoder(words, hidden_states)
    # project over vocab space
    next_word_logits = model.projector(hidden_states)
#     next_word = nn.Softmax(dim=1)(next_word_logits).argmax(dim=1)
    output_sentences[idx] = next_word_logits

In [214]:
output_sentences.argmax(dim=2)

tensor([[3, 3, 3,  ..., 3, 3, 3],
        [3, 3, 3,  ..., 3, 3, 3],
        [3, 3, 3,  ..., 3, 3, 3],
        ...,
        [3, 3, 3,  ..., 3, 3, 3],
        [3, 3, 3,  ..., 3, 3, 3],
        [3, 3, 3,  ..., 3, 3, 3]])

In [201]:
input_sentences.shape

torch.Size([128, 40])

In [202]:
loss = nn.CrossEntropyLoss(ignore_index=0)
recon_loss = loss(
    output_sentences.view(-1, mconfig.vocab_size), input_sentences.view(-1))

In [205]:
output_sentences.view(-1, mconfig.vocab_size).shape

torch.Size([5120, 9203])

In [152]:
next_word_logits[0,:].shape

torch.Size([9203])

In [117]:
next_word_logits.shape

torch.Size([128, 9203])

In [None]:
next_word_probs = nn.Softmax(dim=1)(self.projector(hidden_states))
next_word = max(next_word_probs.argmax(1))
output_sentence[idx] = next_word
word_emb = self.embedding(next_word)

In [139]:
# model.encoder(output_sentences)

In [109]:
next_word_logits.shape

torch.Size([128, 9203])

In [110]:
output_sentences.shape

torch.Size([40, 128, 9203])

In [None]:
# Training mode
if not inference:
    # Prepend the input sentences with <sos> token
    sos_token_tensor = torch.tensor(
        [gconfig.predefined_word_index['<sos>']], 
        device=input_sentences.device).unsqueeze(0).repeat(mconfig.batch_size, 1)
    input_sentences = torch.cat(
        (sos_token_tensor, input_sentences), dim=1)
    sentence_embs = self.dropout(self.embedding(input_sentences))
    # Make the latent embedding compatible for concatenation
    # by repeating it for max_seq_len + 1(additional one bcoz <sos> tokens were added)
    latent_emb = latent_emb.unsqueeze(1).repeat(
        1, mconfig.max_seq_len+1, 1)
    gen_sent_embs = torch.cat(
        (sentence_embs, latent_emb), dim=2)
    # Delete latent embedding and sos token tensor to reduce memory usage
    del latent_emb, sos_token_tensor
    output_sentences = torch.zeros(
        mconfig.max_seq_len, mconfig.batch_size, mconfig.vocab_size, device=input_sentences.device)
    # initialize hidden state
    hidden_states = torch.zeros(
        mconfig.batch_size, mconfig.hidden_dim, device=input_sentences.device)
    # generate sentences one word at a time in a loop
    for idx in range(mconfig.max_seq_len):
        # get words at the index idx from all the batches
        words = gen_sent_embs[:, idx, :]
        hidden_states = self.decoder(words, hidden_states)
        # project over vocab space
        next_word_logits = self.projector(hidden_states)
        output_sentences[idx] = next_word_logits
# if inference mode is on
else:

    sos_token_tensor = torch.tensor(
        [gconfig.predefined_word_index['<sos>']], device=latent_emb.device).unsqueeze(0).repeat(mconfig.batch_size, 1)
    word_emb = self.embedding(sos_token_tensor)
    hidden_states = torch.zeros(
        mconfig.batch_size, mconfig.hidden_dim, device=latent_emb.device)
    # Store output sentences
    output_sentences = torch.zeros(
        mconfig.max_seq_len, 1, device=latent_emb.device)
    latent_emb = latent_emb.unsqueeze(1)#.repeat(
        #1, mconfig.max_seq_len+1, 1)
    with torch.no_grad():
        # Greedily generate new words at a time
        for idx in range(mconfig.max_seq_len):
            gen_sent_embs = torch.cat(
                (word_emb, latent_emb), dim=2)
            words = gen_sent_embs[:, idx, :]
            hidden_states = self.decoder(words, hidden_states)
            next_word_probs = nn.Softmax(dim=1)(
                self.projector(hidden_states))
            next_word = max(next_word_probs.argmax(1))
            output_sentence[idx] = next_word
            word_emb = self.embedding(next_word)