In [2]:
import torch
from trainers.rnn_trainer import RNNTrainer
from torch.optim import Adam
from torch import nn
import utils
from models.model import Seq2SeqLSTM
from data import get_vocabs
from torchtext.datasets import Multi30k

### understand the data loading

In [4]:
optimizer_type = Adam
criterion = nn.CrossEntropyLoss(ignore_index=utils.PAD_IDX)
device = torch.device('cuda')
save_dir = '/home/tingchen/learning_subspace_save/'
batch_size = 32
dropout_prob = 0.15
learning_rate = 1e-3
embed_size=256
hidden_size=256
seed=23

trainer = RNNTrainer(optimizer_type=optimizer_type,
criterion=criterion,
device=device,
batch_size=batch_size,
dropout_prob=dropout_prob,
learning_rate=learning_rate,
save_dir=save_dir,
embed_size=embed_size,
hidden_size=hidden_size,
seed=seed
)

In [7]:
train_loader, val_loader = trainer.create_dataloaders()

In [8]:
batch = next(iter(train_loader))

In [11]:
src, tgt = batch
print(src.size(), tgt.size())

torch.Size([21, 32]) torch.Size([24, 32])


In [13]:
src[:, 1]

tensor([  2,  84,  31,  10, 847,   0,  15,   0,   4,   3,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1])

In [14]:
tgt[:, 1]

tensor([   2,  165,   36,    7,  335,  287,   17, 1224,    4,  758,    0,    0,
           5,    3,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1])

In [23]:
' '.join(trainer.vocab_transform[utils.src_lang].lookup_tokens(list((src[:, 0].detach().cpu().numpy()))))

'<bos> Zwei junge weiße Männer sind im Freien in der Nähe <unk> <unk> . <eos> <pad> <pad> <pad> <pad> <pad> <pad>'

In [24]:
' '.join(trainer.vocab_transform[utils.tgt_lang].lookup_tokens(list((tgt[:, 0].detach().cpu().numpy()))))

'<bos> Two young , White males are outside near many <unk> . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>'

## Test out trained model

In [27]:
vocab_transform = get_vocabs(utils.src_lang, utils.tgt_lang)

In [30]:
model = Seq2SeqLSTM(src_vocab_size=len(vocab_transform[utils.src_lang]),
                    tgt_vocab_size=len(vocab_transform[utils.tgt_lang]),
                    embed_size=embed_size,
                    hidden_size=hidden_size,
                    dropout_prob=dropout_prob)

In [31]:
name = 'seq2seq_vanilla_lstms'

In [32]:
model.load_state_dict(torch.load(f'{save_dir}models/{name}.pt'))

<All keys matched successfully>

In [33]:
decoder_output, decoder_hidden, decoder_cell = model(
                    src[:, 1].reshape(1, -1), tgt[:, 1].reshape(1, -1), teacher_forcing_ratio=0)

In [37]:
translated = decoder_output.squeeze().argmax(1)

' '.join(trainer.vocab_transform[utils.tgt_lang].lookup_tokens(list((translated.detach().cpu().numpy()))))

'A men are a hats are playing <unk> <unk> <unk> . . <eos> <eos> . . . . . . . . . .'

## Try to calculate BLEU

In [3]:
test_data = Multi30k(split='test',
                            language_pair=(utils.src_lang, utils.tgt_lang))

In [5]:
val_dataloader = torch.utils.data.DataLoader(test_data,
                                    batch_size=batch_size,
                                    collate_fn=utils.collate_fn,
                                    shuffle=False)