We mostly follow the blog ["Transformers Laid Out"](https://goyalpramod.github.io/blogs/Transformers_laid_out/) to implement a vanilla transformer presented in the seminal paper ["Attention is All You Need"](https://arxiv.org/abs/1706.03762) but make the following changes:

1. Use separate dropouts after multi-headed attention and feed-forward network in `EncoderLayer` and `DecoderLayer`.
2. Query may have different sequence length from Key and Value in `MultiHeadAttention.forward`.
4. Use logical or operation to combine the padding mask and future mask.
5. Work around the `RuntimeError`: "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!" in `create_masks` method.
6. Fix 2 issues in `TransformerLRScheduler.step()`.
7. As for the soft target distribution of Label Smoothing, set the probability of the true label to `(1 - smoothing) + smoothing / vocab_size` and the probability of other tokens to `smoothing / vocab_size`.
8. Remove the redundant `create_mask` in method `training_transformer` since we create the source and target masks in `Transformer.forward`.
9. Add the missing `AdamW` optimizer for training the transformer.
10. Replace `view()` with `reshape()` to reshape `tgt_output` in `train_transformer` since `tgt_output` is not contiguous.

**TODO**:

1. Use the data loader in HuggingFace `datasets` library.
2. Use Accelerate to train the transformer.


In [1]:
import math

import torch
import torch.nn as nn
from torch.nn.functional import softmax

# Multi-Head Attention

In [2]:
def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Args:
        query: (batch_size, num_heads, seq_len_q, d_q)
        key: (batch_size, num_heads, seq_len_k, d_k)
        value: (batch_size, num_heads, seq_len_v, d_v)
        mask: Optional mask to prevent attention to certain positions
    """
    # Shape checks
    assert query.dim() == 4, f"Query should be 4-dim but got {query.dim()}-dim"
    assert key.dim() == 4, f"Key should be 4-dim but got {key.dim()}-dim"
    assert value.dim() == 4, f"Value should be 4-dim but got {value.dim()}-dim"
    assert query.size(-1) == key.size(
        -1
    ), f"Query depth {query.size(-1)} != Key depth {key.size(-1)}"
    assert key.size(-2) == value.size(
        -2
    ), f"Key length {key.size(-2)} != Value length {value.size(-2)}"

    # Get the Key depth.
    d_k = key.size(-1)

    # Calculate the attention scores from Query and Key.
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    # Set the masked elements to -inf for the decoder, which are usually the upper right half.
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))

    # Calculate the attention weights by taking the softmax for each Query token along Keys.
    weights = softmax(scores, dim=-1)

    # Return the mulplication of the attention weights and Value.
    return torch.matmul(weights, value)

In [29]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert (
            d_model % num_heads == 0
        ), f"d_model = {d_model} is not divisible by num_heads = {num_heads}"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Create the learnable projection matrices
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    @staticmethod
    def scaled_dot_product_attention(query, key, value, mask=None):
        """
        Args:
            query: (batch_size, num_heads, seq_len_q, d_q)
            key: (batch_size, num_heads, seq_len_k, d_k)
            value: (batch_size, num_heads, seq_len_v, d_v)
            mask: Optional mask to prevent attention to certain positions
        """
        # Shape checks
        assert query.dim() == 4, f"Query should be 4-dim but got {query.dim()}-dim"
        assert key.dim() == 4, f"Key should be 4-dim but got {key.dim()}-dim"
        assert value.dim() == 4, f"Value should be 4-dim but got {value.dim()}-dim"
        assert query.size(-1) == key.size(
            -1
        ), f"Query depth {query.size(-1)} != Key depth {key.size(-1)}"
        assert key.size(-2) == value.size(
            -2
        ), f"Key length {key.size(-2)} != Value length {value.size(-2)}"

        # Get the Key depth.
        d_k = key.size(-1)

        # Calculate the attention scores from Query and Key.
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

        # Set the masked elements to -inf for the decoder, which are usually the upper right half.
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))

        # Calculate the attention weights by taking the softmax for each Query token along Keys.
        weights = softmax(scores, dim=-1)

        # Return the mulplication of the attention weights and Value.
        return torch.matmul(weights, value)

    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: (batch_size, query_seq_len, d_model)
            key: (batch_size, key_seq_len, d_model)
            value: (batch_size, key_seq_len, d_model)
            mask: Optional mask to prevent attention to certain positions
        """
        assert key.size(1) == value.size(
            1
        ), f"key sequence length {key.size(1)} != value sequence length {value.size(1)}"

        batch_size = query.size(0)
        query_seq_len = query.size(1)
        key_seq_len = key.size(1)

        # 1. Linear projections
        Q = self.W_q(query)  # (batch_size, query_seq_len, d_model)
        K = self.W_k(key)  # (batch_size, key_seq_len, d_model)
        V = self.W_v(value)  # (batch_size, key_seq_len, d_model)

        # 2. Split into heads with the shape (batch_size, num_heads, seq_len, d_k).
        # Note that d_model = num_heads * d_k.
        Q = Q.view(batch_size, query_seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, key_seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, key_seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # 3. Apply scaled dot product attention with the shape (batch_size, num_heads, seq_len, d_k).
        self_attention = MultiHeadAttention.scaled_dot_product_attention(Q, K, V, mask)

        # 4. Concatenate num_heads (batch_size, query_seq_len, d_k) into (batch_size, query_seq_len, d_model).
        # Note that d_model = num_heads * d_k.
        concatenated_self_attention = (
            self_attention.transpose(1, 2)
            .contiguous()
            .view(batch_size, query_seq_len, self.d_model)
        )

        # 5. Final projection
        return self.W_o(concatenated_self_attention)

# Feed Forward Network

In [4]:
class FeedForwardNetwork(nn.Module):
    """Position-wise Feed-Forward Network
    Args:
        d_model: input/output dimension
        d_ff: hidden dimension
        dropout: dropout rate (default=0.1)
    """

    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(d_model, d_ff),  # Output shape: (batch_size, seq_len, d_ff)
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),  # Output shape: (batch_size, seq_len, d_model)
            nn.Dropout(dropout),
        )

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
        Returns:
            Output tensor of shape (batch_size, seq_len, d_model)
        """
        return self.model(x)

# Positional Encoding

In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()

        # Create a matrix of shape (max_seq_length, d_model).
        pe = torch.zeros(max_seq_length, d_model)

        # Create a position column vector of shape (max_seq_length, 1).
        pos = torch.arange(0, max_seq_length).unsqueeze(1)

        # Create a division 1-D array of shape (d_model // 2).
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        # Compute the matrix of positional encodings.
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)

        # Register buffer with shape (1, max_seq_length, d_model).
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        """
        Args:
            x: Tensor shape (batch_size, seq_len, d_model)
        """
        return x + self.pe[:, : x.size(1)]

# Encoder Layer

In [30]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        # 1. Multi-head attention
        self.mha = MultiHeadAttention(d_model, num_heads)

        # 2. Dropout
        self.dropout_mha = nn.Dropout(dropout)

        # 3. Layer normalization
        self.layer_normal_mha = nn.LayerNorm(d_model)

        # 4. Feed forward
        self.ff = FeedForwardNetwork(d_model, d_ff, dropout)

        # 5. Another Dropout
        self.dropout_ff = nn.Dropout(dropout)

        # 6. Another layer normalization
        self.layer_normal_ff = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
            mask: Optional mask for padding
        Returns:
            x: Output tensor of shape (batch_size, seq_len, d_model)
        """
        # 1. Multi-head attention with residual connection and layer norm
        mha_output = self.mha(x, x, x, mask)
        x = self.dropout_mha(x + mha_output)
        x = self.layer_normal_mha(x)

        # 2. Feed forward with residual connection and layer norm
        ff_output = self.ff(x)
        x = self.dropout_ff(x + ff_output)
        x = self.layer_normal_ff(x)

        return x

