# Transformer from Scratch with Pytorch

This is a Transformer implementation using Pytorch. This is more like a personal understanding than fully optimized version. The example will use `airline-passengers` as reference data to make sure the model is working properly by running both Torch bidrectional with multi-layer LSTM and the hand made version.

In [68]:
import os
import math

os.chdir("/Users/yenchenchou/Documents/GitHub/ml-learning")

import torch
import torch.nn as nn
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
import polars as pl


from typing import Optional

In [2]:
class EnvInit:
    def available_device(self) -> torch.device:
        if torch.backends.mps.is_available():
            device = torch.device("mps")
        elif torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
        return device

    def fix_seed(self, seed: int) -> int:
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
        if torch.backends.mps.is_available():
            torch.mps.manual_seed(seed)
        return seed

In [78]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.wq = nn.Linear(d_model, d_model, bias=False)
        self.wk = nn.Linear(d_model, d_model, bias=False)
        self.wv = nn.Linear(d_model, d_model, bias=False)
        self.wo = nn.Linear(d_model, d_model, bias=False)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None):
        batch_size, seq_len, _ = q.size()
        query = (
            self.wq(q).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        )  # (batch, seq_len, n_heads, d_k) -> (batch, n_heads, seq_len, d_k)
        key = (
            self.wk(k).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        )  # (batch, seq_len, n_heads, d_k) -> (batch, n_heads, seq_len, d_k)
        value = (
            self.wv(v).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        )  # (batch, seq_len, n_heads, d_k) -> (batch, n_heads, seq_len, d_k)
        attention_score = (
            query @ key.transpose(-2, -1) / math.sqrt(self.d_k)
        )  # (batch, n_heads, seq_len, d_k) @ (batch, n_heads, d_k, seq_len) -> (batch, n_heads, seq_len, seq_len)
        if mask is not None:
            attention_score = attention_score.masked_fill(mask == 0, float("-inf"))
        attn = self.softmax(attention_score)
        # (batch, n_heads, seq_len, seq_len) @ (batch, n_heads, seq_len, d_k) -> (batch, n_heads, seq_len, d_k)
        attn_output = attn @ value
        # (batch, n_heads, seq_len, seq_len) -> (batch, seq_len, n_heads, seq_len)
        attn_output = (
            attn_output.transpose(1, 2)
            .contiguous()
            .view(batch_size, seq_len, self.d_model)
        )
        # (batch, seq_len, d_model) -> (batch, seq_len, d_model)
        attn_output = self.wo(attn_output)
        return attn_output

In [63]:
torch.tril(torch.ones(3, 3))


tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [65]:
torch.triu(torch.ones(3, 3), diagonal=1) == 0


tensor([[ True, False, False],
        [ True,  True, False],
        [ True,  True,  True]])

In [4]:
def create_padding_mask(seq, pad_token=0):
    mask = (seq == pad_token).unsqueeze(1).unsqueeze(2)
    return mask  # (batch_size, 1, 1, seq_len)

# Example usage
seq = torch.tensor(
    [
        [[7, 6, 0, 0], [1, 2, 3, 0]],
        [[7, 6, 0, 0], [1, 0, 0, 0]]
    ]
)
print(seq.shape)
padding_mask = create_padding_mask(seq)
print(padding_mask)

torch.Size([2, 2, 4])
tensor([[[[[False, False,  True,  True],
           [False, False, False,  True]]]],



        [[[[False, False,  True,  True],
           [False,  True,  True,  True]]]]])


In [37]:
x = torch.randn(2, 3)
tmp = torch.stack([x, x], dim=1)
print(tmp)
print(tmp.shape)

tensor([[[ 1.9854, -0.1618, -0.2974],
         [ 1.9854, -0.1618, -0.2974]],

        [[ 0.5067,  2.3083, -1.2554],
         [ 0.5067,  2.3083, -1.2554]]])
torch.Size([2, 2, 3])


In [49]:
x = torch.randint(0, 10, (2, 3)).float()
print(x)
print(torch.mean(x, dim=1))
print(torch.mean(x, dim=0))

tensor([[9., 7., 0.],
        [9., 8., 5.]])
tensor([5.3333, 7.3333])
tensor([9.0000, 7.5000, 2.5000])


In [55]:
x[0][-2:]

tensor([7., 0.])

In [35]:
tmp = torch.stack([x, x], dim=0)
print(tmp)
print(tmp.shape)

tensor([[[-0.7003, -1.9755, -0.6597]],

        [[-0.7003, -1.9755, -0.6597]]])
torch.Size([2, 1, 3])


In [20]:
def create_sequence_mask(seq):
    seq_len = seq.size(1)
    mask = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1)
    return mask  # (seq_len, seq_len)

# Example usage
seq_len = 4
sequence_mask = create_sequence_mask(torch.zeros(seq_len, seq_len))
print(sequence_mask)

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])


In [21]:
def create_look_ahead_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    return mask  # (seq_len, seq_len)

# Example usage
look_ahead_mask = create_look_ahead_mask(4)
print(look_ahead_mask)

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])


In [23]:
def create_padding_mask(seq):
    # seq: (batch_size, seq_len)
    # Creates a binary mask where padding tokens are 0
    return (seq != 0).unsqueeze(1).unsqueeze(2)  # Shape: (batch_size, 1, 1, seq_len)

def create_look_ahead_mask(seq_len):
    # Creates a causal mask with an upper triangular matrix of 0s
    return torch.tril(torch.ones((seq_len, seq_len))).unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, seq_len)


In [25]:
# Define sequences with padding
input_seq = torch.tensor([
    [4, 5, 6, 0, 0],
    [3, 8, 9, 2, 0]
])

# Create padding mask
padding_mask = create_padding_mask(input_seq)
print("Padding Mask:\n", padding_mask.int())
print("Padding Mask:\n", padding_mask.squeeze().int())

Padding Mask:
 tensor([[[[1, 1, 1, 0, 0]]],


        [[[1, 1, 1, 1, 0]]]], dtype=torch.int32)
Padding Mask:
 tensor([[1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0]], dtype=torch.int32)


In [26]:
# Define sequence length for look-ahead mask
seq_len = 4
look_ahead_mask = create_look_ahead_mask(seq_len)
print("Look-Ahead Mask:\n", look_ahead_mask.squeeze().int())

Look-Ahead Mask:
 tensor([[1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 1, 1, 0],
        [1, 1, 1, 1]], dtype=torch.int32)
