In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time
from argparse import ArgumentParser
from pathlib import Path

import torch

from trext.datamodules import DeEnDataModule
#from trext.loggers import NeptuneLogger
from trext.models import (
    TransformerTranslator,
    TransformerEncoder,
    TransformerDecoder,
)
from trext.trainer import Trainer
from trext.utils import Editor, Vocabulary


args = dict(
    batch_size=64,
    decoder_dropout_p=0.5,
    decoder_hidden_dim=128,
    decoder_embedding_dim=128,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    encoder_dropout_p=0.5,
    encoder_hidden_dim=128,
    encoder_embedding_dim=128,
    max_epoch=10,
    verbose=True,
    version='0.1',
)
device = args['device']

In [3]:
start_time = time.time()
print(f"Device is: {args['device']}")

print("Preparing datamodule...")
datamodule = DeEnDataModule(
    data_dir=Path('data/homework_machine_translation_de-en'),
    batch_size=args['batch_size'],
    num_workers=4,
)
datamodule.setup()
print(f"Datamodule is prepared ({time.time() - start_time} seconds)")

Device is: cuda
Preparing datamodule...
Datamodule is prepared (9.839361190795898 seconds)


In [5]:
from torchtext.data import Field, BucketIterator

In [8]:
from torchtext.datasets import TranslationDataset

DATA_PATH = Path("data/homework_machine_translation_de-en")
SRC = Field(tokenize = lambda x: x.split(),
            tokenizer_language="de",
            init_token = '<sos>',
            eos_token = '<eos>',
            lower = True)

TRG = Field(tokenize = lambda x: x.split(),
            tokenizer_language="en",
            init_token = '<sos>',
            eos_token = '<eos>',
            lower = True)

train_data = TranslationDataset(str(DATA_PATH / 'train.de-en.'), ['de', 'en'], fields=(SRC, TRG))
valid_data = TranslationDataset(str(DATA_PATH / 'val.de-en.'), ['de', 'en'], fields=(SRC, TRG))
test_data = TranslationDataset(str(DATA_PATH / 'test1.de-en.'), ['de', 'de'], fields=(SRC, SRC))

SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)



In [9]:
train_iterator = BucketIterator(
    train_data,
    batch_size=64,
    sort_key=lambda x: len(x.comment_text), # the BucketIterator needs to be told what function it should use to group the data.
)
val_iterator = BucketIterator(
    valid_data,
    batch_size=64,
    sort_key=lambda x: len(x.comment_text), # the BucketIterator needs to be told what function it should use to group the data.
)
test_iterator = BucketIterator(
    test_data,
    batch_size=64,
    sort_key=lambda x: len(x.comment_text), # the BucketIterator needs to be told what function it should use to group the data.
)



In [10]:
class DM:
    def __init__(self, train_iterator, val_iterator, test_iterator):
        self.train_iterator = train_iterator
        self.val_iterator = val_iterator
        self.test_iterator = test_iterator
        
    def train_dataloader(self):
        return self.train_iterator
    
    def val_dataloader(self):
        return self.val_iterator
    
    def test_dataloader(self):
        return self.test_iterator

    
datamodule = DM(train_iterator, val_iterator, test_iterator)

In [11]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

encoder = TransformerEncoder(INPUT_DIM, 
              HID_DIM, 
              ENC_LAYERS, 
              ENC_HEADS, 
              ENC_PF_DIM, 
              ENC_DROPOUT, 
              device).to(device)

decoder = TransformerDecoder(OUTPUT_DIM, 
              HID_DIM, 
              DEC_LAYERS, 
              DEC_HEADS, 
              DEC_PF_DIM, 
              DEC_DROPOUT, 
              device).to(device)

In [12]:
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]

#translator = TransformerTranslator(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)
translator = TransformerTranslator(
    encoder=encoder,
    decoder=decoder,
    source_pad_idx=SRC_PAD_IDX,
    target_pad_idx=TRG_PAD_IDX,
    learning_rate=3e-4,
    device=device,
).to(device)

In [19]:
import random
import numpy as np

SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True


trainer = Trainer(
    logger=None,
    max_epoch=args['max_epoch'],
    verbose=args['verbose'],
    version=args['version'],
)

print('Let\'s start training!')
trainer.fit(
    model=translator,
    datamodule=datamodule,
)

print('Predicts!')
predicts = trainer.predict(
    model=translator,
    datamodule=datamodule,
)

  0%|          | 0/3062 [00:00<?, ?it/s]

Let's start training!


  9%|▉         | 284/3062 [00:30<04:54,  9.45it/s]


KeyboardInterrupt: 

In [None]:
device