# Decoder Layer

In [31]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        # 1. Masked Multi-Head Attention
        self.masked_mha = MultiHeadAttention(d_model, num_heads)

        # 2. Dropout for Masked Multi-Head Attention
        self.dropout_masked_mha = nn.Dropout(dropout)

        # 3. Layer norm for first sub-layer
        self.layer_normal_masked_mha = nn.LayerNorm(d_model)

        # 4. Multi-Head Attention for cross attention with encoder output
        # This will take encoder output as key and value
        self.mha = MultiHeadAttention(d_model, num_heads)

        # 5. Dropout for Multi-Head Attention
        self.dropout_mha = nn.Dropout(dropout)

        # 6. Layer norm for second sub-layer
        self.layer_normal_mha = nn.LayerNorm(d_model)

        # 7. Feed forward network
        self.ff = FeedForwardNetwork(d_model, d_ff)

        # 8. Dropout for Feed Forward network
        self.dropout_ff = nn.Dropout(dropout)

        # 9. Layer norm for third sub-layer
        self.layer_normal_ff = nn.LayerNorm(d_model)

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: Target sequence embedding (batch_size, target_seq_len, d_model)
            encoder_output: Output from encoder (batch_size, source_seq_len, d_model)
            src_mask: Mask for source padding
            tgt_mask: Mask for target padding and future positions
        """
        # 1. Masked self-attention
        masked_mha_output = self.masked_mha(x, x, x, tgt_mask)
        x = self.dropout_masked_mha(x + masked_mha_output)
        x = self.layer_normal_masked_mha(x)

        # 2. Cross-Attention between Query and encoder output as Key and Value.
        mha_output = self.mha(x, encoder_output, encoder_output, src_mask)
        x = self.dropout_mha(x + mha_output)
        x = self.layer_normal_mha(x)

        # 3. Feed forward network
        ff_output = self.ff(x)
        x = self.dropout_ff(x + ff_output)
        x = self.layer_normal_ff(x)

        return x

# Encoder

In [32]:
class Encoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model,
        num_layers=6,
        num_heads=8,
        d_ff=2048,
        dropout=0.1,
        max_seq_length=5000,
    ):
        super().__init__()

        # 1. Input embedding
        self.embeddings = nn.Embedding(vocab_size, d_model)
        self.scale = math.sqrt(d_model)

        # 2. Positional encoding
        self.pe = PositionalEncoding(d_model, max_seq_length)

        # 3. Dropout
        self.dropout = nn.Dropout(dropout)

        # 4. Stack of N encoder layers
        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )

    def forward(self, x, mask=None):
        """
        Args:
            x: Input tokens (batch_size, seq_len)
            mask: Mask for padding positions
        Returns:
            encoder_output: (batch_size, seq_len, d_model)
        """
        # 1. Pass through embedding layer and scale
        x = self.embeddings(x) * self.scale

        # 2. Add positional encoding and apply dropout
        x = self.dropout(self.pe(x))

        # 3. Pass through each encoder layer
        for layer in self.encoder_layers:
            x = layer(x, mask)

        return x

# Decoder

In [33]:
class Decoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model,
        num_layers=6,
        num_heads=8,
        d_ff=2048,
        dropout=0.1,
        max_seq_length=5000,
    ):
        super().__init__()

        # 1. Output embedding
        self.embeddings = nn.Embedding(vocab_size, d_model)
        self.scale = math.sqrt(d_model)

        # 2. Positional encoding
        self.pe = PositionalEncoding(d_model, max_seq_length)

        # 3. Dropout
        self.dropout = nn.Dropout(dropout)

        # 4. Stack of N decoder layers
        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: Target tokens (batch_size, target_seq_len)
            encoder_output: Output from encoder (batch_size, source_seq_len, d_model)
            src_mask: Mask for source padding
            tgt_mask: Mask for target padding and future positions
        Returns:
            decoder_output: (batch_size, target_seq_len, d_model)
        """
        # 1. Pass through embedding layer and scale
        x = self.embeddings(x) * self.scale

        # 2. Add positional encoding and dropout
        x = self.dropout(self.pe(x))

        # 3. Pass through each decoder layer
        for layer in self.decoder_layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)

        return x

