This notebook covers all the optimization steps for training an encoder-decoder transformer model

In [1]:
%autosave 300
%reload_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

Autosaving every 300 seconds


In [2]:
import os

os.chdir(
    "/mnt/batch/tasks/shared/LS_root/mounts/clusters/insights-model-run2/code/Users/soutrik.chowdhury/EraV2_Transformers"
)
print(os.getcwd())

/mnt/batch/tasks/shared/LS_root/mounts/clusters/insights-model-run2/code/Users/soutrik.chowdhury/EraV2_Transformers


In [3]:
from S18_code.config import get_config, get_weights_file_path
from S18_code.model import build_transformer
import torch
from S18_code.dataloader import get_ds
from datasets import load_dataset
import torch.optim as optim
import torch.nn as nn
from torchinfo import summary
from torch.optim.lr_scheduler import CyclicLR
from torch.optim import Adam
from torchmetrics.text import BLEUScore
from torch.cuda.amp import GradScaler

In [4]:
device = get_config()["device"]
print(device)

cuda


In [5]:
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name())
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())
print(torch.cuda.get_device_name(device.index))
print(torch.cuda.get_device_properties(device.index).total_memory / 1024**3)
print(torch.cuda.memory_summary(device=device.index))
print(torch.cuda.get_device_capability(device.index))
print(torch.cuda.get_device_properties(device.index))
print(torch.cuda.get_device_properties(device.index).multi_processor_count)

True
0
Tesla V100-PCIE-16GB
0
0
Tesla V100-PCIE-16GB
15.7725830078125
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 

#### Download the dataset and create data via dataloader

In [6]:
ds_raw = load_dataset(
    get_config()["datasource"],
    f"{get_config()['src_lang']}-{get_config()['tgt_lang']}",
    split="train",
)
print(len(ds_raw))

In [None]:
# batch wise padding data loading
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(
    ds_raw, get_config()
)

In [None]:
for data in train_dataloader:
    print(data["encoder_input"].shape)
    print(data["decoder_input"].shape)
    print(data["encoder_mask"].shape)
    print(data["decoder_mask"].shape)
    break

In [None]:
for data in val_dataloader:
    print(data["encoder_input"].shape)
    print(data["decoder_input"].shape)
    print(data["encoder_mask"].shape)
    print(data["decoder_mask"].shape)
    break

#### Model Initialization and Training

In [None]:
transformer_model = build_transformer(
    src_vocab_size=tokenizer_src.get_vocab_size(),
    tgt_vocab_size=tokenizer_tgt.get_vocab_size(),
).to(device)

In [None]:
summary(transformer_model)

First we will do basic training and validation based on:
* Application of basic training loop
* Application of custom scheduler for learning rate

In [None]:
from S18_code.trainer import (
    CustomLRScheduler,
    TranslationLoss,
    run_training_loop_basic,
    run_validation_loop,
    greedy_decode,
    run_inference_loop,
    run_training_loop_opt,
)

from S18_code.utils import start_timer, end_timer

In [None]:
pad_idx = tokenizer_src.token_to_id("[PAD]")
sos_idx = tokenizer_src.token_to_id("[SOS]")
eos_idx = tokenizer_src.token_to_id("[EOS]")
label_smoothing = 0.1

In [None]:
# Optimizer Adam with weight decay
optimizer = Adam(
    transformer_model.parameters(),
    betas=(0.9, 0.98),
    eps=1.0e-9,
    lr=get_config()["learning_rate"],
)
# Learning rate scheduler
scheduler = CustomLRScheduler(optimizer, get_config()["d_model"], 1000)
# Loss function
criterion = TranslationLoss(pad_idx, label_smoothing, tokenizer_tgt)
# BLEU score metric
metric = BLEUScore()

Basic Training

In [None]:
# for epoch in range(get_config()["num_epochs"]):
#     global_step = 0
#     train_loss, global_step = run_training_loop_basic(
#         transformer_model,
#         train_dataloader,
#         optimizer,
#         criterion,
#         device,
#         global_step,
#         scheduler,
#     )
#     print(f"Epoch: {epoch}, Train loss: {train_loss}")

#     val_loss = run_validation_loop(transformer_model, val_dataloader, criterion, device)
#     print(f"Epoch: {epoch}, Validation loss: {val_loss}")

#     run_inference_loop(
#         transformer_model,
#         val_dataloader,
#         tokenizer_tgt,
#         device,
#         5,
#         metric,
#         sos_idx,
#         eos_idx,
#     )

Application of optimization techniques on training loop to improve performance:
* Automatic Mixed Precision
* Gradient Scaling (Gradient Clipping)

In [None]:
scaler = torch.cuda.amp.GradScaler()

In [None]:
start_timer()

for epoch in range(get_config()["num_epochs"]):
    global_step = 0
    train_loss, global_step = run_training_loop_opt(
        transformer_model,
        train_dataloader,
        optimizer,
        criterion,
        device,
        global_step,
        scaler,
        scheduler,
    )
    print(f"Epoch: {epoch}, Train loss: {train_loss}")

    val_loss = run_validation_loop(transformer_model, val_dataloader, criterion, device)
    print(f"Epoch: {epoch}, Validation loss: {val_loss}")

    run_inference_loop(
        transformer_model,
        val_dataloader,
        tokenizer_tgt,
        device,
        5,
        metric,
        sos_idx,
        eos_idx,
    )

end_timer("Simple Neural Network")