### This notebook contains smaller size llama/gpt model without rope.
Contains:
- GQA
- FFN with SWIGLU

In [1]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
@dataclass
class Config:
  vocab_size: int = 50000
  d_model: int = 512
  seq_len: int = 1024
  n_heads: int = 8
  kv_n_heads: int = None
  n_layers: int = 6
  dropout: float = 0.1
  expansion_ratio: int = 4

config = Config()

## InputEmbeddings


In [3]:
class InputEmbeddings(nn.Module):
  def __init__(self, config: Config):
    super().__init__()
    d_model = config.d_model
    vocab_size = config.vocab_size
    self.embedding = nn.Embedding(vocab_size, d_model)

  def forward(self, x):
    # x: (B, T) -> (B, T, d_model)
    return self.embedding(x)

x = torch.randint(low=0, high=config.vocab_size, size=(2, config.seq_len))
x_embed = InputEmbeddings(config=config)
out = x_embed(x)
out.shape

torch.Size([2, 1024, 512])

## PositionalEncoding

In [4]:
class PositionalEncoding(nn.Module):
  def __init__(self, config: Config):
    super().__init__()
    seq_len = config.seq_len
    d_model = config.d_model


    # create a positional encoding tensor of shape (T, d_model)
    pe = torch.zeros(seq_len, d_model)

    # create a vector of shape (seq_len)
    position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)    # (seq_len, 1)

    # create a vector of shape (d_model)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)).unsqueeze(0)  # (1, d_model / 2)

    # apply sine to even indices
    pe[:, 0::2] = torch.sin(position * div_term)

    # apply cosine to odd indices
    pe[:, 1::2] = torch.cos(position * div_term)

    pe = pe.unsqueeze(0)    # (1, T, d_model)

    self.register_buffer('pe', pe)


  def forward(self, x):
    # (B, T, d_model) -> (B, T, d_model)
    x  = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)    # (B, T, d_model)
    return x


x = torch.rand(2, config.seq_len, config.d_model)
pos = PositionalEncoding(config=config)
out = pos(x)
out.shape

torch.Size([2, 1024, 512])

## RoPE

In [5]:
"""
Coming soon....
"""

'\nComing soon....\n'

## RMSNormalization

In [6]:
class RMSNormalization(nn.Module):
  def __init__(self, config: Config, eps: float=1e-6):
    super().__init__()
    d_model = config.d_model
    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(d_model))

  def _norm(self, x):
    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

  def forward(self, x):
    # (B, T, d_model) -> (B, T, d_model)
    return self.gamma * self._norm(x.float()).type_as(x)

x = torch.rand(2, config.seq_len, config.d_model)
norm = RMSNormalization(config=config)
out = norm(x)
out.shape

torch.Size([2, 1024, 512])

## FeedForwardBlock with ReLU

In [7]:
class FeedForwardBlock(nn.Module):
  def __init__(self, config: Config):
    super().__init__()
    d_model = config.d_model
    expansion_ratio = config.expansion_ratio
    dropout = config.dropout
    hidden_size = expansion_ratio * d_model

    self.up_proj = nn.Linear(d_model, hidden_size)
    self.down_proj = nn.Linear(hidden_size, d_model)
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    # (B, T, d_model) -> (B, T, d_ff) -> (B, T, d_model)
    x = self.dropout(self.relu(self.up_proj(x)))
    x = self.down_proj(x)
    return x

x = torch.rand(1, config.seq_len, config.d_model)
ffn = FeedForwardBlock(config=config)
out = ffn(x)
out.shape

torch.Size([1, 1024, 512])

## FeedForwardBlock with SWIGLU

In [8]:
class FeedForwardSwigluBlock(nn.Module):
  def __init__(self, config: Config):
    super().__init__()
    d_model = config.d_model
    expansion_ratio = config.expansion_ratio
    hidden_size = expansion_ratio * d_model

    self.w1 = nn.Linear(d_model, hidden_size, bias=False)
    self.w2 = nn.Linear(hidden_size, d_model, bias=False)
    self.w3 = nn.Linear(d_model, hidden_size, bias=False)

  def forward(self, x):
    # (B, T, d_model) -> (B, T, hidden_size) -> (B, T, d_model)
    swish = F.silu(self.w1(x))   # (B, T, hidden_size)
    x_V = self.w3(x)   # (B, T, hidden_size)
    x = swish * x_V    # (B, T, hidden_size)
    return self.w2(x)   # (B, T, d_model)

