In [3]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
import time
from argparse import ArgumentParser
from pathlib import Path

import torch
from torchtext.data import Field, BucketIterator

from trext.datamodules import DeEnDataModule
#from trext.loggers import NeptuneLogger
from trext.models import (
    TransformerTranslator,
    TransformerEncoder,
    TransformerDecoder,
)
from trext.trainer import Trainer
from trext.utils import Editor, Vocabulary


args = dict(
    batch_size=64,
    decoder_dropout_p=0.5,
    decoder_hidden_dim=128,
    decoder_embedding_dim=128,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    encoder_dropout_p=0.5,
    encoder_hidden_dim=128,
    encoder_embedding_dim=128,
    max_epoch=10,
    verbose=True,
    version='0.1',
)
device = args['device']

In [5]:
start_time = time.time()
print(f"Device is: {args['device']}")

print("Preparing datamodule...")
datamodule = DeEnDataModule(
    data_dir=Path('data/homework_machine_translation_de-en'),
    batch_size=args['batch_size'],
    num_workers=4,
)
datamodule.setup()
print(f"Datamodule is prepared ({time.time() - start_time} seconds)")

Device is: cuda
Preparing datamodule...
Datamodule is prepared (9.539065599441528 seconds)


In [6]:
from torchtext.datasets import TranslationDataset


DATA_PATH = Path("data/homework_machine_translation_de-en")
SRC = Field(tokenize = lambda x: x.split(),
            tokenizer_language="de",
            init_token = '<sos>',
            eos_token = '<eos>',
            lower = True)

TRG = Field(tokenize = lambda x: x.split(),
            tokenizer_language="en",
            init_token = '<sos>',
            eos_token = '<eos>',
            lower = True)


train_data = TranslationDataset(str(DATA_PATH / 'train.de-en.'), ['de', 'en'], fields=(SRC, TRG))
valid_data = TranslationDataset(str(DATA_PATH / 'val.de-en.'), ['de', 'en'], fields=(SRC, TRG))
test_data = TranslationDataset(str(DATA_PATH / 'test1.de-en.'), ['de', 'de'], fields=(SRC, SRC))

SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)



In [7]:
train_iterator = BucketIterator(
    train_data,
    batch_size=64,
    sort_key=lambda x: len(x.comment_text), # the BucketIterator needs to be told what function it should use to group the data.
)
val_iterator = BucketIterator(
    valid_data,
    batch_size=64,
    sort_key=lambda x: len(x.comment_text), # the BucketIterator needs to be told what function it should use to group the data.
)
test_iterator = BucketIterator(
    test_data,
    batch_size=64,
    sort_key=lambda x: len(x.comment_text), # the BucketIterator needs to be told what function it should use to group the data.
)



In [8]:
class DM:
    def __init__(self, train_iterator, val_iterator, test_iterator):
        self.train_iterator = train_iterator
        self.val_iterator = val_iterator
        self.test_iterator = test_iterator
        
    def train_dataloader(self):
        return self.train_iterator
    
    def val_dataloader(self):
        return self.val_iterator
    
    def test_dataloader(self):
        return self.test_iterator

    
datamodule = DM(train_iterator, val_iterator, test_iterator)

In [9]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

encoder = TransformerEncoder(INPUT_DIM, 
              HID_DIM, 
              ENC_LAYERS, 
              ENC_HEADS, 
              ENC_PF_DIM, 
              ENC_DROPOUT, 
              device).to(device)

decoder = TransformerDecoder(OUTPUT_DIM, 
              HID_DIM, 
              DEC_LAYERS, 
              DEC_HEADS, 
              DEC_PF_DIM, 
              DEC_DROPOUT, 
              device).to(device)

In [10]:
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]

#translator = TransformerTranslator(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)
translator = TransformerTranslator(
    encoder=encoder,
    decoder=decoder,
    source_pad_idx=SRC_PAD_IDX,
    target_pad_idx=TRG_PAD_IDX,
    learning_rate=3e-4,
    device=device,
).to(device)

In [11]:
import random
import numpy as np

SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True


