In [1]:
import os
import sys
sys.path.insert(1, os.path.abspath(os.path.join(os.path.dirname(f'..{os.sep}utils'))))
sys.path.insert(1, os.path.abspath(os.path.join(os.path.dirname( '..'))))
from utils.constants import *
import torch
import torch.nn as nn
from transformer_v2 import Transformer
from utils.function_utils import *
from func_load_model import *
from utils.optimizer_n_scheduler import *
from utils.logging_tensorboard import create_summary_writer, log_loss, log_learning_rate, log_gradients, log_attention_weights
from utils.distributions import *
from torch.cuda.amp import GradScaler, autocast

In [2]:
num_workers = os.cpu_count()

In [3]:
!nvidia-smi

Fri Apr 28 12:08:37 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 531.14                 Driver Version: 531.14       CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                      TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3060 Ti    WDDM | 00000000:29:00.0  On |                  N/A |
| 30%   37C    P0               43W / 200W|    437MiB /  8192MiB |      5%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 16
max_len = MODEL_MAX_SEQ_LEN
d_model = MODEL_DIM
num_layers = MODEL_N_LAYERS
num_heads = MODEL_N_HEADS
dropout = MODEL_DROPOUT
num_epochs = 10
learning_rate = 1e-4
warmup_steps = 2000
weight_decay = 1e-4
VOCAB_SIZE = 64_000
d_ff = MODEL_FF
label_smoothing = MODEL_LABEL_SMOTHING

NUM_PHRASES = 10_000

n=0
LOGGING_FILE = f'runs{os.sep}translation_experiment_{n}'

In [5]:
tokenizer = load_tokenizer()
model = Transformer(VOCAB_SIZE,
                    VOCAB_SIZE, 
                    d_model, 
                    num_heads, 
                    num_layers, 
                    d_ff, 
                    dropout, 
                    max_len).to(device)

In [6]:
optimizer, scheduler = create_optimizer_and_scheduler(model, d_model, warmup_steps, learning_rate, weight_decay)

In [7]:
writer = create_summary_writer(LOGGING_FILE)

In [8]:
sentence_pairs = load_dataset(FILE_PATH, limit=NUM_PHRASES)
preprocessed_pairs = [(preprocess_text(en), preprocess_text(pt)) for en, pt in sentence_pairs]
split_idx = int(len(preprocessed_pairs) * 0.9)
train_sentence_pairs = preprocessed_pairs[:split_idx]
val_sentence_pairs = preprocessed_pairs[split_idx:]

In [9]:
train_dataset = preprocess_data(train_sentence_pairs, tokenizer, max_len)
val_dataset = preprocess_data(val_sentence_pairs, tokenizer, max_len)

In [10]:
train_dataloader = create_dataloader(train_dataset, batch_size, tokenizer, shuffle=True, num_workers=num_workers)
val_dataloader = create_dataloader(val_dataset, batch_size, tokenizer, shuffle=False, num_workers=num_workers)

In [11]:
pad_idx = tokenizer.token_to_id("<pad>")
criterion = LabelSmoothingKLDivergenceLoss(label_smoothing, VOCAB_SIZE, ignore_index=pad_idx)

In [16]:
def train(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, tgt_vocab, pad_idx, device, log_interval=100):
    global_step = 0
    accumulation_steps = 4
    for epoch in range(num_epochs):
        print('Starting epoch: ', epoch+1)
        model.train()
        accumulation_steps = 4
        optimizer.zero_grad()
        for batch_idx, (src, tgt) in enumerate(train_loader):
            src, tgt = src.to(device), tgt.to(device)
            src_mask, tgt_mask = generate_masks(src, tgt, pad_idx)
            if batch_idx == 0: 
                _, enc_attention_weights, dec_self_attention_weights, dec_enc_attention_weights = model(src, tgt, src_mask, tgt_mask, return_attention=True)
                attention_weights = {
                    "encoder": enc_attention_weights,
                    "decoder_self": dec_self_attention_weights,
                    "decoder_enc_dec": dec_enc_attention_weights
                }
                log_attention_weights(writer, attention_weights, MODEL_N_LAYERS, MODEL_N_HEADS, global_step)

            output = model(src, tgt, src_mask, tgt_mask)
            _, loss = criterion(output, tgt)
            loss.backward()
            if (batch_idx + 1) % accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            # Log loss, learning rate, weights, and attention weights to TensorBoard
            log_loss(writer, loss, global_step)
            log_learning_rate(writer, scheduler.learning_rate(), global_step)
            global_step += 1

            if (batch_idx + 1) % 10 == 0:
                print(f"Epoch {epoch + 1}/{num_epochs} | Batch {batch_idx + 1}/{len(train_loader)} | Train Loss: {loss.item():.4f}")
            


        # Evaluate the model on the validation set after each epoch
        val_loss = evaluate_model(model, val_loader, criterion, device, pad_idx)
        print(f"Epoch: {epoch + 1} | Validation Loss: {val_loss:.4f}")

        #TODO: Implement model forward function withouth tgt_mask
        # May use greedy decoding or beam search
        
        #bleu_score = evaluate_metrics(model, val_loader, pad_idx, tokenizer, device)
        #print(f"Epoch: {epoch + 1}, BLEU Score: {bleu_score:.4f}")

        # Save the model checkpoint after each epoch
        save_checkpoint(model, optimizer, scheduler, epoch, f"checkpoints{os.sep}checkpoint_epoch_{epoch+1}_val_loss_{val_loss:.4f}.pt")

