In [1]:
from dataclasses import dataclass
from typing import Any, Literal, Optional, Type, Tuple, List
import torch
from torch import nn, Tensor
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, einsum
from xformers.ops import SwiGLU
from transformers import AutoModelForCausalLM, AutoTokenizer
from torchtune.modules import RotaryPositionalEmbeddings
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding

  @torch.library.impl_abstract("xformers_flash::flash_fwd")
  @torch.library.impl_abstract("xformers_flash::flash_bwd")
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# TODO fix Rotatory Embeddings
# TODO fix flashAttention
# TODO remove unused config parameters
# TODO refactor the code
# TODO fix cache
# TODO code load dataset and dataloader
# TODO code training loop
# TODO code evaluation loop
# TODO code inference loop
# TODO add support to tensorboard
# TODO add weights initialization

# Config

In [3]:
@dataclass
class Config:
    org: str = "StatNLP-research"
    name: str = "tiny_LLaMA_1b"
    
    seq_length: int = 2048
    vocab_size: int = 32000
    padding_multiple: int = 64
    padded_vocab_size: Optional[int] = None
    
    n_layer: int = 22
    n_head: int = 32
    n_embd: int = 2048

    rotary_percentage: float = 1.0
    parallel_residual: bool = False
    bias: bool = False
    n_query_groups: int = 4
    shared_attention_norm: bool = False
    norm_eps: float = 1e-5

    hidden_dim: int = 5632
    condense_ratio: int = 1
    dropout: int = 0.0
    device: str = "cpu"
    dtype=torch.float16

    @property
    def head_size(self) -> int:
        return self.n_embd // self.n_head

configs = [
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b",
        seq_length=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        norm_eps=1e-5,
        hidden_dim=5632,
        n_query_groups=4,
    ),
    dict(
        org="ufv",
        name="tiny_tiny_LLaMA",
        seq_length=256,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=8,
        n_query_groups=4,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        norm_eps=1e-5,
        hidden_dim=512,
    )
]

name_to_config: dict[str, Config] = {
    config["name"]: Config(**config) for config in configs}

# RoPe

In [4]:
# class LlamaRotaryEmbedding(nn.Module):
#     def __init__(
#         self,
#         dim=None,
#         max_position_embeddings=2048,
#         base=10000,
#         device=None,
#         scaling_factor=1.0,
#         rope_type="default",
#         config: Optional[LlamaConfig] = None,
#     ):
#         super().__init__()
#         # TODO (joao): remove the `if` below, only used for BC
#         self.rope_kwargs = {}
#         if config is None:
#             logger.warning_once(
#                 "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
#                 "`config` argument. All other arguments will be removed in v4.46"
#             )
#             self.rope_kwargs = {
#                 "rope_type": rope_type,
#                 "factor": scaling_factor,
#                 "dim": dim,
#                 "base": base,
#                 "max_position_embeddings": max_position_embeddings,
#             }
#             self.rope_type = rope_type
#             self.max_seq_len_cached = max_position_embeddings
#             self.original_max_seq_len = max_position_embeddings
#         else:
#             # BC: "rope_type" was originally "type"
#             if config.rope_scaling is not None:
#                 self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
#             else:
#                 self.rope_type = "default"
#             self.max_seq_len_cached = config.max_position_embeddings
#             self.original_max_seq_len = config.max_position_embeddings

#         self.config = config
#         self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

#         inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
#         self.register_buffer("inv_freq", inv_freq, persistent=False)
#         self.original_inv_freq = self.inv_freq

#     def _dynamic_frequency_update(self, position_ids, device):
#             self.max_seq_len_cached = self.original_max_seq_len

#     @torch.no_grad()
#     def forward(self, x, position_ids):
#         if "dynamic" in self.rope_type:
#             self._dynamic_frequency_update(position_ids, device=x.device)

#         # Core RoPE block
#         inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
#         position_ids_expanded = position_ids[:, None, :].float()
#         # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
#         device_type = x.device.type
#         device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
#         with torch.autocast(device_type=device_type, enabled=False):
#             freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
#             emb = torch.cat((freqs, freqs), dim=-1)
#             cos = emb.cos()
#             sin = emb.sin()

#         # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
#         cos = cos * self.attention_scaling
#         sin = sin * self.attention_scaling

#         return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

In [5]:
def build_rope_cache(
    seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1
) -> Tuple[Tensor, Tensor]:
    """
    Build rotary position embedding cache.
    Note: n_elem should be head_size, not head_size//2
    """

    theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem))

    seq_idx = torch.arange(seq_len, device=device) / condense_ratio
    
    idx_theta = torch.outer(seq_idx, theta)

    # Expand dimensions to match broadcasting needs
    cos = torch.cos(idx_theta).unsqueeze(1)  
    sin = torch.sin(idx_theta).unsqueeze(1)  
    
    # Duplicate each value to match full head size
    cos = torch.repeat_interleave(cos, 2, dim=-1)  
    sin = torch.repeat_interleave(sin, 2, dim=-1)  
    
    if dtype == torch.bfloat16:
        return cos.bfloat16(), sin.bfloat16()
    if dtype in (torch.float16, torch.bfloat16, torch.int8):
        return cos.half(), sin.half()
    return cos, sin