trainer = Trainer(
    logger=None,
    max_epoch=args['max_epoch'],
    verbose=args['verbose'],
    version=args['version'],
)

print('Let\'s start training!')
trainer.fit(
    model=translator,
    datamodule=datamodule,
)

print('Predicts!')
predicts = trainer.predict(
    model=translator,
    datamodule=datamodule,
)

  0%|          | 0/3062 [00:00<?, ?it/s]

Let's start training!


100%|██████████| 3062/3062 [05:40<00:00,  8.99it/s]
 31%|███▏      | 5/16 [00:00<00:00, 41.91it/s]

Training epoch #1 is over.


100%|██████████| 16/16 [00:00<00:00, 32.96it/s]


Validation epoch #1 is over.


100%|██████████| 3062/3062 [05:45<00:00,  8.86it/s]
 31%|███▏      | 5/16 [00:00<00:00, 36.53it/s]

Training epoch #2 is over.


100%|██████████| 16/16 [00:00<00:00, 34.49it/s]


Validation epoch #2 is over.


100%|██████████| 3062/3062 [05:45<00:00,  8.87it/s]
 31%|███▏      | 5/16 [00:00<00:00, 39.62it/s]

Training epoch #3 is over.


100%|██████████| 16/16 [00:00<00:00, 32.97it/s]


Validation epoch #3 is over.


 44%|████▍     | 1351/3062 [02:33<03:14,  8.82it/s]


KeyboardInterrupt: 

In [None]:
device

In [27]:
from pathlib import Path

import torch

from trext.datamodules import DeEnDataModule
from trext.models import AttentionTranslator, Encoder, Decoder, Attention
from trext.utils import Editor


checkpoint = torch.load('models/v0.1-e1.hdf5', map_location=device)

model = TransformerTranslator(
    encoder=encoder,
    decoder=decoder,
    source_pad_idx=SRC_PAD_IDX,
    target_pad_idx=TRG_PAD_IDX,
    learning_rate=3e-4,
    device=device,
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])


for b in datamodule.val_dataloader():
    outs = model.test_step(b, 1)
    '''a = Editor.tags_lists2tokens_lists(
        tags_lists=outs,
        vocabulary=datamodule.en_vocabulary,
    )'''
    print(b.src.shape, outs.shape)
    break

torch.Size([56, 64]) torch.Size([64, 50, 34047])


In [65]:
def tags2tokens(indices, vocab):
    sent = [vocab[i] for i in indices]
    return ' '.join(sent)

a = outs.argmax(2)

for idx in range(10):
    print(''.join(tags2tokens(a[idx], TRG.vocab.itos)).replace('<pad>',''))
    print(''.join(tags2tokens(b.trg[1:,idx], TRG.vocab.itos)).replace('<pad>',''))
    print()

you can take a can idea . <eos>                                          
you can give somebody an idea . <eos>                                         

so , i , i 'm think with with with i is a great problem with me , 's with like with with a 's are going a with with <eos>                   
so , well , i do applied math , and this is a peculiar problem for anyone who does applied math , is that we are like management consultants . <eos>                  

and that 's what thing thing , 'm that can see with and and i thing is 'm to do that . what <eos>                          
and that 's the important lesson i think you can take away , and the one i want to leave you with . <eos>                         

but , to make more more make more , make the , more the , more the the , the , , , more than the to make make the , <eos>                  
politicians try to pick words and use words to shape reality and control reality , but in fact , reality changes words far more than words ca

In [None]:
def indices2sentence(indices, vocab):
    sent = [vocab[i] for i in indices if field.vocab.itos[i] != '<eos>']
    return ' '.join(sent)

def compute_prediction(model, iterator, vocab):
    model.eval()

    preds = []

    with torch.no_grad():
    
        for i, batch in enumerate(tqdm(iterator)):
            src = batch.src
            output = model(src, src, 0) #turn off teacher forcing

            #output = [trg len, batch size, output dim]

            predicted = output.argmax(2)
            for j in range(len(batch)):
                preds.append(indices2sentence(predicted[:, j], vocab))
    return preds

def save_test_preds(test_preds, path):
    with open(path, 'w') as f:
        f.write('\n'.join(test_preds))