In [2]:
import torch
import torch.nn as nn 
from einops import einsum, rearrange
torch.manual_seed(123);

In [3]:
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps 
        self.scale = nn.Parameter(torch.ones(d_model)).float()

    def forward(self, x):
        var = x.pow(2).mean(dim=-1, keepdim=True)
        x_norm = x * torch.rsqrt(var + self.eps)
        return (x_norm * self.scale).to(dtype=x.dtype)

In [4]:
example_batch = torch.randn(2,3,4)

rms_norm = RMSNorm(example_batch.shape[-1], eps=1e-5)
rms_norm_pytorch = nn.RMSNorm(example_batch.shape[-1], eps=1e-5)

assert torch.allclose(rms_norm(example_batch), rms_norm_pytorch(example_batch))

In [5]:
class SiLU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)

In [6]:
silu = SiLU()
assert torch.allclose(silu(example_batch), nn.functional.silu(example_batch))

In [7]:
class FeedForward(nn.Module):
    def __init__(self, cfg): 
        super().__init__()
        d_model = cfg["d_model"]
        d_ff = cfg["d_model"] * 4 if "hidden_dim" not in cfg else cfg["hidden_dim"]
        self.fc1 = nn.Linear(d_model, d_ff, bias=False)
        self.fc2 = nn.Linear(d_model, d_ff, bias=False)
        self.fc3 = nn.Linear(d_ff, d_model, bias=False)
        self.silu = SiLU()

    def forward(self, x):
        return self.fc3(self.silu(self.fc1(x)) * self.fc2(x))

In [8]:
class RoPE(nn.Module):
    def __init__(self, head_dim: int, base: int = 10_000, ctx_len: int = 4096):
        super().__init__()
        self.base = base
        self.head_dim = head_dim
        self._compute_rope(ctx_len)

    def _compute_rope(self, ctx_len):
        inv_freq = 1.0 / (self.base ** torch.arange(0, self.head_dim, 2).float() / self.head_dim)
        pos = torch.arange(ctx_len)
        angles = einsum(pos, inv_freq, "n,d->n d") # (ctx_len, d//2)
        angles = torch.cat((angles, angles), dim=1) # (ctx_len, d)

        self.cos = angles.cos()
        self.sin = angles.sin()

    def _rotate_half(self, x):
        x1 = x[...,:self.head_dim//2]
        x2 = x[...,self.head_dim//2:]
        return torch.cat((-x2, x1), dim=-1)

    def forward(self, x):
        """
        x (in MHSA): (batch_size, n_heads, seq_len, head_dim)
        """
        _, _, seq_len, _ = x.shape
        cos = self.cos[None,None,:seq_len,:]
        sin = self.sin[None,None,:seq_len,:]
        return x * cos + self._rotate_half(x) * sin

In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, ctx_len, n_heads):
        super().__init__()

        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key   = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)
        self.out_proj = nn.Linear(d_out, d_out, bias=False)

        mask = torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1)
        self.register_buffer("mask", mask)

        self.n_heads = n_heads
        head_dim = d_out // n_heads
        self.rope = RoPE(head_dim, ctx_len=ctx_len)

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = rearrange(keys, "batch_size seq_len (n_heads head_dim) -> batch_size n_heads seq_len head_dim", n_heads=self.n_heads)
        queries = rearrange(queries, "batch_size seq_len (n_heads head_dim) -> batch_size n_heads seq_len head_dim", n_heads=self.n_heads)
        values = rearrange(values, "batch_size seq_len (n_heads head_dim) -> batch_size n_heads seq_len head_dim", n_heads=self.n_heads)

        keys = self.rope(keys)
        queries = self.rope(queries)

        attn_scores = einsum(keys, queries, "... s1 head_dim, ... s2 head_dim -> ... s1 s2")
        attn_scores.masked_fill_(self.mask.bool()[:seq_len,:seq_len], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = einsum(attn_weights, values, "... s1 s2, ... s2 head_dim -> ... s1 head_dim")
        context_vec = rearrange(context_vec, "batch_size n_heads seq_len head_dim -> batch_size seq_len (n_heads head_dim)")
        context_vec = self.out_proj(context_vec)
        return context_vec


In [10]:
batch_size = 1
context_len = 100
max_context_len = 4096
embed_dim = 128
num_heads = 4


example_batch = torch.randn((batch_size, context_len, embed_dim))

mha = MultiHeadAttention(
    d_in=embed_dim,
    d_out=embed_dim,
    ctx_len=max_context_len,
    n_heads=num_heads
)

mha(example_batch)
del mha

In [14]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.mhsa = MultiHeadAttention(
            d_in=cfg["d_model"],
            d_out=cfg["d_model"],
            ctx_len=cfg["context_length"],
            n_heads=cfg["n_heads"],
        )
        self.ff = FeedForward(cfg)
        self.norm1 = RMSNorm(cfg["d_model"])
        self.norm2 = RMSNorm(cfg["d_model"])

    def forward(self, x):
        x = x + self.mhsa(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

In [15]:
class Llama2Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["d_model"])
        self.transformer_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )
        self.final_norm = RMSNorm(cfg["d_model"])
        self.out_head = nn.Linear(cfg["d_model"], cfg["vocab_size"], bias=False)

    def forward(self, in_idx):
        x = self.tok_emb(in_idx)
        x = self.transformer_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits


In [16]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,     # Vocabulary size
    "context_length": 1024,  # Context length
    "d_model": 768,          # Embedding dimension
    "n_heads": 12,           # Number of attention heads
    "n_layers": 12,          # Number of layers
    # "drop_rate": 0.1,        # Dropout rate
    # "qkv_bias": False        # Query-Key-Value bias
}

model = Llama2Model(GPT_CONFIG_124M)
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,}")

190,460,160


In [17]:
LLAMA2_CONFIG_7B = {
    "vocab_size": 32000,     # Vocabulary size
    "context_length": 4096,  # Context length
    "d_model": 4096,         # Embedding dimension
    "n_heads": 32,           # Number of attention heads
    "n_layers": 32,          # Number of layers
    "hidden_dim": 11008,     # NEW: Size of the intermediate dimension in FeedForward
    "dtype": torch.bfloat16  # NEW: Lower-precision dtype to reduce memory usage
}
model = Llama2Model(LLAMA2_CONFIG_7B)
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,}")

6,738,415,616


In [None]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_in, d_out, n_heads, n_kv_groups):
        super().__init__()
        assert d_out % n_heads == 0, "d_out must be divisible by n_heads!"
        assert n_heads % n_kv_groups == 0, "n_heads must be divisible by n_kv_groups!"

        self.n_kv_groups = n_kv_groups
        self.group_size = n_heads // n_kv_groups
        self.head_dim = d_out // n_heads

        self.W_key = nn.Lienar(d_in, n_kv_groups * self.head_dim, bias=False)
        self.W_value = nn.Linear(d_in, n_kv_groups * self.head_dim, bias=False)
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.out_proj = nn.Linear(d_out, d_out, bias=False)

    def forward(self, x, mask=None, cos=None, sin=None):
        batch_size, seq_len, d_in = x.shape

        queries = self.W_query(x) # (batch_size, seq_len, d_out)
        keys    = self.W_key(x)   # (batch_size, seq_len, n_kv_groups * head_dim)
        values  = self.W_value(x) # (batch_size, seq_len, n_kv_groups * head_dim)

        queries = rearrange(queries, "batch_size seq_len (n_heads head_dim) -> batch_size n_heads seq_len head_dim", head_dim=self.head_dim)
        keys    = rearrange(keys,    "batch_size seq_len (n_kv_groups head_dim) -> batch_size n_kv_groups seq_len head_dim", head_dim=self.head_dim)
        values  = rearrange(values,  "batch_size seq_len (n_kv_groups head_dim) -> batch_size n_kv_groups seq_len head_dim", head_dim=self.head_dim)

        if cos is None:
            rope = RoPE()
            keys = rope(keys)
            queries = rope(queries)
    
        # example: [K1, K2] -> [K1, K1, K2, K2]
        keys   = keys.repeat_interleave(self.group_size, dim=1)   # (batch_size, n_heads, seq_len, head_dim)
        values = values.repeat_interleave(self.group_size, dim=1) # (batch_size, n_heads, seq_len, head_dim)

        attn_scores = einsum(queries, keys, ".. s1 head_dim, ... s2 head_dim -> ... s1 s2")
        if mask is None:
            mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)

        attn_scores.masked_fill_(mask, -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = einsum(attn_weights, values, "... s1 s2, ... s2 head_dim -> ... s1 head_dim")
        context_vec = rearrange(context_vec, "batch_size n_heads seq_len head_dim -> batch_size seq_len (n_heads head_dim)")
        context_vec = self.out_proj(context_vec)
        return context_vec
