<a href="https://colab.research.google.com/github/vineelkondapalli/multi30k_transformer/blob/main/firsttransformer_multi30k.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# !pip install torch==2.3.0 torchtext==0.18.0 torchvision --upgrade
# !pip install -U spacy
# !pip install datasets
# !python -m spacy download en_core_web_sm
# !python -m spacy download de_core_news_sm
# !pip install transformers



In [2]:
# =============================================================================
# 1. IMPORTS AND SETUP
# =============================================================================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F
from transformers import AutoTokenizer
import triton
import triton.language as tl

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from datasets import load_dataset
# from torchtext.vocab import build_vocab_from_iterator

# import spacy
import math
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# =============================================================================
# 2. DATA PIPELINE
# =============================================================================
# Load tokenizers
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en")

# Load the dataset
multi30k_dataset = load_dataset("bentrevett/multi30k")

# Get the special token IDs from the tokenizer
SRC_PAD_IDX = tokenizer.pad_token_id
TRG_PAD_IDX = tokenizer.pad_token_id

BATCH_SIZE = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def collate_fn(batch):
    # Extract source and target sentences from the batch
    src_batch = [sample['de'] for sample in batch]
    trg_batch = [sample['en'] for sample in batch]

    # Tokenize and numericalize the batches
    # The tokenizer handles padding, truncation, and converting to PyTorch tensors
    src_processed = tokenizer(src_batch, padding=True, truncation=True, return_tensors="pt")
    trg_processed = tokenizer(trg_batch, padding=True, truncation=True, return_tensors="pt")

    # The tokenizer returns a dictionary, we just need the 'input_ids'
    return src_processed['input_ids'].to(device), trg_processed['input_ids'].to(device)

# Create the DataLoaders
train_dataloader = DataLoader(multi30k_dataset['train'], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(multi30k_dataset['validation'], batch_size=BATCH_SIZE, collate_fn=collate_fn)
test_dataloader = DataLoader(multi30k_dataset['test'], batch_size=BATCH_SIZE, collate_fn=collate_fn)

print("Modern data pipeline is ready.")

# =============================================================================
# 3. MODEL DEFINITION
# =============================================================================
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        assert hid_dim % n_heads == 0
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        Q, K, V = self.fc_q(query), self.fc_k(key), self.fc_v(value)
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        if mask is not None: energy = energy.masked_fill(mask == 0, -1e10)
        attention = torch.softmax(energy, dim=-1)
        x = torch.matmul(self.dropout(attention), V)
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, -1, self.hid_dim)
        x = self.fc_o(x)
        return x, attention

@triton.jit
def fused_linear_relu_fwd_kernel(
    # Pointers to matrices
    x_ptr, w_ptr, bias_ptr, output_ptr, relu_mask_ptr,
    # Matrix dimensions
    M, N, K,
    # Strides
    stride_xm, stride_xk,
    stride_wk, stride_wn,
    stride_om, stride_on,
    stride_mm, stride_mn,
    # Meta-parameters
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    """
    Forward pass: output = ReLU(x @ w.T + bias)
    Also saves ReLU mask for backward pass
    """
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    # Matrix multiplication
    for k in range(0, K, BLOCK_K):
        x_ptrs = x_ptr + offs_m[:, None] * stride_xm + (k + offs_k[None, :]) * stride_xk
        x_mask = (offs_m[:, None] < M) & ((k + offs_k[None, :]) < K)
        x_block = tl.load(x_ptrs, mask=x_mask, other=0.0)

        w_ptrs = w_ptr + offs_n[:, None] * stride_wn + (k + offs_k[None, :]) * stride_wk
        w_mask = (offs_n[:, None] < N) & ((k + offs_k[None, :]) < K)
        w_block = tl.load(w_ptrs, mask=w_mask, other=0.0)

        acc += tl.dot(x_block, tl.trans(w_block))

    # Add bias
    if bias_ptr is not None:
        bias_ptrs = bias_ptr + offs_n
        bias_mask = offs_n < N
        bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0)
        acc += bias[None, :]

    # Apply ReLU and save mask
    relu_mask = acc > 0.0
    acc = tl.where(relu_mask, acc, 0.0)

    # Store output and mask
    output_ptrs = output_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
    output_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tl.store(output_ptrs, acc, mask=output_mask)

    # Store ReLU mask for backward (as uint8 to save memory)
    mask_ptrs = relu_mask_ptr + offs_m[:, None] * stride_mm + offs_n[None, :] * stride_mn
    tl.store(mask_ptrs, relu_mask, mask=output_mask)


