In [1]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [2]:
from scripts.utils import load_config
config = load_config('../config.json')

Configuration loaded successfully from ../config.json


In [3]:
from scripts.data_service import DataService

BATCH_SIZE = config["BATCH_SIZE"]
SRC_LANGUAGE = config["SRC_LANGUAGE"]
TGT_LANGUAGE = config["TGT_LANGUAGE"]

data_service = DataService(src_language=SRC_LANGUAGE,
                           tgt_language=TGT_LANGUAGE, batch_size=BATCH_SIZE)



In [4]:
import torch
from torch import nn

from scripts.model_service import ModelService

src_vocab, tgt_vocab = data_service.get_vocabularies()
src_vocab_size, tgt_vocab_size = len(src_vocab), len(tgt_vocab)

EMBED_DIM = config["EMBED_DIM"]
NUM_HEADS = config["NUM_HEADS"]
FF_DIM = config["FF_DIM"]
NUM_LAYERS = config["NUM_LAYERS"]
DROPOUT = config["DROPOUT"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_service = ModelService()
model = model_service.get_model(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    embed_dim=EMBED_DIM, num_heads=NUM_HEADS,
    ff_dim=FF_DIM, num_layers=NUM_LAYERS,
    dropout=DROPOUT, device=device)

criterion = nn.CrossEntropyLoss(ignore_index=data_service.PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)



In [5]:
def save_checkpoint(epoch, model, optimizer, loss, checkpoint_dir):
    """
    Save a checkpoint at the current training state.
    """
    checkpoint_path = os.path.join(
        checkpoint_dir, f"checkpoint_epoch_{epoch + 1}.pth")
    torch.save({
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss,
    }, checkpoint_path)
    print(f"Checkpoint saved: {checkpoint_path}")


def load_checkpoint(checkpoint_path, model, optimizer, device):
    """
    Load a checkpoint and restore model and optimizer states.
    """
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"]
    loss = checkpoint["loss"]
    print(
        f"Checkpoint loaded: {checkpoint_path} (Epoch {start_epoch}, Loss: {loss:.4f})")
    return start_epoch, loss


checkpoint_dir = "../checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
resume_training = False
checkpoint_path = os.path.join(
    checkpoint_dir, "checkpoint_epoch_10.pth")
start_epoch = 0

if resume_training and os.path.exists(checkpoint_path):
    start_epoch, _ = load_checkpoint(checkpoint_path, model, optimizer, device)

In [None]:
from tqdm import tqdm

num_epochs = 10

train_loader = data_service.get_train_loader()

for epoch in range(start_epoch, num_epochs):
    model.train()
    total_loss = 0

    with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as t:
        for batch_idx, (src, tgt) in enumerate(t):
            src, tgt = src.to(device), tgt.to(device)
            tgt_input = tgt[:-1, :]
            tgt_output = tgt[1:, :]

            optimizer.zero_grad()
            output = model(src, tgt_input)
            output = output.view(-1, output.size(-1))
            tgt_output = tgt_output.view(-1)

            loss = criterion(output, tgt_output)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            t.set_postfix(loss=f"{loss.item():.4f}")

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}, Average Loss: {avg_loss:.4f}")

    save_checkpoint(epoch, model, optimizer, avg_loss, checkpoint_dir)

final_model_path = os.path.join(checkpoint_dir, "final_model.pth")
torch.save(model.state_dict(), final_model_path)
print(f"Final model saved to {final_model_path}")

Epoch 1/10:   0%|          | 2/907 [00:06<45:25,  3.01s/batch, loss=8.5719]  


KeyboardInterrupt: 

In [None]:
model_save_path = "../models/transformer_model.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

Model saved to transformer_model.pth
