In [1]:
%pip install torch dataclasses



In [7]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Tuple, List, Dict, Optional

In [3]:
class RMSNorm(nn.Module):
  def __init__(self, hidden_size: int, eps: float=1e-6):
    super().__init__()
    self.weight = nn.Parameter(torch.ones(hidden_size))
    self.variance_epsilon = eps

  def forward(self, x):
    variance = x.pow(2).mean(-1, keepdim=True)
    hidden_states *= torch.rsqrt(variance + self.variance_epsilon)
    return self.weight * hidden_states

In [4]:
class RoPE(nn.Module):
  def __init__(self, dim: int, max_positional_embeddings: int, rope_theta: float = 10000.0):
    super().__init__()
    self.dim = dim
    self.max_positional_embeddings = max_positional_embeddings
    self.rope_theta = rope_theta

    inv_freq = 1.0 / self.rope_theta ** (torch.arange(0, self.dim, 2).float() / self.dim)
    self.register_buffer("inv_freq", inv_freq, persistent=False)

  def forward(self, x: torch.Tensor, seq_len: int):
    t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
    freqs = torch.outer(t, self.inv_freq)
    emb = torch.cat((freqs, freqs), dim=-1)

    cos = emb.cos()
    sin = emb.sin()

    return cos, sin

In [5]:
def rotate_half(x):
  x1 = x[..., :x.shape[-1] // 2]
  x2 = x[..., x.shape[-1] // 2:]
  return torch.cat((-x2, x1), dim=-1)

In [6]:
def apply_rotary_pos_emb(q, k, cos, sin):
  q_embed = (q*cos) + (rotate_half(q)*sin)
  k_embed = (k*cos) + (rotate_half(k)*sin)
  return q_embed, k_embed

In [9]:
class GroupedQueryAttention(nn.Module):
  def __init__(self, config: Dict):
    super().__init__()
    self.config = config
    self.hidden_size = config.hidden_size
    self.num_heads = config.num_attention_heads
    self.num_kv_heads = config.num_kv_heads
    self.num_kv_groups = self.num_heads // self.num_kv_heads
    self.head_dim = self.hidden_size // self.num_heads
    self.attention_dropout = config.attention_dropout

    self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
    self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
    self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads*self.head_dim, bias=False)
    self.o_proj = nn.Linear(self.num_heads*self.head_dim, self.hidden_size, bias=False)

    self.rope = RoPE(
        self.head_dim,
        max_positional_embeddings=config.max_positional_embeddings,
        rope_theta=config.rope_theta
    )
  def forward(
      self,
      hidden_states: torch.Tensor,
      attention_mask: Optional[torch.Tensor] = None,
      position_ids: Optional[torch.Tensor] = None
  ) -> torch.Tensor:
    B, T, C = hidden_states.size()

    query_states = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = self.k_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1,2)
    value_states = self.v_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)

    cos, sin = self.rope(value_states, seq_len=T)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

    if self.num_kv_groups > 1:
      key_states = key_states.repeat_interleave(self.num_kv_groups, dim=1)
      value_states = value_states.repeat_interleave(self.num_kv_groups, dim=1)

    attention_weights = query_states @ key_states.transpose(2, 3)
    attention_weights = attention_weights / math.sqrt(self.head_dim)

    if attention_mask is not None:
      attention_weights += attention_mask

    attention_weights = F.softmax(attention_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attention_weights = F.dropout(attention_weights, p=self.attention_dropout, training=self.training)

    attention_output = attention_weights @ value_states
    attention_output = attention_output.transpose(1,2).contiguous().view(B, T, self.hidden_size)
    attention_output = self.o_proj(attention_output)
    return attention_output

In [10]:
class SwiGLUFeedForward(nn.Module):
  def __init__(self, config: Dict):
    super().__init__()
    self.hidden_size = config.hidden_size
    self.intermediate_size = config.intermediate_size

    self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
    self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
    self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)

  def forward(self, x):
    return self.down_proj(F.silu(self.gate_proj(x))*self.up_proj(x))

In [11]:
class TransformerBlock(nn.Module):
  def __init__(self, config: Dict):
    super().__init__()
    self.hidden_size = config.hidden_size

    self.attention_block = GroupedQueryAttention(config)
    self.ffn = SwiGLUFeedForward(config)
    self.pre_norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
    self.attention_norm = RMSNorm(config.hidden_size, config.rms_norm_eps)

  def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor]=None):
    residual = x
    x = self.pre_norm(x)
    x = self.attention_block(x, attention_mask)
    x += residual

    residual = x
    x = self.attention_norm(x)
    x = self.ffn(x)
    x += residual

    return x

In [13]:
class LLaMA4(nn.Module):
  def __init__(self, config: Dict):
    super().__init__()
    self.config = config
    self.vocabulary_size = config.vocabulary_size

    self.embed_tokens = nn.Embedding(config.vocabulary_size, config.hidden_size)
    self.layers = nn.ModuleList(
        [TransformerBlock(config) for _ in range(config.num_hidden_layers)]
    )
    self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
    self.lm_head = nn.Linear(config.hidden_size, config.vocabulary_size, bias=False)

    if config.tie_word_embeddings:
      self.lm_head.weight = self.embed_tokens.weight

    self.apply(self._init_weights)

  def _init_weights(self, module):
    std = self.config.initializer_range
    if isinstance(module, nn.Linear):
      module.weight.data.normal_(mean=0.0, std=std)
      if module.bias is not None:
        module.bias.data.zero_()
      elif isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=std)

  def forward(self, input_ids: torch.Tensor, attention_mask:Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None):
    hidden_states = self.embed_tokens(input_ids)

    if attention_mask is None:
      seq_len = input_ids.shape[1]
      attention_mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1)

    for layer in self.layers:
      hidden_states = layer(hidden_states, attention_mask=attention_mask)

      hidden_states = self.norm(hidden_states)
      logits = self.lm_head(hidden_states)

      loss = None
      if labels is not None:
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss = F.cross_entropy(
            shift_logits.view(-1, self.vocab_size),
            shift_labels.view(-1),
            ignore_index=-100
        )

      return {"loss": loss, "logits": logits}

    def count_total_params(self):
      total_params = sum(p.numel() for p in self.parameters())
      trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
      return {"total": total_params, "trainable": trainable_params}