@triton.jit
def fused_linear_relu_bwd_kernel(
    # Pointers
    grad_output_ptr, relu_mask_ptr, x_ptr, w_ptr,
    grad_x_ptr, grad_w_ptr, grad_bias_ptr,
    # Matrix dimensions
    M, N, K,
    # Strides
    stride_gom, stride_gon,
    stride_mm, stride_mn,
    stride_xm, stride_xk,
    stride_wn, stride_wk,
    stride_gxm, stride_gxk,
    stride_gwn, stride_gwk,
    # Meta-parameters
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    COMPUTE_GRAD_X: tl.constexpr,
    COMPUTE_GRAD_W: tl.constexpr,
):
    """
    Backward pass for fused linear + ReLU
    Computes gradients w.r.t. x, w, and bias
    """
    pid_m = tl.program_id(0)
    pid_k = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
    offs_n = tl.arange(0, BLOCK_N)

    if COMPUTE_GRAD_X:
        # Compute grad_x = grad_output @ w (with ReLU mask applied to grad_output)
        acc_gx = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)

        for n in range(0, N, BLOCK_N):
            # Load grad_output block
            go_ptrs = grad_output_ptr + offs_m[:, None] * stride_gom + (n + offs_n[None, :]) * stride_gon
            go_mask = (offs_m[:, None] < M) & ((n + offs_n[None, :]) < N)
            grad_out = tl.load(go_ptrs, mask=go_mask, other=0.0)

            # Load ReLU mask and apply
            mask_ptrs = relu_mask_ptr + offs_m[:, None] * stride_mm + (n + offs_n[None, :]) * stride_mn
            relu_mask = tl.load(mask_ptrs, mask=go_mask, other=0.0)
            grad_out = tl.where(relu_mask, grad_out, 0.0)

            # Load weight block
            w_ptrs = w_ptr + (n + offs_n[:, None]) * stride_wn + offs_k[None, :] * stride_wk
            w_mask = ((n + offs_n[:, None]) < N) & (offs_k[None, :] < K)
            w_block = tl.load(w_ptrs, mask=w_mask, other=0.0)

            # Accumulate: grad_output @ w
            acc_gx += tl.dot(grad_out, w_block)

        # Store grad_x
        gx_ptrs = grad_x_ptr + offs_m[:, None] * stride_gxm + offs_k[None, :] * stride_gxk
        gx_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
        tl.store(gx_ptrs, acc_gx, mask=gx_mask)


@triton.jit
def fused_linear_relu_bwd_weight_kernel(
    # Pointers
    grad_output_ptr, relu_mask_ptr, x_ptr,
    grad_w_ptr, grad_bias_ptr,
    # Matrix dimensions
    M, N, K,
    # Strides
    stride_gom, stride_gon,
    stride_mm, stride_mn,
    stride_xm, stride_xk,
    stride_gwn, stride_gwk,
    # Meta-parameters
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    COMPUTE_GRAD_BIAS: tl.constexpr,
):
    """
    Backward pass for weights: grad_w = (grad_output.T @ x).T = x.T @ grad_output
    """
    pid_n = tl.program_id(0)
    pid_k = tl.program_id(1)

    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
    offs_m = tl.arange(0, BLOCK_M)

    # Compute grad_w = x.T @ (grad_output * relu_mask)
    acc_gw = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)
    acc_gb = tl.zeros((BLOCK_N,), dtype=tl.float32)

    for m in range(0, M, BLOCK_M):
        # Load grad_output block
        go_ptrs = grad_output_ptr + (m + offs_m[:, None]) * stride_gom + offs_n[None, :] * stride_gon
        go_mask = ((m + offs_m[:, None]) < M) & (offs_n[None, :] < N)
        grad_out = tl.load(go_ptrs, mask=go_mask, other=0.0)

        # Load ReLU mask and apply
        mask_ptrs = relu_mask_ptr + (m + offs_m[:, None]) * stride_mm + offs_n[None, :] * stride_mn
        relu_mask = tl.load(mask_ptrs, mask=go_mask, other=0.0)
        grad_out = tl.where(relu_mask, grad_out, 0.0)

        # Load x block
        x_ptrs = x_ptr + (m + offs_m[:, None]) * stride_xm + offs_k[None, :] * stride_xk
        x_mask = ((m + offs_m[:, None]) < M) & (offs_k[None, :] < K)
        x_block = tl.load(x_ptrs, mask=x_mask, other=0.0)

        # Accumulate: grad_out.T @ x
        acc_gw += tl.dot(tl.trans(grad_out), x_block)

        # Accumulate bias gradient
        if COMPUTE_GRAD_BIAS:
            acc_gb += tl.sum(grad_out, axis=0)

    # Store grad_w
    gw_ptrs = grad_w_ptr + offs_n[:, None] * stride_gwn + offs_k[None, :] * stride_gwk
    gw_mask = (offs_n[:, None] < N) & (offs_k[None, :] < K)
    tl.store(gw_ptrs, acc_gw, mask=gw_mask)

    # Store grad_bias
    if COMPUTE_GRAD_BIAS:
        gb_ptrs = grad_bias_ptr + offs_n
        gb_mask = offs_n < N
        tl.atomic_add(gb_ptrs, acc_gb, mask=gb_mask)


class FusedLinearReLUFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias):
        # Reshape input to 2D if needed
        original_shape = x.shape
        x_2d = x.reshape(-1, x.shape[-1])

        M, K = x_2d.shape
        N, K_w = weight.shape
        assert K == K_w

        # Allocate output and ReLU mask
        output = torch.empty((M, N), device=x.device, dtype=x.dtype)
        relu_mask = torch.empty((M, N), device=x.device, dtype=torch.bool)

        # Block sizes
        BLOCK_M = 64
        BLOCK_N = 64
        BLOCK_K = 32

        # Launch forward kernel
        grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

        fused_linear_relu_fwd_kernel[grid](
            x_2d, weight, bias, output, relu_mask,
            M, N, K,
            x_2d.stride(0), x_2d.stride(1),
            weight.stride(0), weight.stride(1),
            output.stride(0), output.stride(1),
            relu_mask.stride(0), relu_mask.stride(1),
            BLOCK_M=BLOCK_M,
            BLOCK_N=BLOCK_N,
            BLOCK_K=BLOCK_K,
        )

        # Save for backward
        ctx.save_for_backward(x_2d, weight, relu_mask)
        ctx.original_shape = original_shape

        return output.reshape(*original_shape[:-1], N)

    @staticmethod
    def backward(ctx, grad_output):
        x, weight, relu_mask = ctx.saved_tensors
        original_shape = ctx.original_shape

        # Reshape grad_output to 2D
        grad_output = grad_output.reshape(-1, grad_output.shape[-1])

        M, K = x.shape
        N = weight.shape[0]

        # Allocate gradients
        grad_x = torch.empty_like(x) if ctx.needs_input_grad[0] else None
        grad_weight = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
        grad_bias = torch.empty(N, device=weight.device, dtype=weight.dtype) if ctx.needs_input_grad[2] else None

        BLOCK_M = 64
        BLOCK_N = 64
        BLOCK_K = 32

        # Compute grad_x
        if grad_x is not None:
            grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(K, BLOCK_K))
            fused_linear_relu_bwd_kernel[grid](
                grad_output, relu_mask, x, weight,
                grad_x, None, None,
                M, N, K,
                grad_output.stride(0), grad_output.stride(1),
                relu_mask.stride(0), relu_mask.stride(1),
                x.stride(0), x.stride(1),
                weight.stride(0), weight.stride(1),
                grad_x.stride(0), grad_x.stride(1),
                0, 0,
                BLOCK_M=BLOCK_M,
                BLOCK_N=BLOCK_N,
                BLOCK_K=BLOCK_K,
                COMPUTE_GRAD_X=True,
                COMPUTE_GRAD_W=False,
            )
            grad_x = grad_x.reshape(original_shape)

        # Compute grad_weight and grad_bias
        if grad_weight is not None or grad_bias is not None:
            if grad_bias is not None:
                grad_bias.zero_()

            grid = (triton.cdiv(N, BLOCK_N), triton.cdiv(K, BLOCK_K))
            fused_linear_relu_bwd_weight_kernel[grid](
                grad_output, relu_mask, x,
                grad_weight if grad_weight is not None else grad_output,  # dummy
                grad_bias,
                M, N, K,
                grad_output.stride(0), grad_output.stride(1),
                relu_mask.stride(0), relu_mask.stride(1),
                x.stride(0), x.stride(1),
                grad_weight.stride(0) if grad_weight is not None else 0,
                grad_weight.stride(1) if grad_weight is not None else 0,
                BLOCK_M=BLOCK_M,
                BLOCK_N=BLOCK_N,
                BLOCK_K=BLOCK_K,
                COMPUTE_GRAD_BIAS=(grad_bias is not None),
            )

        return grad_x, grad_weight, grad_bias


