<a href="https://colab.research.google.com/github/taaha3244/gpt2-scratch/blob/main/LLM_from_scratch_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):

        super().__init__()
        self.d_out = d_out

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):

        batch_size, seq_len, d_in = x.shape

        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # Compute attention scores
        # (batch_size, seq_len, d_out) @ (batch_size, d_out, seq_len)
        # -> (batch_size, seq_len, seq_len)
        attn_scores = queries @ keys.transpose(-2, -1)

        # Scale attention scores
        attn_scores = attn_scores / (self.d_out ** 0.5)

        # Compute attention weights using softmax
        attn_weights = torch.softmax(attn_scores, dim=-1)

        # Compute context vectors
        # (batch_size, seq_len, seq_len) @ (batch_size, seq_len, d_out)
        # -> (batch_size, seq_len, d_out)
        context_vectors = attn_weights @ values

        return context_vectors




In [2]:
batch_size = 2
seq_len = 4
d_in = 8
d_out = 8

x = torch.randn(batch_size, seq_len, d_in)

self_attention = SelfAttention(d_in=d_in, d_out=d_out)

context_vectors = self_attention(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {context_vectors.shape}")

print(f"Input shape: {x}")
print(f"Output shape: {context_vectors}")

Input shape: torch.Size([2, 4, 8])
Output shape: torch.Size([2, 4, 8])
Input shape: tensor([[[ 0.2159, -0.3583, -0.4211, -0.7047,  0.3477,  2.0457, -0.2310,
           0.8658],
         [-0.1003,  0.2543, -1.0935, -0.5496,  0.5316,  0.5560,  0.6634,
          -0.8767],
         [-0.1346,  0.6268,  0.2724, -0.0927,  0.0892, -0.8980,  1.0593,
          -0.7124],
         [-0.3824, -0.2236, -0.9022,  0.3905,  1.2533,  0.5790,  1.1562,
           0.5745]],

        [[ 0.7469, -0.6548, -0.0366,  0.6984,  0.1737,  0.3693, -0.2300,
           0.8148],
         [-0.3705, -1.1015, -0.5517, -1.6435, -0.0266, -0.3467,  0.7775,
          -1.2061],
         [-0.7777,  1.3427, -1.1442, -0.3512,  0.4188, -0.7924, -0.2582,
           1.2532],
         [-0.3196,  0.7955, -0.5760,  0.0781,  0.0887,  1.1598,  0.9838,
          -1.8984]]])
Output shape: tensor([[[-0.3020,  0.1582, -0.1137, -0.1427, -0.1558,  0.2479,  0.0260,
          -0.0947],
         [-0.3074,  0.1372, -0.1072, -0.1462, -0.1590,  0.240

In [4]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout=0.1, qkv_bias=False):

        super().__init__()
        self.d_out = d_out

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # Dropout layer
        self.dropout = nn.Dropout(dropout)

        # Register causal mask buffer
        # This creates a lower triangular matrix where future tokens are masked
        mask = torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()
        self.register_buffer("mask", mask)

    def forward(self, x):

        batch_size, seq_len, d_in = x.shape

        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)


        attn_scores = queries @ keys.transpose(-2, -1)


        attn_scores = attn_scores / (self.d_out ** 0.5)

        # Apply causal mask to prevent attention to future tokens
        # We use the mask up to seq_len as input might be shorter than context_length
        attn_scores = attn_scores.masked_fill(
            self.mask[:seq_len, :seq_len],
            float('-inf')
        )


        attn_weights = torch.softmax(attn_scores, dim=-1)


        attn_weights = self.dropout(attn_weights)

        context_vectors = attn_weights @ values

        return context_vectors



In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, num_heads, dropout=0.1, qkv_bias=False):

        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.W_out = nn.Linear(d_out, d_out)

        self.dropout = nn.Dropout(dropout)

        mask = torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()
        self.register_buffer("mask", mask)

    def forward(self, x):

        batch_size, seq_len, d_in = x.shape

        # Project input to Query, Key, Value and split into heads
        # Shape: (batch_size, seq_len, num_heads, head_dim)
        queries = self.W_query(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        keys = self.W_key(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        values = self.W_value(x).view(batch_size, seq_len, self.num_heads, self.head_dim)

        #(batch_size, num_heads, seq_len, head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)


        # (batch_size, num_heads, seq_len, seq_len)
        attn_scores = queries @ keys.transpose(-2, -1)

        attn_scores = attn_scores / (self.head_dim ** 0.5)

        attn_scores = attn_scores.masked_fill(
            self.mask[:seq_len, :seq_len],
            float('-inf')
        )

        attn_weights = torch.softmax(attn_scores, dim=-1)

        attn_weights = self.dropout(attn_weights)

        # (batch_size, num_heads, seq_len, head_dim)
        context_vectors = attn_weights @ values

        # Transpose and reshape to combine heads
        # (batch_size, seq_len, d_out)
        context_vectors = context_vectors.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_out
        )

        # Apply output projection
        context_vectors = self.W_out(context_vectors)

        return context_vectors
