In [None]:
import torch
from torch.utils.data import DataLoader, random_split
from translation.dataset import BilingualDataset
from translation.tokenizer import get_or_build_tokenizer
from translation.config import get_config
from datasets import load_dataset
import torch.nn as nn
from pathlib import Path
import os
import argparse
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import warnings
from tokenizers import Tokenizer
from translation.model import get_model
import warnings

In [None]:
config = get_config()
lang_src = config['lang_src']
lang_tgt = config['lang_tgt']

ds_raw = load_dataset(
    'opus_books',
    f'{lang_src}-{lang_tgt}',
    split=f'train[:{config["download_size"]}%]'
)
# Build tokenizer
tokenizer_src = get_or_build_tokenizer(
    config, ds_raw, lang_src
)
tokenizer_tgt = get_or_build_tokenizer(
    config, ds_raw, lang_tgt
)

# Keep 90% for train, 10% validation
train_ds_size = int(0.9 * len(ds_raw))
valid_ds_size = len(ds_raw) - train_ds_size
train_ds_raw, valid_ds_raw = random_split(
    ds_raw, [train_ds_size, valid_ds_size]
)

train_ds = BilingualDataset(
    train_ds_raw,
    tokenizer_src,
    tokenizer_tgt,
    lang_src,
    lang_tgt,
    config['seq_len']
)
valid_ds = BilingualDataset(
    valid_ds_raw,
    tokenizer_src,
    tokenizer_tgt,
    lang_src,
    lang_tgt,
    config['seq_len']
)

max_len_src = 0
max_len_tgt = 0

for item in ds_raw:
    src_ids = tokenizer_src.encode(
        item['translation'][lang_src]
    ).ids
    tgt_ids = tokenizer_src.encode(
        item['translation'][lang_tgt]
    ).ids
    max_len_src = max(max_len_src, len(src_ids))
    max_len_tgt = max(max_len_tgt, len(tgt_ids))

print(f'Max length of source sentence: {max_len_src}')
print(f'Max length of target sentence: {max_len_tgt}')

train_dataloader = DataLoader(
    train_ds, batch_size=config['batch_size'], shuffle=True
)
valid_dataloader = DataLoader(
    valid_ds, batch_size=1, shuffle=True
)

In [None]:
# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')
warnings.filterwarnings('ignore')

In [None]:
# Create model
model = get_model(
    config,
    tokenizer_src.get_vocab_size(),
    tokenizer_tgt.get_vocab_size()
).to(device)

# Tensorboard
writer = SummaryWriter(config['experiment_name'])

optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

initial_epoch = 0
global_step = 0

loss_fn = nn.CrossEntropyLoss(
    ignore_index=tokenizer_src.token_to_id('[PAD]'),
    label_smoothing=0.1
).to(device)

for epoch in range(initial_epoch, config['num_epochs']):
    model.train()
    batch_iterator = tqdm(
        train_dataloader,
        desc=f'Processing epoch {epoch:02d}'
    )
    for batch in batch_iterator:
        # (batch, seq_len)
        encoder_input = batch['encoder_input'].to(device)

        # (batch, seq_len)
        decoder_input = batch['decoder_input'].to(device)

        # (batch, 1, 1, seq_len)
        encoder_mask = batch['encoder_mask'].to(device)

        # (batch, 1, seq_len, seq_len)
        decoder_mask = batch['decoder_mask'].to(device)

        # Run the tensors through transformer
        # (batch, seq_len, d_model)
        encoder_output = model.encode(encoder_input, encoder_mask)
        # (batch, seq_len, d_model)
        decoder_output = model.decode(
            encoder_output, encoder_mask, decoder_input, decoder_mask
        )
        # (batch, seq_len, tgt_vocab_size)
        proj_output = model.project(decoder_output)

        label = batch['label'].to(device)  # (batch, seq_len)

        # (batch, seq_len, tgt_vocab_size)
        # --> (batch * seq_len, tgt_vocab_size)
        loss = loss_fn(
            proj_output.view(-1, tokenizer_tgt.get_vocab_size()),
            label.view(-1)
        )
        batch_iterator.set_postfix({'loss': f'{loss.item():6.3f}'})

        # Log the loss
        writer.add_scalar('train_loss', loss.item(), global_step)
        writer.flush()

        # Back propagate the loss
        loss.backward()

        # Update the weights
        optimizer.step()
        optimizer.zero_grad()

        global_step += 1



In [None]:
# Save model
model_folder = f'{Path(__file__).parent}/{config["model_folder"]}'
Path(model_folder).mkdir(parents=True, exist_ok=True)
model_local_path = f'{model_folder}/{config["model_filename"]}{epoch:02d}.pt'
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'global_step': global_step,
    'tokenizer_src': tokenizer_src.save_pretrained(model_folder),
    'tokenizer_tgt': tokenizer_tgt.save_pretrained(model_folder)
}, model_local_path)

print(f'Model saved at {model_local_path}')

print('Training finished')