def fused_linear_relu(x, weight, bias=None):
    """Fused linear + ReLU with autograd support"""
    return FusedLinearReLUFunction.apply(x, weight, bias)


class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout, use_triton=False):
        super().__init__()
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        self.use_triton = use_triton

    def forward(self, x):
        if self.use_triton:
            x = fused_linear_relu(x, self.fc_1.weight, self.fc_1.bias)
        else:
            x = torch.relu(self.fc_1(x))

        x = self.dropout(x)
        x = self.fc_2(x)
        return x


class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.self_attention_layer_norm = nn.LayerNorm(hid_dim)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.feedforward_layer_norm = nn.LayerNorm(hid_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, src, src_mask):
        _src, _ = self.self_attention(src, src, src, src_mask)
        src = self.self_attention_layer_norm(src + self.dropout(_src))
        _src = self.positionwise_feedforward(src)
        src = self.feedforward_layer_norm(src + self.dropout(_src))
        return src

class Encoder(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, device, max_length=200):
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)])
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
    def forward(self, src, src_mask):
        batch_size, src_len = src.shape
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
        for layer in self.layers:
            src = layer(src, src_mask)
        return src

class DecoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.self_attention_layer_norm = nn.LayerNorm(hid_dim)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.encoder_attention_layer_norm = nn.LayerNorm(hid_dim)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.feedforward_layer_norm = nn.LayerNorm(hid_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, trg, enc_src, trg_mask, src_mask):
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        trg = self.self_attention_layer_norm(trg + self.dropout(_trg))
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        trg = self.encoder_attention_layer_norm(trg + self.dropout(_trg))
        _trg = self.positionwise_feedforward(trg)
        trg = self.feedforward_layer_norm(trg + self.dropout(_trg))
        return trg, attention

class Decoder(nn.Module):
    def __init__(self, output_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, device, max_length=200):
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)])
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
    def forward(self, trg, enc_src, trg_mask, src_mask):
        batch_size, trg_len = trg.shape
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
        output = self.fc_out(trg)
        return output, attention

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask
    def make_trg_mask(self, trg):
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        trg_len = trg.shape[1]
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=self.device)).bool()
        trg_mask = trg_pad_mask & trg_sub_mask
        return trg_mask
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        return output, attention

print("All model classes defined.")

# =============================================================================
# 4. INSTANTIATION AND TRAINING SETUP
# =============================================================================
INPUT_DIM = tokenizer.vocab_size
OUTPUT_DIM = tokenizer.vocab_size
HID_DIM = 32
ENC_LAYERS = 1
DEC_LAYERS = 1
ENC_HEADS = 4
DEC_HEADS = 4
ENC_PF_DIM = 64
DEC_PF_DIM = 64
ENC_DROPOUT = 0.2
DEC_DROPOUT = 0.2

PAD_IDX = tokenizer.pad_token_id # Use the tokenizer's padding ID

enc = Encoder(INPUT_DIM, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, device)
dec = Decoder(OUTPUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, device)

net = Seq2Seq(enc, dec, PAD_IDX, PAD_IDX, device).to(device)

def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)
net.apply(initialize_weights)

compiled_net = torch.compile(net)

print("Model compiled.")

NUM_EPOCHS = 20
optimizer = optim.SGD(compiled_net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX, label_smoothing=.1)

print("Model instantiated and ready for training.")

