<a href="https://colab.research.google.com/github/vardhanreddy2003/GPT-2Training/blob/main/TransfomerBlock.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [85]:
import torch
import torch.nn as nn

In [86]:
llm_config={
    "vocab_size":50257,
    "context_length":1024,
    "emb_dim":768,
    "n_heads":12,
    "n_layers":12,
    "dropout_rate":0.1,
    "qkv_bias":False
}

In [87]:
class feedforwardnetwork(nn.Module):
  def __init__(self,llm_config):
    super().__init__()

    self.feedforwardnetwork=nn.Sequential(
        nn.Linear(llm_config["emb_dim"],llm_config["emb_dim"]*4),
        nn.GELU(),
        nn.Linear(llm_config["emb_dim"]*4,llm_config["emb_dim"])
    )

  def forward(self,X):
    out=self.feedforwardnetwork(X)
    return out

In [88]:
class LayerNorm(nn.Module):

  def __init__(self,emb_dim):
    super().__init__()
    self.epis=1e-5
    self.shift=nn.Parameter(torch.zeros(emb_dim))
    self.scale=nn.Parameter(torch.ones(emb_dim))
  def forward(self,X):
    self.mean=X.mean(-1,keepdim=True)
    self.variance=X.var(-1,keepdim=True,unbiased=True)
    norm_x=(X-self.mean)/((self.variance+self.epis)**0.5)

    return self.scale*norm_x+self.shift

In [89]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

In [90]:
class TransformerBlock(nn.Module):

  def __init__(self,llm_config):
    super().__init__()
    self.MultiHeadAttention=MultiHeadAttention(
         d_in=llm_config["emb_dim"],
            d_out=llm_config["emb_dim"],
            context_length=llm_config["context_length"],
            num_heads=llm_config["n_heads"],
            dropout=llm_config["dropout_rate"],
            qkv_bias=llm_config["qkv_bias"]
    )
    self.dropout_layer=nn.Dropout(llm_config["dropout_rate"])
    self.feedforwardnetwork=feedforwardnetwork(llm_config)
    self.layer_norm1=LayerNorm(llm_config["emb_dim"])
    self.layer_norm2=LayerNorm(llm_config["emb_dim"])

  def forward(self,X):
    shortcut=X
    X=self.layer_norm1(X)
    X=self.MultiHeadAttention(X)
    X=self.dropout_layer(X)
    X=X+shortcut
    shortcut=X
    X=self.layer_norm2(X)
    X=self.feedforwardnetwork(X)
    X=self.dropout_layer(X)
    X=X+shortcut
    return X



In [91]:
torch.manual_seed(123)
X=torch.rand(2,4,768)
block=TransformerBlock(llm_config)
out=block(X)
out.shape

torch.Size([2, 4, 768])

In [92]:
print(block.feedforwardnetwork.feedforwardnetwork[0].weight)

Parameter containing:
tensor([[ 0.0176,  0.0298, -0.0270,  ..., -0.0134, -0.0250, -0.0225],
        [-0.0216,  0.0315,  0.0016,  ...,  0.0078,  0.0077, -0.0230],
        [-0.0211,  0.0097,  0.0292,  ...,  0.0216, -0.0124,  0.0164],
        ...,
        [ 0.0077, -0.0191, -0.0316,  ...,  0.0225, -0.0091,  0.0247],
        [-0.0286, -0.0322,  0.0108,  ...,  0.0288, -0.0130,  0.0138],
        [-0.0181, -0.0186,  0.0168,  ..., -0.0075, -0.0009,  0.0138]],
       requires_grad=True)
