<a href="https://colab.research.google.com/github/simply-pouria/The-LMs-Book/blob/main/TheLMBook_Chapter4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformer Implementation

### Attention Head

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, emb_dim, d_h):
        super().__init__()
        self.W_Q = nn.Parameter(torch.empty(emb_dim, d_h))
        self.W_K = nn.Parameter(torch.empty(emb_dim, d_h))
        self.W_V = nn.Parameter(torch.empty(emb_dim, d_h))
        self.d_h = d_h

    def forward(self, x, mask):
        Q = x @ self.W_Q
        K = x @ self.W_K
        V = x @ self.W_V

        Q, K = rope(Q), rope(K)

        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_h )
        masked_scores = scores.masked_fill(mask == 0, float("-inf"))
        attention_weights = torch.softmax(masked_scores, dim=-1)
        return attention_weights @ V

### Multi-Head Attention Mechanism

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super().__init__()
        d_h = emb_dim // num_heads
        self.heads = nn.ModuleList([
            AttentionHead(emb_dim, d_h)
            for _ in range(num_heads)
        ])
        self.W_O = nn.Parameter(torch.empty(emb_dim, emb_dim))

    def forward(self, x, mask):
        head_outputs = [head(x, mask) for head in self.heads]
        x = torch.cat(head_outputs, dim=-1)
        return x @ self.W_O

### MLP

In [None]:
class MLP(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.W_1 = nn.Parameter(torch.empty(emb_dim, emb_dim * 4))
        self.B_1 = nn.Parameter(torch.empty(emb_dim * 4))
        self.W_2 = nn.Parameter(torch.empty(emb_dim * 4, emb_dim))
        self.B_2 = nn.Parameter(torch.empty(emb_dim))

    def forward(self, x):
        x = x @ self.W_1 + self.B_1
        x = torch.relu(x)
        x = x @ self.W_2 + self.B_2
        return x

## Ceating the Decoder Block

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super().__init__()
        self.norm1 = RMSNorm(emb_dim)
        self.attn = MultiHeadAttention(emb_dim, num_heads)
        self.norm2 = RMSNorm(emb_dim)
        self.mlp = MLP(emb_dim)

    def forward(self, x, mask):
        attn_out = self.attn(self.norm1(x), mask)
        x = x + attn_out
        mlp_out = self.mlp(self.norm2(x))
        x = x + mlp_out
        return x

## Creating a Decoder Language Model

In [None]:
class DecoderLanguageModel(nn.Module):
    def __init__(
        self, vocab_size, emb_dim,
        num_heads, num_blocks, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(
            vocab_size, emb_dim,
            padding_idx=pad_idx)
        self.layers = nn.ModuleList([
            DecoderBlock(emb_dim, num_heads) for _ in range(num_blocks) ])
        self.output = nn.Parameter(torch.rand(emb_dim, vocab_size))

    def forward(self, x):
        x = self.embedding(x)
        _, seq_len, _ = x.shape
        mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device))
        for layer in self.layers:
          x = layer(x, mask)
        return x @ self.output

[Full Implementation](https://github.com/aburkov/theLMbook/blob/main/news_decoder_language_model.ipynb)