# Utility Code for creating masks

In [10]:
def create_padding_mask(seq):
    """
    Create mask for padding tokens (0s)
    Args:
        seq: Input sequence tensor (batch_size, seq_len)
    Returns:
        mask: Padding mask (batch_size, 1, 1, seq_len)
    """
    batch_size, seq_len = seq.shape
    # Input seq: 0 - padding tokens.
    # Output mask: 0 - allowed positions; 1 - padding positions.
    output = torch.eq(seq, 0).float()
    return output.view(batch_size, 1, 1, seq_len)

In [11]:
def create_future_mask(size):
    """
    Create mask to prevent attention to future positions
    Args:
        size: Size of square mask (target_seq_len)
    Returns:
        mask: Future mask (1, 1, size, size)
    """
    # Create upper triangular matrix and invert it.
    # Output mask: 0 - allowed positions; 1 - masked positions.
    return torch.triu(torch.ones(1, 1, size, size), diagonal=1) == 0

In [12]:
def create_masks(src, tgt):
    """
    Create all masks needed for training
    Args:
        src: Source sequence (batch_size, src_len)
        tgt: Target sequence (batch_size, tgt_len)
    Returns:
        src_mask: Padding mask for encoder
        tgt_mask: Combined padding and future mask for decoder
    """
    # 1. Create padding masks
    src_padding_mask = create_padding_mask(src)  # Shape: (batch_size, 1, 1, src_len)
    tgt_padding_mask = create_padding_mask(tgt)  # Shape: ()

    # 2. Create future mask
    tgt_future_mask = create_future_mask(tgt.size(1))

    # 3. Combine padding and future mask for target
    # Both masks should be 0 for allowed positions
    # BUGBUG: Manually copy the future mask to GPU.
    tgt_mask = torch.logical_or(tgt_padding_mask, tgt_future_mask.to("cuda"))

    return src_padding_mask, tgt_mask