def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    """
      Reference implementation: https://github.com/jzhang38/TinyLlama/blob/main/lit_gpt/model.py
    """
    head_size = x.size(-1)
    x1 = x[..., : head_size // 2]  # (B, nh, T, hs/2)
    x2 = x[..., head_size // 2 :]  # (B, nh, T, hs/2)
    rotated = torch.cat((-x2, x1), dim=-1)  # (B, nh, T, hs)
    roped = (x * cos) + (rotated * sin)
    return roped.type_as(x)

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)
  
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed



# GroupQueryAttention

In [6]:
class GroupQueryAttention(nn.Module):
    def __init__(
        self,
        config: Config,
    ):
        super().__init__()

        self.dtype = config.dtype
        self.dim = config.n_embd
        self.query_heads = config.n_head 
        self.queries_per_kv = config.n_head // config.n_query_groups
        self.key_value_heads = config.n_head // self.queries_per_kv

        self.kv_dim = self.dim // self.query_heads * self.key_value_heads
        
        self.q_proj = nn.Linear(self.dim, self.dim,    bias=False, dtype=self.dtype)
        self.k_proj = nn.Linear(self.dim, self.kv_dim, bias=False, dtype=self.dtype)
        self.v_proj = nn.Linear(self.dim, self.kv_dim, bias=False, dtype=self.dtype)

        self.o_proj = nn.Linear(self.dim, self.dim,    bias=False, dtype=self.dtype)
        self.config = config
        
    
    def scaled_dot_product_gqa(self, query: Tensor, key:Tensor, value:Tensor):
        scale_factor = 1 / query.size(-1) ** 0.5 
        
        L, S = query.size(-2), key.size(-2)

        attn_bias = torch.zeros(L, S, dtype=query.dtype)
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

        query = rearrange(query, "b (h g) n d -> b g h n d", g=self.queries_per_kv)

        attn_weight = query @ key.transpose(-1,-2) * scale_factor

        attn_weight += attn_bias
        
        attn_weight = torch.softmax(attn_weight, dim=-1)
        y = attn_weight @ value
        y = rearrange(y, "b g h n d -> b (h g) n d")

        return y
        

    def forward(self, x: Tensor, position_embeddings: Tensor,  mask: Optional[Tensor] = None) -> torch.Tensor:

        
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = rearrange(q, "b n (h d) -> b h n d", h=self.query_heads)
        k = rearrange(k, "b n (h d) -> b h n d", h=self.key_value_heads)
        v = rearrange(v, "b n (h d) -> b h n d", h=self.key_value_heads)
        
        cos, sin = position_embeddings
        

        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        
        
        y = self.scaled_dot_product_gqa(
            query=q,
            key=k,
            value=v
        )
        
        y = rearrange(y, "b h n d -> b n (h d)")

        y = self.o_proj(y)

        return y


# FFN

In [7]:
class FFN(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config

        self.gate_proj = nn.Linear(config.n_embd, config.hidden_dim, bias=False, dtype=config.dtype)
        self.up_proj = nn.Linear(config.n_embd, config.hidden_dim, bias=False, dtype=config.dtype)
        self.down_proj = nn.Linear(config.hidden_dim, config.n_embd, bias=False, dtype=config.dtype)
        self.act_fn = nn.SiLU()

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj

# RMSNorm

In [8]:
class RMSNorm(nn.Module):
    def __init__(self, config: Config,  dim: int = -1, eps: float = 1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(config.n_embd, dtype=config.dtype))
        self.variance_epsilon = eps

    def forward(self, x):
        input_dtype = x.dtype
        x = x.to(torch.float32)
        variance = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * x.to(input_dtype)

# Decoder

In [9]:
class DecoderLayer(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config

        self.attn = GroupQueryAttention(config)
        self.ffn = FFN(config)
        self.norm1 = RMSNorm(config)
        self.norm2 = RMSNorm(config)

    def forward(self, x, pos_emb, mask=None):
        n_1 = self.norm1(x)
        h = self.attn(n_1, pos_emb, mask=mask)
        x = x + h
        x = x + self.ffn(self.norm2(x))
        return x

# TinyLLama

In [10]:
class LlamaModel(nn.Module):

    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        
        self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd , dtype=config.dtype)
        self.layers       = nn.ModuleList([DecoderLayer(config) for _ in range(config.n_layer)])
        self.norm         = RMSNorm(config)
        self.lm_head      = nn.Linear(config.n_embd, config.vocab_size, bias=False, dtype=config.dtype)
        
        self.rope = LlamaRotaryEmbedding(max_position_embeddings=config.seq_length, dim=int(config.rotary_percentage * config.head_size))
        self.rope_cache: Optional[Tuple[Tensor, Tensor]] = None
        self.mask_cache: Optional[Tensor] = None
        self.kv_caches: List[Tuple[Tensor, Tensor]] = []

    def get_kv_head_dim(self, config: Config) -> int:
        queries_per_kv = config.n_head // config.n_query_groups
        key_value_heads = config.n_head // queries_per_kv
        kv_dim = config.hidden_dim // config.n_head * key_value_heads
        return kv_dim

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        max_length: int = None,
        use_kv_cache: bool = False,
        input_pos: Optional[torch.LongTensor] = None,
    ) -> Tensor:
      
        B, T = input_ids.shape
        
        assert T <= self.config.seq_length, f"Input length {T} exceeds maximum model length {self.config.seq_length}"
        
        max_seq_length = self.config.seq_length
        x = self.embed_tokens(input_ids)  
        
        if self.rope_cache is None :
            self.rope_cache = self.build_rope_cache(x)
            
        if use_kv_cache and self.mask_cache is None:
            self.mask_cache = self.build_mask_cache(input_ids)
        
        cos, sin = self.rope_cache
        
        if use_kv_cache:
              cos = cos.index_select(0, input_pos)
              sin = sin.index_select(0, input_pos)
              mask = self.mask_cache.index_select(2, input_pos)
              mask = mask[:, :, :, :max_seq_length]
        else:
            cos = cos[:T]
            sin = sin[:T]
            mask = None

        
        if not use_kv_cache:
            for block in self.layers:
                x = block(x, (cos, sin), max_seq_length)
        else:
            self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1) * 2)
            for i, block in enumerate(self.transformer.h):
                x, self.kv_caches[i] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i])

        x = self.norm(x)

        return self.lm_head(x) 

    def build_rope_cache(self, idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        position_ids = torch.arange(idx.size(1), device=idx.device).unsqueeze(0)
        return self.rope(idx, position_ids)
        
    def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor:
        ones = torch.ones((self.config.seq_length, self.config.seq_length), device=idx.device, dtype=torch.bool)
        return torch.tril(ones).unsqueeze(0).unsqueeze(0)

    def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope_cache_length: int) :
        B = idx.size(0)
        heads = 1 if self.config.n_query_groups == 1 else self.config.n_query_groups

        k_cache_shape = (
            B,
            max_seq_length,
            heads,
            rope_cache_length + self.config.head_size - int(self.config.rotary_percentage * self.config.head_size),
        )
        v_cache_shape = (B, max_seq_length, heads, self.config.head_size)
        device = idx.device
        return [
            (torch.zeros(k_cache_shape, device=device), torch.zeros(v_cache_shape, device=device))
            for _ in range(self.config.n_layer)
        ]