# =============================================================================
# 5. TRAINING LOOP
# =============================================================================
def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(iterator):
        src, trg = batch
        optimizer.zero_grad()
        output, _ = model(src, trg[:,:-1])
        output_dim = output.shape[-1]
        output = output.contiguous().view(-1, output_dim)
        trg = trg[:,1:].contiguous().view(-1)
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src, trg = batch
            output, _ = model(src, trg[:,:-1])
            output_dim = output.shape[-1]
            output = output.contiguous().view(-1, output_dim)
            trg = trg[:,1:].contiguous().view(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss / len(iterator)

CLIP = 1
best_valid_loss = float('inf')

EARLY_STOPPING_PATIENCE = 5
patience_counter = 0

print("\nStarting training...")
for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    train_loss = train(compiled_net, train_dataloader, optimizer, criterion, CLIP)
    valid_loss = evaluate(compiled_net, valid_dataloader, criterion)
    end_time = time.time()

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(net.state_dict(), 'transformer-model.pt')
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= EARLY_STOPPING_PATIENCE:
        print("Early stopping triggered.")
        break

    scheduler.step()

    print(f'Epoch: {epoch+1:02} | Time: {end_time - start_time:.0f}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

print("\nFinished Training.")

Using device: cpu




Modern data pipeline is ready.
All model classes defined.
Model compiled.
Model instantiated and ready for training.

Starting training...


KeyboardInterrupt: 

introducing beam translation(picking best result out of top k translations)

In [None]:
def translate_sentence_beam_search(sentence, tokenizer, model, device, beam_width=3, max_len=50):
    model.eval()

    # Tokenize the source sentence and get input_ids
    inputs = tokenizer(sentence, return_tensors="pt")
    src_tensor = inputs['input_ids'].to(device)

    src_mask = model.make_src_mask(src_tensor)

    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)

    # The Helsinki-NLP model uses the pad_token_id as the decoder_start_token_id
    decoder_start_token_id = tokenizer.pad_token_id
    eos_token_id = tokenizer.eos_token_id

    # Start with a single beam: ([start_token], 0 score)
    beams = [([decoder_start_token_id], 0.0)]
    completed_beams = []

    for _ in range(max_len):
        new_beams = []

        # If all beams have already ended, we can stop
        if not beams:
            break

        for seq, score in beams:
            trg_tensor = torch.LongTensor(seq).unsqueeze(0).to(device)
            trg_mask = model.make_trg_mask(trg_tensor)

            with torch.no_grad():
                output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)

            pred_log_probs = F.log_softmax(output[:,-1], dim=-1)
            top_k_log_probs, top_k_indexes = torch.topk(pred_log_probs, beam_width)

            for i in range(beam_width):
                new_seq = seq + [top_k_indexes[0][i].item()]
                new_score = score + top_k_log_probs[0][i].item()

                if new_seq[-1] == eos_token_id:
                    completed_beams.append((new_seq, new_score))
                else:
                    new_beams.append((new_seq, new_score))

        # If we have enough completed beams, we can potentially stop early
        if len(completed_beams) >= beam_width:
            break

        # Prune the new beams: keep only the top 'k' overall
        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]

    # Add any remaining active beams to the completed list
    completed_beams.extend(beams)

    # Normalize scores by length and find the best one
    completed_beams.sort(key=lambda x: x[1]/len(x[0]), reverse=True)

    best_seq_ids = completed_beams[0][0]

    # Use the tokenizer to decode the IDs back to a string
    translation = tokenizer.decode(best_seq_ids, skip_special_tokens=True)

    return translation, None # Attention is not returned in this version

In [None]:
# Load the weights from your best model
net.load_state_dict(torch.load('transformer-model.pt'))

# Get an example from the test set
example_idx = 13
sample = multi30k_dataset['test'][example_idx]
src_text = sample['de']
trg_text = sample['en']

print(f'Source Sentence: {src_text}')
print(f'Target Sentence: {trg_text}\n')

# The new function takes the raw source string directly
beam_translation, _ = translate_sentence_beam_search(src_text, tokenizer, net, device, beam_width=5)

print(f'Beam Search (k=5) Translation: {beam_translation}')

Source Sentence: Ein sitzender Mann, der an einem Tisch in seinem Haus mit einem Werkzeug arbeitet.
Target Sentence: Man sitting using tool at a table in his home.

Beam Search (k=5) Translation: sitting on a table with his work.