# Transformer

In [34]:
class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        tgt_vocab_size,
        d_model,
        num_layers=6,
        num_heads=8,
        d_ff=2048,
        dropout=0.1,
        max_seq_length=5000,
    ):
        super().__init__()

        # Pass all necessary parameters to Encoder and Decoder
        self.encoder = Encoder(
            src_vocab_size,
            d_model,
            num_layers,
            num_heads,
            d_ff,
            dropout,
            max_seq_length,
        )
        self.decoder = Decoder(
            tgt_vocab_size,
            d_model,
            num_layers,
            num_heads,
            d_ff,
            dropout,
            max_seq_length,
        )

        # The final linear layer should project from d_model to tgt_vocab_size
        self.final_layer = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):
        # Create masks for source and target
        src_mask, tgt_mask = create_masks(src, tgt)

        # Pass through encoder
        encoder_output = self.encoder(src, src_mask)

        # Pass through decoder
        decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask)

        # Project to vocabulary size
        output = self.final_layer(decoder_output)

        # Note: Usually don't apply softmax here if using CrossEntropyLoss
        # as it applies log_softmax internally
        return output

# Utility code for Transformer

In [46]:
class TransformerLRScheduler:
    def __init__(self, optimizer, d_model, warmup_steps):
        """
        Args:
            optimizer: Optimizer to adjust learning rate for
            d_model: Model dimensionality
            warmup_steps: Number of warmup steps
        """
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.step_num = 0  # Track the number of steps

    def step(self):
        """
        Update learning rate based on step number
        """
        # Increment step count
        self.step_num += 1

        # Convert integers to tensors
        step_num_tensor = torch.tensor(self.step_num, dtype=torch.float32)
        warmup_steps_tensor = torch.tensor(self.warmup_steps, dtype=torch.float32)
        d_model_tensor = torch.tensor(self.d_model, dtype=torch.float32)

        # Copy the learning rate formula from Section 5.3 of paper "Attention is All You Need":
        #   lrate = d_model^(-0.5) * min(step_num^(-0.5), step_num * warmup_steps^(-1.5))
        lrate = torch.pow(d_model_tensor, -0.5) * torch.min(
            torch.pow(step_num_tensor, -0.5),
            step_num_tensor * torch.pow(warmup_steps_tensor, -1.5),
        )

        # Apply new learning rate to optimizer
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lrate.item()

