In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
!pip install torch==1.10.2
!pip install torchtext==0.11.2
!python3 -m spacy download ru_core_news_sm
!python3 -m spacy download en_core_web_sm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ru-core-news-sm==3.3.0
  Downloading https://github.com/explosion/spacy-models/releases/download/ru_core_news_sm-3.3.0/ru_core_news_sm-3.3.0-py3-none-any.whl (15.3 MB)
[K     |████████████████████████████████| 15.3 MB 1.3 MB/s 
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('ru_core_news_sm')
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting en-core-web-sm==3.3.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.3.0/en_core_web_sm-3.3.0-py3-none-any.whl (12.8 MB)
[K     |████████████████████████████████| 12.8 MB 28.7 MB/s 
[38;5;2m✔ 

In [44]:
import os
import math
import time

import torch
import torch.nn as nn
import spacy
from torchtext.legacy.data import Field, BucketIterator, TabularDataset
from tqdm.notebook import tqdm

from transformer import Encoder, Decoder, Seq2Seq
from utils import (
    set_seed,
    calculate_bleu,
    count_parameters,
    train_transformer,
    evaluate_transformer,
    epoch_time,
    transformer_translate_sentence
)

In [23]:
DATA_PATH = 'data'
INPUT_DATA = os.path.join(DATA_PATH, 'data.csv')
MODEL_PATH = os.path.join('models', 'transformer.pt')

BATCH_SIZE = 128
LEARNING_RATE = 0.0005

In [24]:
set_seed()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [25]:
spacy_ru = spacy.load('ru_core_news_sm')
spacy_en = spacy.load('en_core_web_sm')

def tokenize_ru(text):
    """
    Tokenizes Russian text from a string into a list of strings
    """
    return [tok.text for tok in spacy_ru.tokenizer(text)]

def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [26]:
SRC = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True, 
            include_lengths = False,
            batch_first=True)

TRG = Field(tokenize = tokenize_ru, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True,
            batch_first=True)

In [27]:
dataset = TabularDataset(
    path=INPUT_DATA,
    format='tsv',
    fields=[('src', SRC), ('trg', TRG)]
)

In [28]:
train_data, valid_data, test_data = dataset.split(split_ratio=[0.8, 0.15, 0.05])

In [29]:
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)

In [30]:
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
     batch_size = BATCH_SIZE,
     sort=False,
     device = device)

In [31]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 4
DEC_HEADS = 4
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

enc = Encoder(
    INPUT_DIM,
    HID_DIM,
    ENC_LAYERS,
    ENC_HEADS,
    ENC_PF_DIM,
    ENC_DROPOUT,
    device
)

dec = Decoder(
    OUTPUT_DIM,
    HID_DIM,
    DEC_LAYERS,
    DEC_HEADS,
    DEC_PF_DIM,
    DEC_DROPOUT,
    device
)

In [32]:
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]

model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)

In [33]:
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

In [34]:
model.apply(initialize_weights);

In [35]:
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

In [36]:
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

In [37]:
%%time
N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    start_time = time.time()
    train_loss = train_transformer(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate_transformer(model, valid_iterator, criterion)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), MODEL_PATH)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

Epoch: 01 | Time: 0m 44s
	Train Loss: 5.011 | Train PPL: 150.129
	 Val. Loss: 3.166 |  Val. PPL:  23.702
Epoch: 02 | Time: 0m 46s
	Train Loss: 2.916 | Train PPL:  18.467
	 Val. Loss: 2.520 |  Val. PPL:  12.430
Epoch: 03 | Time: 0m 45s
	Train Loss: 2.435 | Train PPL:  11.414
	 Val. Loss: 2.264 |  Val. PPL:   9.623
Epoch: 04 | Time: 0m 45s
	Train Loss: 2.144 | Train PPL:   8.529
	 Val. Loss: 2.108 |  Val. PPL:   8.228
Epoch: 05 | Time: 0m 45s
	Train Loss: 1.905 | Train PPL:   6.722
	 Val. Loss: 1.973 |  Val. PPL:   7.194
Epoch: 06 | Time: 0m 45s
	Train Loss: 1.712 | Train PPL:   5.541
	 Val. Loss: 1.896 |  Val. PPL:   6.659
Epoch: 07 | Time: 0m 45s
	Train Loss: 1.547 | Train PPL:   4.696
	 Val. Loss: 1.837 |  Val. PPL:   6.279
Epoch: 08 | Time: 0m 45s
	Train Loss: 1.404 | Train PPL:   4.072
	 Val. Loss: 1.821 |  Val. PPL:   6.180
Epoch: 09 | Time: 0m 45s
	Train Loss: 1.281 | Train PPL:   3.600
	 Val. Loss: 1.806 |  Val. PPL:   6.084
Epoch: 10 | Time: 0m 45s
	Train Loss: 1.172 | Train PPL

In [38]:
model.load_state_dict(torch.load(MODEL_PATH))
num_batches = len(test_data.examples) // test_iterator.batch_size
start_time = time.time()
test_loss = evaluate_transformer(model, test_iterator, criterion)
elapsed_time = (time.time() - start_time) 
time_per_32_batch = elapsed_time / num_batches / 4

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} | Inference time if batch_size=32: {time_per_32_batch:.3f} sec')

| Test Loss: 1.774 | Test PPL:   5.896 | Inference time if batch_size=32: 0.013 sec


In [39]:
example_idx = 3

src = vars(train_data.examples[example_idx])['src']
trg = vars(train_data.examples[example_idx])['trg']

print(f'src = {" ".join(src)}')
print(f'trg = {" ".join(trg)}')

src = guests can prepare their meals in the kitchen with a stove , microwave and fridge .
trg = гости могут приготовить себе еду на кухне с плитой , микроволновой печью и холодильником .


In [40]:
translation, attention = transformer_translate_sentence(src, SRC, TRG, model, device, max_len=100)

print(f'predicted trg = {" ".join(translation)}')

predicted trg = гости могут приготовить блюда на кухне с плитой , холодильником и плитой . <eos>


In [41]:
example_idx = 7

src = vars(valid_data.examples[example_idx])['src']
trg = vars(valid_data.examples[example_idx])['trg']

print(f'src = {" ".join(src)}')
print(f'trg = {" ".join(trg)}')

src = dam square and munt square are within 700 metres of cloud9 guesthouse amsterdam .
trg = многие достопримечательности амстердама также расположены на расстоянии пешей прогулки от номера - студио . расстояние от площадей дам и мунт до гостевого дома cloud9 amsterdam составляет всего 700 метров .


In [42]:
example_idx = 21

src = vars(test_data.examples[example_idx])['src']
trg = vars(test_data.examples[example_idx])['trg']

print(f'src = {" ".join(src)}')
print(f'trg = {" ".join(trg)}')

src = free private parking is possible on site .
trg = на территории имеется бесплатная частная парковка .


In [43]:
translation, attention = transformer_translate_sentence(src, SRC, TRG, model, device, max_len=100)
print(f'predicted trg = {" ".join(translation)}')

predicted trg = на территории обустроена бесплатная частная парковка . <eos>


In [45]:
model.to(device);

In [46]:
bleu_score = calculate_bleu(test_data, SRC, TRG, model, device, max_len=100)

print(f'BLEU score = {bleu_score*100:.2f}')

BLEU score = 22.08
