In [1]:
import random

import numpy as np
import torch
from torch.utils.data.dataloader import default_collate

from settings import EXPERIMENTS_DIR
from experiment import Experiment
from utils import to_device, load_weights, load_embeddings, create_embeddings_matrix
from vocab import Vocab
from train import create_model
from preprocess import load_dataset, create_dataset_reader

In [2]:
exp_id = 'train.jkmkvrrr'

# Load everything

In [3]:
exp = Experiment.load(EXPERIMENTS_DIR, exp_id)

In [4]:
exp.config

TrainConfig(model_class=<class 'models.Seq2SeqMeaningStyle'>, preprocess_exp_id='preprocess.buppgpnf', embedding_size=300, hidden_size=256, dropout=0.2, scheduled_sampling_ratio=0.5, pretrained_embeddings=True, trainable_embeddings=False, meaning_size=128, style_size=128, lr=0.001, weight_decay=1e-07, grad_clipping=5, D_num_iterations=10, D_loss_multiplier=1, P_loss_multiplier=10, P_bow_loss_multiplier=1, use_discriminator=True, use_predictor=False, use_predictor_bow=True, use_motivator=True, use_gauss=False, num_epochs=500, batch_size=1024, best_loss='loss')

In [5]:
preprocess_exp = Experiment.load(EXPERIMENTS_DIR, exp.config.preprocess_exp_id)
dataset_train, dataset_val, dataset_test, vocab, style_vocab, W_emb = load_dataset(preprocess_exp)

Dataset: 453655, val: 10000, test: 10000
Vocab: 9419, style vocab: 2
W_emb: (9419, 300)


In [6]:
dataset_reader = create_dataset_reader(preprocess_exp.config)

In [7]:
model = create_model(exp.config, vocab, style_vocab, dataset_train.max_len, W_emb)

In [8]:
load_weights(model, exp.experiment_dir.joinpath('best.th'))

In [9]:
model = model.eval()

## Predict

In [10]:
def create_inputs(instances):
    if not isinstance(instances, list):
        instances = [instances,]
        
    if not isinstance(instances[0], dict):
        sentences = [
            dataset_reader.preprocess_sentence(dataset_reader.spacy( dataset_reader.clean_sentence(sent)))
            for sent in instances
        ]
        
        style = list(style_vocab.token2id.keys())[0]
        instances = [
            {
                'sentence': sent,
                'style': style,
            }
            for sent in sentences
        ]
        
        for inst in instances:
            inst_encoded = dataset_train.encode_instance(inst)
            inst.update(inst_encoded)            
    
    
    instances = [
        {
            'sentence': inst['sentence_enc'],
            'style': inst['style_enc'],
        } 
        for inst in instances
    ]
    
    instances = default_collate(instances)
    instances = to_device(instances)      
    
    return instances

In [11]:
def get_sentences(outputs):
    predicted_indices = outputs["predictions"]
    end_idx = vocab[Vocab.END_TOKEN]
    
    if not isinstance(predicted_indices, np.ndarray):
        predicted_indices = predicted_indices.detach().cpu().numpy()

    all_predicted_tokens = []
    for indices in predicted_indices:
        indices = list(indices)

        # Collect indices till the first end_symbol
        if end_idx in indices:
            indices = indices[:indices.index(end_idx)]

        predicted_tokens = [vocab.id2token[x] for x in indices]
        all_predicted_tokens.append(predicted_tokens)
        
    return all_predicted_tokens

In [12]:
sentence =  ' '.join(dataset_val.instances[1]['sentence'])

In [13]:
sentence

'they are really good people .'

In [14]:
inputs = create_inputs(sentence)

In [15]:
outputs = model(inputs)

In [16]:
sentences = get_sentences(outputs)

In [17]:
' '.join(sentences[0])

'they are really good people .'

### Swap style

In [18]:
possible_styles = list(style_vocab.token2id.keys()) #['negative', 'positive']

In [19]:
possible_styles

['negative', 'positive']

In [20]:
sentences0 = [s for s in dataset_val.instances if s['style'] == possible_styles[0]]
sentences1 = [s for s in dataset_val.instances if s['style'] == possible_styles[1]]

In [52]:
for i in np.random.choice(np.arange(len(sentences0)), 5):
    print(i, ' '.join(sentences0[i]['sentence']))

3239 if i could give negative stars i certainly would for this place .
2083 if it was not broke , why in the world would you fix it ?
3874 the rice had hard things in it .
1569 quite possibly the worst experience of my life .
3584 however our little one ordered buttered noodles and was pleased as punch .


In [22]:
for i in np.random.choice(np.arange(len(sentences1)), 5):
    print(i, ' '.join(sentences1[i]['sentence']))

4935 which is awesome !
1561 it is also a good place to go just for dessert .
1208 they are amazing , truly , could not be happier .
1347 i had the tamales and they were the best i have ever had !
3450 the capistrami is the best thing ever .


#### Swap

In [53]:
target0 = 3874 # np.random.choice(np.arange(len(sentences0)))
target1 = 4935 # np.random.choice(np.arange(len(sentences0)))

In [54]:
print(' '.join(sentences0[target0]['sentence']))

the rice had hard things in it .


In [55]:
print(' '.join(sentences1[target1]['sentence']))

which is awesome !


In [56]:
inputs = create_inputs([
    sentences0[target0],
    sentences1[target1],
])

In [57]:
z_hidden = model(inputs)

In [58]:
z_hidden['style_hidden'].shape

torch.Size([2, 128])

In [59]:
z_hidden['meaning_hidden'].shape

torch.Size([2, 128])

In [60]:
original_decoded = model.decode(z_hidden)

In [61]:
original_sentences = get_sentences(original_decoded)

In [62]:
print(' '.join(original_sentences[0]))
print(' '.join(original_sentences[1]))

the rice had hard things in it .
which is awesome !


In [63]:
z_hidden_swapped = {
    'meaning_hidden': torch.stack([
        z_hidden['meaning_hidden'][0].clone(),
        z_hidden['meaning_hidden'][1].clone(),        
    ], dim=0),
    'style_hidden': torch.stack([
        z_hidden['style_hidden'][1].clone(),
        z_hidden['style_hidden'][0].clone(),        
    ], dim=0),
}

In [64]:
swaped_decoded = model.decode(z_hidden_swapped)

In [65]:
swaped_sentences = get_sentences(swaped_decoded)

In [66]:
print(' '.join(original_sentences[0]))
print(' '.join(original_sentences[1]))
print()
print(' '.join(swaped_sentences[0]))
print(' '.join(swaped_sentences[1]))

the rice had hard things in it .
which is awesome !

plus is really hard to it .
the rice was awesome .