In [15]:
class LabelSmoothing(nn.Module):
    # One regularization technique mentioned in Section 5.4 of paper "Attention is All You Need".
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
        self.confidence = 1.0 - smoothing

    def forward(self, logits, target):
        """
        Args:
            logits: Model predictions (batch_size * seq_len, vocab_size) #each row of vocab_size contains probability score of each label
            target: True labels (batch_size * seq_len) #each row of batch size contains the index to the correct label
        """
        vocab_size = logits.size(-1)
        with torch.no_grad():
            # Create a soft target distribution
            true_dist = torch.zeros_like(
                logits
            )  # create the all-zero tensor with the same shape as logits
            true_dist.fill_(
                self.smoothing / vocab_size
            )  # set the probability of tokens other than the true label
            true_dist.scatter_(
                1, target.unsqueeze(1), self.confidence + self.smoothing / vocab_size
            )  # set the probability of the true label

        # Return cross entropy loss. Note that we use log_softmax instead of log + softmax for stability (i.e., avoid overflow
        # and underflow) and efficiency.
        return torch.mean(
            torch.sum(-true_dist * torch.log_softmax(logits, dim=-1), dim=-1)
        )

# Training the Transformer

In [43]:
def train_transformer(
    model, train_dataloader, criterion, optimizer, scheduler, num_epochs, device="cuda"
):
    """
    Training loop for transformer

    Args:
        model: Transformer model
        train_dataloader: DataLoader for training data
        criterion: Loss function (with label smoothing)
        optimizer: Optimizer
        scheduler: Learning rate scheduler
        num_epochs: Number of training epochs
    """
    # 1. Setup
    model = model.to(device)
    model.train()

    # For tracking training progress:
    total_loss = 0
    all_losses = []

    # 2. Training loop
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        epoch_loss = 0

        for batch_idx, batch in enumerate(train_dataloader):
            # Get source and target batches
            src = batch["src"].to(device)
            tgt = batch["tgt"].to(device)

            # Prepare target for input and output
            # Remove last token from target for input
            tgt_input = tgt[:, :-1]
            # Remove first token from target for output
            tgt_output = tgt[:, 1:]

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass. Note that during training, we predict all the tokens in parallel, i.e., in one forward pass.
            outputs = model(src, tgt_input)

            # Reshape outputs and target for loss calculation.
            # outputs shape: (batch_size, seq_len - 1, vocab_size) --> (batch_size * (seq_len - 1), vocab_size)
            outputs = outputs.view(-1, outputs.size(-1))
            # tgt_outputs shape: (batch_size, seq_len - 1) --> (batch_size * (seq_len - 1))
            # Note that view(-1) does not work because `tgt_output` is not contiguous.
            tgt_output = tgt_output.reshape(-1)

            # Calculate loss
            loss = criterion(outputs, tgt_output)

            # Backward pass
            loss.backward()

            # Clip gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Update weights
            optimizer.step()
            scheduler.step()

            # Update loss tracking
            epoch_loss += loss.item()

            # Print progress every N batches
            if batch_idx % 100 == 0:
                print(f"Batch {batch_idx}, Loss: {loss.item():.4f}")

        # Calculate average loss for epoch
        avg_epoch_loss = epoch_loss / len(train_dataloader)
        all_losses.append(avg_epoch_loss)
        print(f"Epoch {epoch + 1}, Loss: {avg_epoch_loss:.4f}")

        # Save checkpoint
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": avg_epoch_loss,
            },
            f"checkpoint_epoch_{epoch + 1}.pt",
        )

    return all_losses

# Prepare the dataset and define the data loader

In [22]:
!pip install spacy

Collecting spacy
  Downloading spacy-3.8.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (27 kB)
Collecting spacy-legacy<3.1.0,>=3.0.11 (from spacy)
  Downloading spacy_legacy-3.0.12-py2.py3-none-any.whl.metadata (2.8 kB)
Collecting spacy-loggers<2.0.0,>=1.0.0 (from spacy)
  Downloading spacy_loggers-1.0.5-py3-none-any.whl.metadata (23 kB)
Collecting murmurhash<1.1.0,>=0.28.0 (from spacy)
  Downloading murmurhash-1.0.12-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.1 kB)