In [17]:
train(model, train_dataloader, val_dataloader, criterion, optimizer, scheduler, num_epochs, VOCAB_SIZE, VOCAB_SIZE, device)

Starting epoch:  1
Epoch 1/10 | Batch 10/563 | Train Loss: 6.9945
Epoch 1/10 | Batch 20/563 | Train Loss: 6.7886
Epoch 1/10 | Batch 30/563 | Train Loss: 6.5420
Epoch 1/10 | Batch 40/563 | Train Loss: 6.7686
Epoch 1/10 | Batch 50/563 | Train Loss: 7.0416
Epoch 1/10 | Batch 60/563 | Train Loss: 6.7818
Epoch 1/10 | Batch 70/563 | Train Loss: 6.7614
Epoch 1/10 | Batch 80/563 | Train Loss: 7.1194
Epoch 1/10 | Batch 90/563 | Train Loss: 6.8572
Epoch 1/10 | Batch 100/563 | Train Loss: 6.7055
Epoch 1/10 | Batch 110/563 | Train Loss: 6.9877
Epoch 1/10 | Batch 120/563 | Train Loss: 6.6758
Epoch 1/10 | Batch 130/563 | Train Loss: 6.8668
Epoch 1/10 | Batch 140/563 | Train Loss: 6.9361
Epoch 1/10 | Batch 150/563 | Train Loss: 6.8894
Epoch 1/10 | Batch 160/563 | Train Loss: 6.7731
Epoch 1/10 | Batch 170/563 | Train Loss: 6.6445
Epoch 1/10 | Batch 180/563 | Train Loss: 6.9944
Epoch 1/10 | Batch 190/563 | Train Loss: 6.7453
Epoch 1/10 | Batch 200/563 | Train Loss: 6.8134
Epoch 1/10 | Batch 210/563 | T

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x000002774468D990>
Traceback (most recent call last):
  File "c:\Users\reidp\miniconda3\envs\torch_gpu\lib\site-packages\torch\utils\data\dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "c:\Users\reidp\miniconda3\envs\torch_gpu\lib\site-packages\torch\utils\data\dataloader.py", line 1437, in _shutdown_workers
    if self._persistent_workers or self._workers_status[worker_id]:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_workers_status'


Epoch 2/10 | Batch 10/563 | Train Loss: 6.6535
Epoch 2/10 | Batch 20/563 | Train Loss: 6.3902
Epoch 2/10 | Batch 30/563 | Train Loss: 6.5809
Epoch 2/10 | Batch 40/563 | Train Loss: 6.8837
Epoch 2/10 | Batch 50/563 | Train Loss: 6.8297
Epoch 2/10 | Batch 60/563 | Train Loss: 6.8025
Epoch 2/10 | Batch 70/563 | Train Loss: 6.6526
Epoch 2/10 | Batch 80/563 | Train Loss: 6.6831
Epoch 2/10 | Batch 90/563 | Train Loss: 6.7584
Epoch 2/10 | Batch 100/563 | Train Loss: 6.5045
Epoch 2/10 | Batch 110/563 | Train Loss: 6.6813
Epoch 2/10 | Batch 120/563 | Train Loss: 7.0154
Epoch 2/10 | Batch 130/563 | Train Loss: 7.1124
Epoch 2/10 | Batch 140/563 | Train Loss: 7.0285
Epoch 2/10 | Batch 150/563 | Train Loss: 6.9532
Epoch 2/10 | Batch 160/563 | Train Loss: 7.0226
Epoch 2/10 | Batch 170/563 | Train Loss: 6.6541
Epoch 2/10 | Batch 180/563 | Train Loss: 6.7288
Epoch 2/10 | Batch 190/563 | Train Loss: 6.6636
Epoch 2/10 | Batch 200/563 | Train Loss: 6.8242
Epoch 2/10 | Batch 210/563 | Train Loss: 6.9517
E

KeyboardInterrupt: 