In [1]:
from dataclasses import dataclass
from typing import Optional,  Tuple
import torch
from torch import nn, Tensor
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from safetensors.torch import load_file
from collections import OrderedDict


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# TODO refactor the code
# TODO code load dataset and dataloader
# TODO code training loop
# TODO code evaluation loop
# TODO add support to tensorboard
# TODO add weights initialization

# Config

In [3]:
@dataclass
class Config:
    name: str = "tiny_LLaMA_1b"

    seq_length: int = 2048
    vocab_size: int = 32000
    
    n_layer: int = 22
    n_head: int = 32
    n_embd: int = 2048
    hidden_dim: int = 5632
    n_query_groups: int = 4

    base=10000
    rotary_percentage: float = 1.0
    
    device: str = "cpu"
    dtype=torch.float16
    
    stop_token_id = 2

    @property
    def head_size(self) -> int:
        return self.n_embd // self.n_head
      
    @property
    def dim(self) -> int:
        return int(self.rotary_percentage * self.head_size)

configs = [
    dict(
        name="tiny_LLaMA_1b",
        seq_length=2048,
        vocab_size=32000,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        hidden_dim=5632,
        n_query_groups=4,
    ),
    dict(
        name="tiny_tiny_LLaMA",
        seq_length=256,
        vocab_size=32000,
        n_layer=12,
        n_head=8,
        n_query_groups=4,
        n_embd=768,
        rotary_percentage=1.0,
        hidden_dim=512,
    )
]

name_to_config: dict[str, Config] = {
    config["name"]: Config(**config) for config in configs}

# RoPe

In [4]:

class LlamaRotaryEmbedding(nn.Module):
    def __init__(
        self,
        config: Config,
    ):
        super().__init__()
      

        self.config = config
        inv_freq, self.attention_scaling = self._compute_rope_parameters(self.config)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq
        
    def _compute_rope_parameters(
        self,
        config: Config  
    ) -> Tuple[torch.Tensor, float]:

        base = config.base
        dim = config.dim
        attention_factor = 1.0  
        
        # Compute the inverse frequencies
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(config.device) / dim))
        return inv_freq, attention_factor
      
      
    @torch.no_grad()
    def forward(self, x, position_ids):

        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()

        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()

        cos = cos * self.attention_scaling
        sin = sin * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)
  
