In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from math import sqrt


In [35]:
n_layers = 6 # 22 Tiny LLAMA
n_heads = 6 # 32 Tiny LLAMA
d_model = 768 # 2048 Tiny LLAMA
intermediate_dim = d_model * 4



### MHA
<img src="https://data-science-blog.com/wp-content/uploads/2022/01/mha_img_original.png" width=500>

In [20]:
# Generate random data 
sequence_length = 10 # number of tokens
batch_size = 5
input_data = torch.rand((batch_size, sequence_length, d_model)) # [batch_size, sequence_length, d_model]

In [21]:
input_data.shape

torch.Size([5, 10, 768])

In [24]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super(AttentionHead, self).__init__()
        self.q = nn.Linear(embed_dim, hidden_dim)
        self.k = nn.Linear(embed_dim, hidden_dim)
        self.v = nn.Linear(embed_dim, hidden_dim)

    def scaled_dot_product_attention(self, q, k, v, mask=None):    
        dim_k = q.size(-1)
        scores = torch.bmm(q, k.transpose(1, 2)) / sqrt(dim_k) # k.T = [batch_size, sequence_length, embed_dim]  -> [batch_size, embed_dim, sequence_length]
        if mask is not None:
            scores = torch.masked_fill(scores, mask == 0, -torch.inf)
        weights = F.softmax(scores, dim=-1)

        return torch.bmm(weights, v)
    def forward(self, hidden_state, mask=None):
        output = self.scaled_dot_product_attention(
            self.q(hidden_state), self.k(hidden_state), self.v(hidden_state), mask=mask)
        return output


In [25]:
attn = AttentionHead(d_model, d_model//n_heads)
attn(input_data).shape

torch.Size([5, 10, 128])

In [26]:
class MHA(nn.Module):
    def __init__(self, n_heads, hidden_dim):
        super(MHA, self).__init__()
        embed_dim = hidden_dim
        head_dim = hidden_dim // n_heads
        self.heads = nn.ModuleList(
            [AttentionHead(embed_dim, head_dim) for _ in range(n_heads)]
            )
        self.out_proj = nn.Linear(embed_dim, embed_dim)
    def forward(self, hidden_state):
        x = torch.cat([h(hidden_state) for h in self.heads], dim=-1 )
        return self.out_proj(x)

In [27]:
mha = MHA(n_heads, d_model)
mha(input_data).shape

torch.Size([5, 10, 768])

In [40]:
input_data.shape

torch.Size([5, 10, 768])

In [41]:
class LLaMAMLP(nn.Module):
    def __init__(self, hidden_dim, intermediate_dim):
        super(LLaMAMLP, self).__init__()
        hidden_dim_factor = 4
        self.linear_1 = nn.Linear(hidden_dim, intermediate_dim)
        self.linear_2 = nn.Linear(hidden_dim, intermediate_dim)
        self.activation_fn = nn.SiLU()
        self.out_proj = nn.Linear(intermediate_dim, hidden_dim)

    def forward(self, hidden_state):
        x_fc_1 = self.linear_1(hidden_state)
        x_fc_2 = self.linear_2(hidden_state)
        x = self.activation_fn(x_fc_1) * x_fc_2
        return self.out_proj(x)

In [42]:
intermediate_dim

3072

In [43]:
mlp=LLaMAMLP(d_model, intermediate_dim)
mlp(input_data).shape

torch.Size([5, 10, 768])

In [44]:
class Block(nn.Module):
    def __init__(self, n_heads, hidden_dim, intermediate_dim):
        super(Block, self).__init__()
        self.n_heads = n_heads
        self.hidden_dim = hidden_dim
        self.intermediate_dim = intermediate_dim
        self.mha = MHA(n_heads=n_heads, hidden_dim=hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.mlp = LLaMAMLP(hidden_dim, intermediate_dim)

    def forward(self, hidden_state, mask = None):
        x = self.mha(hidden_state)
        x = self.layer_norm(hidden_state) + x
        x_fc = self.mlp(x)
        x += x_fc
        return x




In [45]:
block = Block(n_heads, d_model, intermediate_dim)
block(input_data).shape

torch.Size([5, 10, 768])

In [53]:
class babyLLaMA(nn.Module):
    def __init__(self, max_seq_len, vocab_size, n_layers, n_heads, hidden_dim, intermediate_dim):
        super(babyLLaMA, self).__init__()
        self.emb = nn.Embedding(vocab_size, hidden_dim)
        self.pos = nn.Embedding(max_seq_len, hidden_dim)
        self.blocks = nn.ModuleList(
            [Block(n_heads, hidden_dim, intermediate_dim) for _ in range(n_layers)]
            )
        self.out_proj = nn.Linear(hidden_dim, vocab_size)

    def forward(self, hidden_state):
        emb = self.emb(hidden_state)
        seq_len = hidden_state.size(1)
        positions = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
        pos = self.pos(positions)
        x = emb + pos
        for b in self.blocks:
            x = b(x)
        
        x = self.out_proj(x)
        return F.softmax(x, dim=-1)
    


In [54]:
llm = babyLLaMA(d_model, 32000, n_layers, n_heads, d_model, intermediate_dim)
input_ids = torch.randint(1, 32000, (batch_size, sequence_length))


In [56]:
llm(input_ids).shape

torch.Size([5, 10, 32000])

In [58]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
count_parameters(llm)