x = torch.rand(1, config.seq_len, config.d_model)
ffn = FeedForwardSwigluBlock(config=config)
out = ffn(x)
out.shape

torch.Size([1, 1024, 512])

## MultiHeadAttentionBlock

In [9]:
class MultiHeadAttentionBlock(nn.Module):
  def __init__(self, config: Config):
    super().__init__()
    self.d_model = config.d_model
    self.n_heads = config.n_heads
    dropout = config.dropout
    assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"
    self.head_size = self.d_model // self.n_heads

    self.kv_n_heads = config.n_heads if config.kv_n_heads is None else config.kv_n_heads

    # Indicates how many times keys and values should be repeated
    self.n_rep = self.n_heads // self.kv_n_heads

    self.w_q = nn.Linear(self.d_model, self.n_heads * self.head_size, bias=False)
    self.w_k = nn.Linear(self.d_model, self.kv_n_heads * self.head_size, bias=False)
    self.w_v = nn.Linear(self.d_model, self.kv_n_heads * self.head_size, bias=False)
    self.w_o = nn.Linear(self.n_heads * self.head_size, self.d_model, bias=False)

    self.dropout = nn.Dropout(dropout)

  def repeat_kv(self, x: torch.Tensor, n_rep: int):
    B, T, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
      return x

    x = x[:, :, :, None, :].expand(B, T, n_kv_heads, n_rep, head_dim).reshape(B, T, n_kv_heads * n_rep, head_dim)   # (B, T, n_heads, head_dim)
    return x

  def forward(self, x, mask):
    # x: (B, T, d_model)
    # mask: (B, 1, seq_len, seq_len)
    B, T, _ = x.shape
    query = self.w_q(x)    # (B, T, nH * hd)
    key = self.w_k(x)      # (B, T, nKVH * hd)
    value = self.w_v(x)    # (B, T, nKVH * hd)

    query = query.view(B, T, self.n_heads, self.head_size)    # (B, T, nH, Hd)
    key = key.view(B, T, self.kv_n_heads, self.head_size)        # (B, T, nKVH, Hd)
    value = value.view(B, T, self.kv_n_heads, self.head_size)    # (B, T, nKVH, Hd)

    # Since every group of Q shares same K and V heads, just repeat the K and V heads for every Q in the same group
    key = self.repeat_kv(key, self.n_rep)             # (B, T, nH, Hd)
    value = self.repeat_kv(value, self.n_rep)         # (B, T, nH, Hd)

    query = query.transpose(1, 2)   # (B, nH, T, Hd)
    key = key.transpose(1, 2)       # (B, nH, T, Hd)
    value = value.transpose(1, 2)   # (B, nH, T, Hd)

    # compute the attention score
    attention_scores = query @ key.transpose(2, 3) / math.sqrt(self.head_size)    # (B, nH, T, T)

    # apply mask
    if mask is not None:
      attention_scores = attention_scores.masked_fill_(mask == 0, -1e9)     # (B, nH, T, T)

    if self.dropout is not None:
      attention_scores = self.dropout(attention_scores)    # (B, nH, T, T)

    attention_scores = attention_scores.softmax(dim=-1)    # (B, nH, T, T)

    # compute the context vector
    z = attention_scores @ value      # (B, nH, T, Hd)

    z = z.transpose(1, 2).contiguous().view(B, T, self.head_size * self.n_heads)    # (B, T, d_model)

    return self.w_o(z)


def causal_mask(size):
  mask = torch.triu(torch.ones(1, size, size), diagonal=1)
  return mask == 0

## Without GQA (as in MHA)
config.kv_n_heads = config.n_heads
x = torch.rand(1, config.seq_len, config.d_model)
mask = causal_mask(config.seq_len)
attn = MultiHeadAttentionBlock(config=config)
out = attn(x, mask)
print(out.shape)

## With GQA
config.kv_n_heads = 4
x = torch.rand(1, config.seq_len, config.d_model)
mask = causal_mask(config.seq_len)
attn = MultiHeadAttentionBlock(config=config)
out = attn(x, mask)
print(out.shape)

torch.Size([1, 1024, 512])
torch.Size([1, 1024, 512])


## DecoderBlock