Collecting cymem<2.1.0,>=2.0.2 (from spacy)
  Downloading cymem-2.0.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.5 kB)
Collecting preshed<3.1.0,>=3.0.2 (from spacy)
  Downloading preshed-3.0.9-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.2 kB)
Collecting thinc<8.4.0,>=8.3.4 (from spacy)
  Downloading thinc-8.3.4-cp312-cp312-manylinux_2_17_x86

In [17]:
import os
import urllib.request
import zipfile

import spacy
import torch
from torch.utils.data import DataLoader, Dataset

data_dir = "../data/vanilla-transformer"


def download_multi30k():
    """Download Multi30k dataset if not present"""
    # Create data directory
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    # Download files if they don't exist
    base_url = (
        "https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/"
    )
    files = {
        "train.de": "train.de.gz",
        "train.en": "train.en.gz",
        "val.de": "val.de.gz",
        "val.en": "val.en.gz",
        "test.de": "test_2016_flickr.de.gz",
        "test.en": "test_2016_flickr.en.gz",
    }

    for local_name, remote_name in files.items():
        filepath = f"{data_dir}/{local_name}"
        if not os.path.exists(filepath):
            url = base_url + remote_name
            urllib.request.urlretrieve(url, filepath + ".gz")
            os.system(f"gunzip -f {filepath}.gz")

In [18]:
def load_data(filename):
    """Load data from file"""
    with open(filename, "r", encoding="utf-8") as f:
        return [line.strip() for line in f]

In [19]:
def create_dataset():
    """Create dataset from files"""
    # Download data if needed
    download_multi30k()

    # Load data
    train_de = load_data(f"{data_dir}/train.de")
    train_en = load_data(f"{data_dir}/train.en")
    val_de = load_data(f"{data_dir}/val.de")
    val_en = load_data(f"{data_dir}/val.en")

    return (train_de, train_en), (val_de, val_en)

In [20]:
class TranslationDataset(Dataset):
    def __init__(
        self, src_texts, tgt_texts, src_vocab, tgt_vocab, src_tokenizer, tgt_tokenizer
    ):
        self.src_texts = src_texts
        self.tgt_texts = tgt_texts
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer

    def __len__(self):
        return len(self.src_texts)

    def __getitem__(self, idx):
        src_text = self.src_texts[idx]
        tgt_text = self.tgt_texts[idx]

        # Tokenize
        src_tokens = [tok.text for tok in self.src_tokenizer(src_text)]
        tgt_tokens = [tok.text for tok in self.tgt_tokenizer(tgt_text)]

        # Convert to indices
        src_indices = (
            [self.src_vocab["<s>"]]
            + [self.src_vocab.get(token, 3) for token in src_tokens]  # "<unk>": 3
            + [self.src_vocab["</s>"]]
        )
        tgt_indices = (
            [self.tgt_vocab["<s>"]]
            + [self.tgt_vocab.get(token, 3) for token in tgt_tokens]  # "<unk>": 3
            + [self.tgt_vocab["</s>"]]
        )

        return {"src": torch.tensor(src_indices), "tgt": torch.tensor(tgt_indices)}

In [21]:
def build_vocab_from_texts(texts, tokenizer, min_freq=2):
    """Build vocabulary from texts"""
    counter = {}
    for text in texts:
        for token in [tok.text for tok in tokenizer(text)]:
            counter[token] = counter.get(token, 0) + 1

    # Create vocabulary
    vocab = {"<s>": 0, "</s>": 1, "<blank>": 2, "<unk>": 3}
    idx = 4
    for word, freq in counter.items():
        if freq >= min_freq:
            vocab[word] = idx
            idx += 1
    return vocab

In [22]:
def create_dataloaders(batch_size=32):
    # Load tokenizers
    spacy_de = spacy.load("de_core_news_sm")
    spacy_en = spacy.load("en_core_web_sm")

    # Get data
    (train_de, train_en), (val_de, val_en) = create_dataset()

    # Build vocabularies
    vocab_src = build_vocab_from_texts(train_de, spacy_de)
    vocab_tgt = build_vocab_from_texts(train_en, spacy_en)

    # Create datasets
    train_dataset = TranslationDataset(
        train_de, train_en, vocab_src, vocab_tgt, spacy_de, spacy_en
    )

    val_dataset = TranslationDataset(
        val_de, val_en, vocab_src, vocab_tgt, spacy_de, spacy_en
    )

    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch
    )

    val_dataloader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch
    )

    return train_dataloader, val_dataloader, vocab_src, vocab_tgt