# Load Original Model

In [11]:
import torch
from safetensors.torch import load_file
from collections import OrderedDict

def remap_state_dict(state_dict):
    """
    Remaps the state dict keys from LlamaForCausalLM format to your custom model format
    """
    new_state_dict = OrderedDict()
    
    # Create a mapping dictionary for the different naming conventions
    key_mapping = {
        'model.embed_tokens': 'embed_tokens',
        'model.norm': 'norm',
        'model.layers': 'layers',
        'self_attn': 'attn',
        'input_layernorm': 'norm1',
        'post_attention_layernorm': 'norm2',
        'mlp': 'ffn'
    }
    
    for key, value in state_dict.items():
        new_key = key
        
        # Apply the mappings
        for old, new in key_mapping.items():
            new_key = new_key.replace(old, new)
            
        # Handle specific cases where tensor shapes need to be validated
        if 'attn' in new_key:
            # Ensure attention weights have compatible shapes
            if value.shape != state_dict[key].shape:
                raise ValueError(f"Incompatible shape for attention weights: {key}")
                
        new_state_dict[new_key] = value
    
    return new_state_dict

def load_model_weights(model, safetensors_path):
    """
    Loads and remaps weights from a safetensors file to your custom model
    
    Args:
        model: Your custom model instance
        safetensors_path: Path to the safetensors file
    """
    try:
        # Load the safetensors file
        original_state_dict = load_file(safetensors_path)
        
        # Remap the state dict to match your model's architecture
        new_state_dict = remap_state_dict(original_state_dict)
        
        # Load the remapped weights into your model
        missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
        
        print("Model loaded successfully!")
        if missing_keys:
            print("Missing keys:", missing_keys)
        if unexpected_keys:
            print("Unexpected keys:", unexpected_keys)
            
        return True
        
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return False


# validation

In [12]:
tiny_LLaMA_1b = LlamaModel(name_to_config["tiny_LLaMA_1b"])

`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46


In [13]:
load_model_weights(tiny_LLaMA_1b, "model/models--TinyLlama--TinyLlama-1.1B-Chat-v1.0/snapshots/fe8a4ea1ffedaf415f4da2f062534de366a451e6/model.safetensors")

Model loaded successfully!


True

In [14]:

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,  map_device="auto", add_eos_token=True, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    cache_dir="model",
    # attn_implementation=attn_implementation
)

In [15]:
inputs_ids = tokenizer("Hello, how are you?", return_tensors="pt").input_ids

In [16]:

out1 = tiny_LLaMA_1b(inputs_ids)
out2 = model(inputs_ids, output_attentions=True)



In [17]:
assert (out1 == out2.logits).all(), "The model output is different from the reference model output"