In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [15]:
class SelfAttention(nn.Module):

    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.q = nn.Linear(embed_dim, hidden_dim)
        self.k = nn.Linear(embed_dim, hidden_dim)
        self.v = nn.Linear(embed_dim, hidden_dim)

    def scaled_dot_product_attention(self, q, k, v, mask=None):
        d_k = q.size(-1)
        attn_logits = torch.bmm(q, k.transpose(1, 2)) / math.sqrt(d_k)
        if mask is not None:
            attn_logits = attn_logits.masked_fill(mask == 0, -torch.inf)
        weights = F.softmax(attn_logits, dim=-1)
        return torch.bmm(scores, v)

    def forward(self, x):
        q, k, v = self.q(x), self.k(x), self.v(x)
        return self.scaled_dot_product_attention(q, k, v)

In [18]:
# Testing with random inputs
hidden_dim = 768
seq_len = 10
batch_size = 32
n_heads = 6
head_dim = hidden_dim // n_heads
print(f'hidden_dim: {hidden_dim}, n_heads: {n_heads}, head_dim: {head_dim}')

x = torch.randn(batch_size, seq_len, hidden_dim)
x.shape

hidden_dim: 768, n_heads: 6, head_dim: 128


torch.Size([32, 10, 768])

In [19]:
sa = SelfAttention(hidden_dim, head_dim)
output = sa(x)
print(output.shape)
output

torch.Size([32, 10, 128])


tensor([[[ 0.1354, -0.0800, -0.2062,  ...,  0.0406, -0.0935, -0.0527],
         [ 0.2048, -0.0702, -0.2163,  ..., -0.0254, -0.2176,  0.1759],
         [ 0.0806, -0.1309, -0.1470,  ...,  0.0566, -0.0871, -0.0675],
         ...,
         [ 0.0487, -0.0901, -0.1056,  ...,  0.1106, -0.1494,  0.0155],
         [ 0.1283, -0.0952, -0.1791,  ...,  0.0519, -0.1772,  0.1507],
         [ 0.2081, -0.0737, -0.2328,  ...,  0.0425, -0.1503,  0.0442]],

        [[ 0.2795, -0.0317,  0.0665,  ...,  0.1240, -0.0638, -0.1576],
         [ 0.3139, -0.0117,  0.1160,  ...,  0.1114,  0.0058, -0.2802],
         [ 0.2876,  0.0218,  0.0989,  ..., -0.0214,  0.0939, -0.1724],
         ...,
         [ 0.3883, -0.0011,  0.0167,  ...,  0.0582,  0.0843, -0.2127],
         [ 0.4049, -0.0512,  0.0827,  ...,  0.1879, -0.0714, -0.3038],
         [ 0.4627, -0.0953,  0.0712,  ..., -0.0777,  0.1226, -0.2759]],

        [[ 0.0140, -0.0202,  0.1011,  ...,  0.0093,  0.0582,  0.0336],
         [ 0.1245, -0.1210, -0.0290,  ..., -0

In [23]:
class MultiHeadAttention(nn.Module):

    def __init__(self, n_heads, hidden_dim):
        super().__init__()
        self.n_heads = n_heads
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // n_heads
        assert self.head_dim * n_heads == hidden_dim, "Hidden dimension must be divisible "
        "by the number of heads"
        
        self.heads = nn.ModuleList([
            SelfAttention(self.hidden_dim, self.head_dim) for _ in range(n_heads)
        ])
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        output = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.out_proj(output)

In [24]:
n_heads = 8

mha = MultiHeadAttention(n_heads=n_heads, hidden_dim=hidden_dim)
output = mha(x)

print(output.shape)

torch.Size([32, 10, 768])


In [25]:
class MLP(nn.Module):

    def __init__(self, hidden_dim, intermediate_dim):
        super().__init__()
        self.fc1 = nn.Linear(hidden_dim, intermediate_dim)
        self.fc2 = nn.Linear(hidden_dim, intermediate_dim)
        self.out = nn.Linear(intermediate_dim, hidden_dim)

    def forward(self, x):
        x = F.silu(self.fc1(x)) + self.fc2(x)
        return self.out(x)

In [27]:
class Block(nn.Module):

    def __init__(self, n_heads, hidden_dim, intermediate_dim):
        super().__init__()
        self.mha = MultiHeadAttention(n_heads, hidden_dim)
        self.mlp = MLP(hidden_dim, intermediate_dim)

    def forward(self, x):
        h = self.mha(x) + x
        out = self.mlp(h) + h
        return out

In [28]:
class BabyLlama(nn.Module):

    def __init__(self, vocab_size, max_seq_len, n_layers, n_heads, hidden_dim, intermediate_dim):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
        self.pos_embedding = nn.Embedding(max_seq_len, hidden_dim)
        self.layers = nn.ModuleList([
            Block(n_heads, hidden_dim, intermediate_dim) for _ in range(n_layers)
        ])
        self.out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        tok_embeddings = self.token_embedding(x)

        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        pos_embeddings = self.pos_embedding(positions)

        x = tok_embeddings + pos_embeddings
        for layer in self.layers:
            x = layer(x)

        logits = self.out(x)
        probs = F.softmax(logits, dim=-1)
        return probs

In [None]:
vocab_size = 32_000
max_seq_len = 1024
n_layers = 6