In [10]:
class DecoderBlock(nn.Module):
  def __init__(self, config: Config):
    super().__init__()
    self.norm1 = RMSNormalization(config)
    self.norm2 = RMSNormalization(config)
    self.attn = MultiHeadAttentionBlock(config)
    self.ffn = FeedForwardSwigluBlock(config)

  def forward(self, x, mask):
    # x: (B, T, d_model)
    # mask: (B, 1, seq_len, seq_len)
    x = x + self.attn(self.norm1(x), mask)
    x = x + self.ffn(self.norm2(x))
    return x

x = torch.rand(1, config.seq_len, config.d_model)
mask = causal_mask(config.seq_len)
decoder_block = DecoderBlock(config=config)
out = decoder_block(x, mask)
out.shape

torch.Size([1, 1024, 512])

## Decoder

In [11]:
class Decoder(nn.Module):
  def __init__(self, config: Config, layers: nn.ModuleList):
    super().__init__()
    self.layers = layers
    self.norm = RMSNormalization(config)

  def forward(self, x, mask):
    for layer in self.layers:
      x = layer(x, mask)
    return self.norm(x)

## ProjectionLayer

In [12]:
class ProjectionLayer(nn.Module):
  def __init__(self, config: Config):
    super().__init__()
    self.proj = nn.Linear(config.d_model, config.vocab_size)

  def forward(self, x):
    # (B, T, d_model) -> (B, T, vocab_size)
    return self.proj(x)

## LlamaModel

In [13]:
class LLAMA(nn.Module):
  def __init__(self, embed: InputEmbeddings, pos_enc: PositionalEncoding, decoder: Decoder, projection_layer: ProjectionLayer):
    super().__init__()
    self.embed = embed
    self.pos_enc = pos_enc
    self.decoder = decoder
    self.projection = projection_layer

  def decode(self, x, mask):
    # x: (B, T)
    # mask: (B, 1, seq_len, seq_len)
    x_emb = self.embed(x)
    x_pos = self.pos_enc(x_emb)
    x = self.decoder(x_pos, mask)
    return x

  def project(self, x):
    # (B, T, d_model) -> (B, T, vocab_size)
    return self.projection(x)


## Bulid Llama model

In [14]:
def build_llama(config: Config):

  embed = InputEmbeddings(config)
  pos_enc = PositionalEncoding(config)

  decoder_blocks = []
  for _ in range(config.n_layers):
    decoder_block = DecoderBlock(config)
    decoder_blocks.append(decoder_block)
  decoder = Decoder(config, nn.ModuleList(decoder_blocks))

  projection_layer = ProjectionLayer(config)

  model = LLAMA(embed, pos_enc, decoder, projection_layer)

  return model

In [15]:
model = build_llama(config)
model

LLAMA(
  (embed): InputEmbeddings(
    (embedding): Embedding(50000, 512)
  )
  (pos_enc): PositionalEncoding()
  (decoder): Decoder(
    (layers): ModuleList(
      (0-5): 6 x DecoderBlock(
        (norm1): RMSNormalization()
        (norm2): RMSNormalization()
        (attn): MultiHeadAttentionBlock(
          (w_q): Linear(in_features=512, out_features=512, bias=False)
          (w_k): Linear(in_features=512, out_features=256, bias=False)
          (w_v): Linear(in_features=512, out_features=256, bias=False)
          (w_o): Linear(in_features=512, out_features=512, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ffn): FeedForwardSwigluBlock(
          (w1): Linear(in_features=512, out_features=2048, bias=False)
          (w2): Linear(in_features=2048, out_features=512, bias=False)
          (w3): Linear(in_features=512, out_features=2048, bias=False)
        )
      )
    )
    (norm): RMSNormalization()
  )
  (projection): ProjectionLayer(
    (pr

In [18]:
B = 1
x = torch.randint(low=0, high=config.vocab_size, size=(B, config.seq_len))
x.shape

torch.Size([1, 1024])

In [19]:
mask = causal_mask(config.seq_len)
mask.shape

torch.Size([1, 1024, 1024])

In [20]:
mask = mask.unsqueeze(0)     # (batch dimension)
mask.shape

torch.Size([1, 1, 1024, 1024])

In [22]:
logits = model.decode(x, mask)
logits.shape

torch.Size([1, 1024, 512])

In [24]:
out = model.project(logits)
out.shape

torch.Size([1, 1024, 50000])

In [27]:
# Compute loss, acc, ....
# Done!!