In [1]:
from importlib.metadata import version

pkgs = [
    "torch", "transformers"
]

for pkg in pkgs:
    print(f"{pkg}: {version(pkg)}")

torch: 2.4.1
transformers: 4.45.2


In [4]:
import math
import yaml

from pydantic import BaseModel
import torch
import torch.nn as nn
from torch.nn import functional as F

In [22]:
class ModelConfig(BaseModel):
    max_seq_len: int = 1024
    embed_dim: int = 768
    num_heads: int = 12
    num_layers: int = 12
    attn_dropout: float = 0.1
    resid_dropout: float = 0.1
    hidden_dropout: float = 0.1

# read in the config.yaml file
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

model_config = ModelConfig(**config)
model_config

ModelConfig(max_seq_len=1024, embed_dim=768, num_heads=12, num_layers=12, attn_dropout=0.1, resid_dropout=0.1, hidden_dropout=0.1)

## Multihead (Causal) Self-Attention

In [24]:
class CausalSelfAttention(nn.Module):
    """
    Causal self-attention layer, masking the future tokens.
    """
    def __init__(self, cfg):
        super().__init__()
        self.num_heads = cfg.num_heads
        self.q_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
        self.k_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
        self.v_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
        self.out_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)

        self.attn_dropout = nn.Dropout(0.1)
        self.resid_dropout = nn.Dropout(0.1)
        
        # Create a bias tensor to prevent attention to future tokens
        mask = torch.tril(torch.ones(cfg.max_seq_len, cfg.max_seq_len))
        self.register_buffer(
            'mask', (mask == 0).view(1, 1, cfg.max_seq_len, cfg.max_seq_len)
        )
        # mask will be a tensor like the following:
        # tensor([[[[False, True,  True,  ...,  True],
        #           [False, False, True,  ...,  True],
        #           [False, False, False, ...,  True],
        #           ...,
        #           [False, False, False, ..., False]]]])
        # where True values indicate that the token should be masked
        # i.e., replaced with -inf in the attention scores
        
    def forward(self, x):
        # Apply linear transformations to get queries, keys, and values
        # x: [B, T, C]
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        # q,k,v: [B, T, C]

        # Split the queries, keys, and values into multiple heads
        B, T, C = q.size()
        q = q.view(B, T, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = k.view(B, T, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = v.view(B, T, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        # q,k,v: [B, nh, T, C//nh]
        
        # Calculate attention scores
        scores = torch.matmul(q, k.permute(0, 1, 3, 2))
        scores = scores / (math.sqrt(k.size(-1)))
        scores.masked_fill_(self.mask[:, :, :T, :T], -torch.inf)
        # scores: [B, nh, T, T]
        
        # Apply softmax to get attention weights
        attn_weights = F.softmax(scores, dim=-1)
        # attn_weights: [B, nh, T, T]

        attn_weights = self.attn_dropout(attn_weights)
        
        # Multiply attention weights with values
        out = torch.matmul(attn_weights, v)
        # out: [B, nh, T, C//nh]

        # Concatenate the heads and apply a linear transformation
        out = out.permute(0, 2, 1, 3).contiguous().view(B, T, C)
        out = self.out_proj(out)
        # out: [B, T, C]

        out = self.resid_dropout(out)
        
        return out

# testing
cfg = ModelConfig(
    max_seq_len=10,
    embed_dim=32,
    num_heads=8,
    num_layers=2,
    attn_dropout=0.1,
    resid_dropout=0.1,
    hidden_dropout=0.1
)
x = torch.randn(2, 5, cfg.embed_dim)
mha = CausalSelfAttention(cfg)
print(mha)

out = mha(x)
print("\nOutput:", out.size())  # torch.Size([2, 5, 10])

CausalSelfAttention(
  (q_proj): Linear(in_features=32, out_features=32, bias=True)
  (k_proj): Linear(in_features=32, out_features=32, bias=True)
  (v_proj): Linear(in_features=32, out_features=32, bias=True)
  (out_proj): Linear(in_features=32, out_features=32, bias=True)
  (attn_dropout): Dropout(p=0.1, inplace=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)

Output: torch.Size([2, 5, 32])


## Feed-Forward Network (FFN)

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FeedForwardNetwork(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        embed_dim = cfg.embed_dim
        hidden_dim = cfg.embed_dim * 4
        p_drop = cfg.hidden_dropout
        # Two linear layers with activation in between
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.gelu = nn.GELU(approximate='tanh')
        self.resid_dropout = nn.Dropout(p_drop)

    def forward(self, x):            # [B, T, C]
        x = self.gelu(self.fc1(x))  # [B, T, 2C]
        x = self.fc2(x)              # [B, T, C]
        x = self.resid_dropout(x)

        return x

# testing
ffn = FeedForwardNetwork(cfg)
print(ffn)
x = torch.randn(2, 5, cfg.embed_dim)
out = ffn(x)
print("\nOutput:", out.size())

FeedForwardNetwork(
  (fc1): Linear(in_features=32, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=32, bias=True)
  (gelu): GELU(approximate='tanh')
  (resid_dropout): Dropout(p=0.1, inplace=False)
)

Output: torch.Size([2, 5, 32])


## Transformer Block

In [21]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mha = CausalSelfAttention(config)
        self.ln1 = nn.LayerNorm(config.embed_dim)
        self.ffn = FeedForwardNetwork(config)
        self.ln2 = nn.LayerNorm(config.embed_dim)

        self.resid_dropout = nn.Dropout(config.resid_dropout)

    def forward(self, x):
        # Apply self-attention and add residual connection
        shortcut = x
        x = self.ln1(x)
        x = self.mha(x)[0]
        x = self.resid_dropout(x)
        x = shortcut + x

        # Apply feedforward network and add residual connection
        shortcut = x
        x = self.ln2(x)
        x = self.ffn(x)
        x = self.resid_dropout(x)
        x = shortcut + x

        return x
    
# testing
transformer_block = TransformerBlock(cfg)
print(transformer_block)
x = torch.randn(2, 5, cfg.embed_dim)
out = transformer_block(x)
print("\nInput:", x.size())
print("Output:", out.size())

TransformerBlock(
  (mha): CausalSelfAttention(
    (q_proj): Linear(in_features=32, out_features=32, bias=True)
    (k_proj): Linear(in_features=32, out_features=32, bias=True)
    (v_proj): Linear(in_features=32, out_features=32, bias=True)
    (out_proj): Linear(in_features=32, out_features=32, bias=True)
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (ffn): FeedForwardNetwork(
    (fc1): Linear(in_features=32, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=32, bias=True)
    (gelu): GELU(approximate='tanh')
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)

Input: torch.Size([2, 5, 32])
Output: torch.Size([2, 5, 32])