In [23]:
def collate_batch(batch):
    src_tensors = [item["src"] for item in batch]
    tgt_tensors = [item["tgt"] for item in batch]

    # Pad sequences
    src_padded = torch.nn.utils.rnn.pad_sequence(
        src_tensors, batch_first=True, padding_value=2
    )
    tgt_padded = torch.nn.utils.rnn.pad_sequence(
        tgt_tensors, batch_first=True, padding_value=2
    )

    return {"src": src_padded, "tgt": tgt_padded}

# Start the training loop

In [32]:
!python -m spacy download de_core_news_sm

Collecting de-core-news-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.8.0/de_core_news_sm-3.8.0-py3-none-any.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.6/14.6 MB[0m [31m16.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: de-core-news-sm
Successfully installed de-core-news-sm-3.8.0
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('de_core_news_sm')


In [35]:
!python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
Installing collected packages: en-core-web-sm
Successfully installed en-core-web-sm-3.8.0
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


In [24]:
# Create the dataloader
train_dataloader, val_dataloader, vocab_src, vocab_tgt = create_dataloaders()

In [36]:
# Initialize your transformer with the vocabulary sizes
device = "cuda"
model = Transformer(
    src_vocab_size=len(vocab_src),
    tgt_vocab_size=len(vocab_tgt),
    d_model=512,
    num_layers=6,
    num_heads=8,
    d_ff=2048,
    dropout=0.1,
)
criterion = LabelSmoothing(smoothing=0.1).to(device)

In [37]:
# Create the optimizer.
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), betas=(0.9, 0.98), eps=1e-9, weight_decay=0.01)

In [47]:
# Create the learning rate scheduler
d_model = 512
warmup_steps = 4000  # Common warmup step value
scheduler = TransformerLRScheduler(optimizer, d_model, warmup_steps)

In [48]:
# Now you can use your training loop
losses = train_transformer(
    model=model,
    train_dataloader=train_dataloader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=10,
    device=device,
)

Epoch 1/10
Batch 0, Loss: 7.3699
Batch 100, Loss: 4.7729
Batch 200, Loss: 3.9463
Batch 300, Loss: 3.2740
Batch 400, Loss: 3.6947
Batch 500, Loss: 3.9301
Batch 600, Loss: 2.9096
Batch 700, Loss: 3.3312
Batch 800, Loss: 3.5278
Batch 900, Loss: 3.2851
Epoch 1, Loss: 3.7799
Epoch 2/10
Batch 0, Loss: 3.4062
Batch 100, Loss: 3.3547
Batch 200, Loss: 3.5276
Batch 300, Loss: 3.4446
Batch 400, Loss: 3.1315
Batch 500, Loss: 3.2641
Batch 600, Loss: 2.9427
Batch 700, Loss: 3.0760
Batch 800, Loss: 3.5282
Batch 900, Loss: 3.5347
Epoch 2, Loss: 3.1905
Epoch 3/10
Batch 0, Loss: 3.0667
Batch 100, Loss: 2.8320
Batch 200, Loss: 2.6779
Batch 300, Loss: 3.2682
Batch 400, Loss: 3.2879
Batch 500, Loss: 2.7618
Batch 600, Loss: 3.3126
Batch 700, Loss: 3.3712
Batch 800, Loss: 3.1600
Batch 900, Loss: 3.0311
Epoch 3, Loss: 3.0968
Epoch 4/10
Batch 0, Loss: 3.4050
Batch 100, Loss: 3.1580
Batch 200, Loss: 3.0060
Batch 300, Loss: 2.9221
Batch 400, Loss: 3.3623
Batch 500, Loss: 3.5122
Batch 600, Loss: 3.2513
Batch 700,