def apply_rotary_pos_emb(q, k, cos, sin,  unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed



# GroupQueryAttention

In [5]:
class GroupQueryAttention(nn.Module):
    def __init__(
        self,
        config: Config,
    ):
        super().__init__()

        self.dtype = config.dtype
        self.dim = config.n_embd
        self.query_heads = config.n_head
        self.queries_per_kv = config.n_head // config.n_query_groups
        self.key_value_heads = config.n_head // self.queries_per_kv

        self.kv_dim = self.dim // self.query_heads * self.key_value_heads

        self.q_proj = nn.Linear(self.dim, self.dim,
                                bias=False, dtype=self.dtype)
        self.k_proj = nn.Linear(self.dim, self.kv_dim,
                                bias=False, dtype=self.dtype)
        self.v_proj = nn.Linear(self.dim, self.kv_dim,
                                bias=False, dtype=self.dtype)

        self.o_proj = nn.Linear(self.dim, self.dim,
                                bias=False, dtype=self.dtype)
        self.config = config

    def scaled_dot_product_gqa(self, query: Tensor, key: Tensor, value: Tensor):
        scale_factor = 1 / query.size(-1) ** 0.5

        L, S = query.size(-2), key.size(-2)

        attn_bias = torch.zeros(L, S, dtype=query.dtype)
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

        query = rearrange(query, "b (h g) n d -> b g h n d",
                          g=self.queries_per_kv)

        attn_weight = query @ key.transpose(-1, -2) * scale_factor

        attn_weight += attn_bias

        attn_weight = torch.softmax(attn_weight, dim=-1)
        y = attn_weight @ value
        y = rearrange(y, "b g h n d -> b (h g) n d")

        return y

    def forward(self, 
                x: Tensor,
                position_embeddings: Tensor) -> torch.Tensor:

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = rearrange(q, "b n (h d) -> b h n d", h=self.query_heads)
        k = rearrange(k, "b n (h d) -> b h n d", h=self.key_value_heads)
        v = rearrange(v, "b n (h d) -> b h n d", h=self.key_value_heads)

            
        cos, sin = position_embeddings

        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        y = self.scaled_dot_product_gqa(
            query=q,
            key=k,
            value=v
        )

        y = rearrange(y, "b h n d -> b n (h d)")

        y = self.o_proj(y)

        return y

# FFN

In [6]:
class FFN(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config

        self.gate_proj = nn.Linear(config.n_embd, config.hidden_dim, bias=False, dtype=config.dtype)
        self.up_proj = nn.Linear(config.n_embd, config.hidden_dim, bias=False, dtype=config.dtype)
        self.down_proj = nn.Linear(config.hidden_dim, config.n_embd, bias=False, dtype=config.dtype)
        self.act_fn = nn.SiLU()

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj

# RMSNorm

In [7]:
class RMSNorm(nn.Module):
    def __init__(self, config: Config,  dim: int = -1, eps: float = 1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(config.n_embd, dtype=config.dtype))
        self.variance_epsilon = eps

    def forward(self, x):
        input_dtype = x.dtype
        x = x.to(torch.float32)
        variance = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * x.to(input_dtype)

# Decoder

In [8]:
class DecoderLayer(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config

        self.attn = GroupQueryAttention(config)
        self.ffn = FFN(config)
        self.norm1 = RMSNorm(config)
        self.norm2 = RMSNorm(config)

    def forward(self, x, pos_emb, mask=None):
        n_1 = self.norm1(x)
        h = self.attn(n_1, pos_emb)
        x = x + h
        x = x + self.ffn(self.norm2(x))
        return x

# TinyLLama

In [9]:
class LlamaModel(nn.Module):

    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        
        self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd , dtype=config.dtype)
        self.layers       = nn.ModuleList([DecoderLayer(config) for _ in range(config.n_layer)])
        self.norm         = RMSNorm(config)
        self.lm_head      = nn.Linear(config.n_embd, config.vocab_size, bias=False, dtype=config.dtype)
        
        self.rope = LlamaRotaryEmbedding(config)
        self.rope_cache: Optional[Tuple[Tensor, Tensor]] = None

    def get_kv_head_dim(self, config: Config) -> int:
        queries_per_kv = config.n_head // config.n_query_groups
        key_value_heads = config.n_head // queries_per_kv
        kv_dim = config.hidden_dim // config.n_head * key_value_heads
        return kv_dim

    def build_rope_cache(self, idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        position_ids = torch.arange(self.config.seq_length, device=idx.device).unsqueeze(0)
        return self.rope(idx, position_ids)
      
    def forward(
        self,
        input_ids: torch.LongTensor = None,
    ) -> Tensor:
      
        B, T = input_ids.shape
        
        assert T <= self.config.seq_length, f"Input length {T} exceeds maximum model length {self.config.seq_length}"
        
        max_seq_length = self.config.seq_length
        x = self.embed_tokens(input_ids)  
        
        if self.rope_cache is None :
            self.rope_cache = self.build_rope_cache(x)
        
        cos, sin = self.rope_cache
        
        cos = cos[:, :T]
        sin = sin[:, :T]

        for block in self.layers:
            x = block(x, (cos, sin), max_seq_length)

        x = self.norm(x)
    
        return self.lm_head(x) 

    def generate(self, input_ids: torch.LongTensor, max_length: int = Optional[100], sample: Optional[bool] = False) -> torch.LongTensor:
        self.eval()
        with torch.no_grad():
            for _ in range(max_length):
                logits = self(input_ids)
                if sample:
                    next_token = torch.multinomial(logits[:, -1].softmax(dim=-1), num_samples=1)
                else:
                    next_token = torch.argmax(logits[:, -1], dim=-1).unsqueeze(-1)
                    
                input_ids = torch.cat([input_ids, next_token], dim=-1)

                if next_token.item() == self.config.stop_token_id:
                    break
        return input_ids

# Load Original Model

In [10]:

def remap_state_dict(state_dict):
    """
    Remaps the state dict keys from LlamaForCausalLM format to your custom model format
    """
    new_state_dict = OrderedDict()
    
    # Create a mapping dictionary for the different naming conventions
    key_mapping = {
        'model.embed_tokens': 'embed_tokens',
        'model.norm': 'norm',
        'model.layers': 'layers',
        'self_attn': 'attn',
        'input_layernorm': 'norm1',
        'post_attention_layernorm': 'norm2',
        'mlp': 'ffn'
    }
    
    for key, value in state_dict.items():
        new_key = key
        
        # Apply the mappings
        for old, new in key_mapping.items():
            new_key = new_key.replace(old, new)
            
        # Handle specific cases where tensor shapes need to be validated
        if 'attn' in new_key:
            # Ensure attention weights have compatible shapes
            if value.shape != state_dict[key].shape:
                raise ValueError(f"Incompatible shape for attention weights: {key}")
                
        new_state_dict[new_key] = value
    
    return new_state_dict

def load_model_weights(model, safetensors_path):
    """
    Loads and remaps weights from a safetensors file to your custom model
    
    Args:
        model: Your custom model instance
        safetensors_path: Path to the safetensors file
    """
    try:
        # Load the safetensors file
        original_state_dict = load_file(safetensors_path)
        
        # Remap the state dict to match your model's architecture
        new_state_dict = remap_state_dict(original_state_dict)
        
        # Load the remapped weights into your model
        missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
        
        print("Model loaded successfully!")
        if missing_keys:
            print("Missing keys:", missing_keys)
        if unexpected_keys:
            print("Unexpected keys:", unexpected_keys)
            
        return True
        
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return False


# validation

In [11]:
tiny_LLaMA_1b = LlamaModel(name_to_config["tiny_LLaMA_1b"])

In [12]:
load_model_weights(tiny_LLaMA_1b, "model/models--TinyLlama--TinyLlama-1.1B-Chat-v1.0/snapshots/fe8a4ea1ffedaf415f4da2f062534de366a451e6/model.safetensors")

Model loaded successfully!


True

In [13]:
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,  map_device="auto", add_eos_token=True, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    cache_dir="model",
    # attn_implementation=attn_implementation
)

In [14]:
inputs_ids = tokenizer("Hello, how are you?", return_tensors="pt").input_ids

In [15]:
out1 = tiny_LLaMA_1b(inputs_ids)
out2 = model(inputs_ids, output_attentions=True)



In [16]:
assert (out1 == out2.logits).all(), "The model output is different from the reference model output"

In [43]:
chat = [
  {'role': 'user', 'content': 'Hello, how are you?'},
]

inputs_ids = tokenizer.apply_chat_template(chat, return_tensors="pt", tokenize=True, add_generation_prompt=True)

In [44]:
out = tiny_LLaMA_1b.generate(inputs_ids, max_length=10, sample=False)

In [45]:
tokenizer.decode(out[0])

'<|user|>\nHello, how are you?</s> \n<|assistant|>\nI am doing well, thank you. How are'