In [99]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [100]:
from types import SimpleNamespace

In [101]:
class LlamaRMSNorm(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.weight = nn.Parameter(
            torch.ones(self.embed_dim,dtype=torch.float32),
            requires_grad=True
        )
        
    def forward(self, x):
        # x [B, S, D]
        mean = x.pow(2).mean(dim=-1,keepdim=True)
        r_sqrt = x * torch.rsqrt(mean + 1e-5) # [B, S, 1]
        y = r_sqrt * self.weight
        return y.to(x.dtype)


class SiLU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        # x [B S D]
        return x * F.sigmoid(x)


class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.intermediate_dim = config.intermediate_dim
        self.gate_proj = nn.Linear(self.embed_dim, self.intermediate_dim, bias=False, dtype=config.dtype)
        self.up_proj = nn.Linear(self.embed_dim, self.intermediate_dim, bias=False, dtype=config.dtype)
        self.down_proj = nn.Linear(self.intermediate_dim, self.embed_dim, bias=False, dtype=config.dtype)
        self.act_fn = SiLU()
        
    def forward(self, x):
        # x [B S D]
        x1 = self.gate_proj(x)
        x2 = self.up_proj(x)
        x = self.act_fn(x1) * x2
        x = self.down_proj(x)
        return x

# RoPE

In [102]:
def precompute_rope(head_dim, base_theta=10_000, context_length=4096):
    k = torch.arange(0,head_dim,2)[:head_dim//2].float()
    inv_freq = 1 / (base_theta ** (k/head_dim))

    positions = torch.arange(context_length)
    angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # [S, H/2]
    angles = torch.cat([angles, angles],dim=-1) # [S, H]

    cos = torch.cos(angles) # [S, H]
    sin = torch.sin(angles) # [S, H]


    return cos, sin

In [103]:
def apply_rope(x, cos, sin, positions=None):
    B, nH, S, H = x.shape
    x1 = x[...,:H//2] # [B, nH, S, H/2]
    x2 = x[...,H//2:] # [B, nH, S, H/2]
    if positions is None:
        positions = torch.arange(S)
    cos_values = cos[positions,:].unsqueeze(0).unsqueeze(1) # [1,1,S,H]
    sin_values = sin[positions,:].unsqueeze(0).unsqueeze(1) # [1,1,S,H]
    rotated = torch.cat([-x2,x1],dim=-1)
    x_rope = (x * cos_values) + (rotated * sin_values)
    return x_rope.to(x.dtype)

# KV Cache

In [104]:
torch.arange(3,5)

tensor([3, 4])

In [105]:
class KVCache:
    def __init__(self, max_length, head_dim, n_heads, dtype=torch.float32, device='cpu'):
        self.max_length = max_length
        self.head_dim = head_dim
        self.n_heads = n_heads
        self.device = device
        self.dtype = dtype
        self.reset()

    def reset(self):
        self.keys = torch.zeros((1, self.n_heads, self.max_length, self.head_dim),
                                device=self.device, dtype=self.dtype)
        self.values = torch.zeros((1, self.n_heads, self.max_length, self.head_dim),
                                  device=self.device, dtype=self.dtype)
        self.length = 0

    def update(self, new_key, new_value):
        # new_key/new_value: [B, n_heads, S, head_dim] (with B==1 during inference)
        S = new_key.shape[2]
        assert self.length + S <= self.max_length, "KV cache overflow"
        seq_start = self.length
        seq_end = seq_start + S
        self.keys[:, :, seq_start:seq_end, :] = new_key
        self.values[:, :, seq_start:seq_end, :] = new_value
        self.length = seq_end

    def get(self):
        if self.length == 0:
            return None, None
        return self.keys[:, :, :self.length, :], self.values[:, :, :self.length, :]

In [106]:
kvcache = KVCache(100,32,4)

In [107]:
kvcache.keys.shape

torch.Size([1, 4, 100, 32])

In [108]:
k = torch.rand(1,4,10,32)
v = torch.rand(1,4,10,32)
kvcache.update(k,v)

In [109]:
kvcache.length

10

In [110]:
past_k, past_v = kvcache.get()
past_k.shape, past_v.shape

(torch.Size([1, 4, 10, 32]), torch.Size([1, 4, 10, 32]))

In [111]:
torch.allclose(k,past_k), torch.allclose(v,past_v)

(True, True)

In [112]:
torch.arange(kvcache.length,kvcache.length+1)

tensor([10])

In [113]:
new_k = torch.rand(1,4,1,32)
new_v = torch.rand(1,4,1,32)
kvcache.update(new_k,new_v)

In [114]:
past_k, past_v = kvcache.get()
past_k.shape, past_v.shape

(torch.Size([1, 4, 11, 32]), torch.Size([1, 4, 11, 32]))

In [115]:
full_k = torch.cat([k,new_k],dim=2)
full_v = torch.cat([v,new_v],dim=2)
full_k.shape, full_v.shape

(torch.Size([1, 4, 11, 32]), torch.Size([1, 4, 11, 32]))

In [116]:
torch.allclose(past_k,full_k), torch.allclose(past_v,full_v)

(True, True)

In [117]:
cos, sin = precompute_rope(32,)

In [118]:
rope_k = apply_rope(full_k, cos, sin)

In [119]:
rope_k.shape

torch.Size([1, 4, 11, 32])

In [120]:
new_k_rope = apply_rope(new_k, cos, sin, positions=torch.arange(kvcache.length,kvcache.length+1)-1)

In [121]:
new_k_rope.shape

torch.Size([1, 4, 1, 32])

In [122]:
torch.allclose(rope_k[:,:,[-1],:],new_k_rope)

True

In [123]:
kvcache.length

11

# Integrating it into GQA

In [124]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.embed_dim = config.embed_dim
        self.num_kv_heads = config.num_kv_heads
        self.num_q_heads = config.num_q_heads
        self.max_position_embeddings = config.max_position_embeddings

        assert self.embed_dim % self.num_q_heads == 0, "embed_dim must be divisible by num_q_heads"
        assert self.num_q_heads % self.num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads"

        self.head_dim = self.embed_dim // self.num_q_heads

        self.q_proj = nn.Linear(self.embed_dim, self.head_dim * self.num_q_heads, bias=False, dtype=config.dtype)
        self.k_proj = nn.Linear(self.embed_dim, self.head_dim * self.num_kv_heads, bias=False, dtype=config.dtype)
        self.v_proj = nn.Linear(self.embed_dim, self.head_dim * self.num_kv_heads, bias=False, dtype=config.dtype)

        self.drop = nn.Dropout(config.attn_dropout)
        self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False, dtype=config.dtype)

        self.register_buffer(
            "causal_mask",
            torch.triu(torch.ones(config.max_position_embeddings, config.max_position_embeddings), diagonal=1)
        )

        cos, sin = precompute_rope(self.head_dim, base_theta=config.base_theta,  # Important: RoPE applies to half dimension
                                  context_length=self.max_position_embeddings)
        self.register_buffer("rope_cos", cos)
        self.register_buffer("rope_sin", sin)

        self.kv_cache = None
        self.use_cache = False

    def enable_kv_cache(self, dtype=None):
        self.kv_cache = KVCache(self.max_position_embeddings, self.head_dim, self.num_kv_heads, dtype, self.rope_cos.device)
        self.use_cache = True
    
    def reset_kv_cache(self):
        if self.kv_cache is not None:
            self.kv_cache.reset()

    def forward(self, x):
        # x: [B, S, D]
        B, S, D = x.shape

        q = self.q_proj(x) # [B S H*nQ]
        k = self.k_proj(x) # [B S H*nKV]
        v = self.v_proj(x) # [B S H*nKV]

        q = q.view(B, S, self.num_q_heads, self.head_dim).transpose(1,2) # [B nQ S H]
        k = k.view(B, S, self.num_kv_heads, self.head_dim).transpose(1,2) # [B nKV S H]
        v = v.view(B, S, self.num_kv_heads, self.head_dim).transpose(1,2) # [B nKV S H]

        if self.use_cache and self.kv_cache is not None:
            assert B == 1, "Batch size must be 1 in inference when using KV cache."
            past_length = self.kv_cache.length
            positions = torch.arange(past_length, past_length + S, device=x.device)
        else:
            positions = torch.arange(0, S, device=x.device)

        # Apply RoPE
        q = apply_rope(q, self.rope_cos, self.rope_sin, positions)
        k = apply_rope(k, self.rope_cos, self.rope_sin, positions)

        if self.use_cache and self.kv_cache is not None:
            self.kv_cache.update(k, v)
            k, v = self.kv_cache.get()
            total_length = k.shape[2]
        else:
            total_length = S

        k = k.repeat_interleave(self.num_q_heads//self.num_kv_heads, dim=1) # [B nQ S H]
        v = v.repeat_interleave(self.num_q_heads//self.num_kv_heads, dim=1) # [B nQ S H]

        attn = q @ k.transpose(2,3) # [B nQ S1 H] @ [B nQ H S2] = [B nQ S1 S2]
        
        if self.use_cache and self.kv_cache is not None:
            mask = self.causal_mask[past_length:past_length+S, :total_length].bool()
        else:
            mask = self.causal_mask[:S, :S].bool()
            
        attn.masked_fill_(mask, -torch.inf)
        
        attn = F.softmax(attn / (self.head_dim ** 0.5), dim=-1)

        attn = self.drop(attn)

        out = attn @ v # [B nQ S S] @ [B nQ S H] = [B nQ S H]
        out = out.transpose(1,2) # [B S nQ H]
        out = out.reshape(B, S, D)

        proj = self.o_proj(out)
        
        return proj

In [125]:
class LlamaDecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.self_attn = GroupedQueryAttention(config)
        self.mlp = LlamaMLP(config)
        
        self.input_layernorm = LlamaRMSNorm(config.embed_dim)
        self.post_attention_layernorm = LlamaRMSNorm(config.embed_dim)
        
        
    def forward(self, x):
        # x [B S D]
        skip = x
        x = self.input_layernorm(x)
        x = self.self_attn(x)
        x = x + skip
        
        skip = x
        x = self.post_attention_layernorm(x)
        x = self.mlp(x)
        x = x + skip
        
        return x


class LLaMA(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_tokens = nn.Embedding(
            self.config.vocab_size, 
            self.config.embed_dim, 
            padding_idx=self.config.eos_token_id,
            dtype=self.config.dtype)
        self.layers = nn.ModuleList([
            LlamaDecoderLayer(self.config) for _ in range(self.config.num_layers)
        ])

        self.norm = LlamaRMSNorm(self.config.embed_dim)
        self.lm_head = nn.Linear(self.config.embed_dim, self.config.vocab_size, bias=False, dtype=self.config.dtype)

        self._tie_weights()

    def _tie_weights(self):
        self.lm_head.weight = self.embed_tokens.weight

    def enable_kv_cache(self):
        for layer in self.layers:
            layer.self_attn.enable_kv_cache(dtype=self.config.dtype)

    def reset_kv_cache(self):
        for layer in self.layers:
            layer.self_attn.kv_cache.reset()
        
    def forward(self, input_ids):
        # input_ids [B S]
        x = self.embed_tokens(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        logits = self.lm_head(x)
        return logits

In [126]:
config = SimpleNamespace(
    embed_dim = 576,
    intermediate_dim = 1536,
    max_position_embeddings = 8192,
    base_theta = 100000,
    num_q_heads = 9,
    num_kv_heads = 3,
    attn_dropout = 0.,
    num_layers = 30,
    vocab_size = 49152,
    dtype = torch.bfloat16,
    eos_token_id = 2
    )
model = LLaMA(config)
model.eval()
model.enable_kv_cache()

# Inference with KV Cache

In [127]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained('./simpleVLM')
smol = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")

smol_sd = smol.state_dict()
model_sd = model.state_dict()
smol_sd = {k:v for k,v in smol_sd.items() if not any([s in k for s in ['rope','causal_mask']])}

for smol_key,smol_value in smol_sd.items():
    model_key = smol_key.replace('model.','')
    model_sd[model_key] = smol_value.clone()

model.load_state_dict(model_sd)

<All keys matched successfully>

In [128]:
def get_input(text):
    messages = [{"role": "user", "content": text}]
    input_text=tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer.encode(input_text, return_tensors="pt")
    return inputs

In [147]:
def generate(
    model,
    config,
    input_ids,
    max_new_tokens=32,
    temperature=0.0,
):
    model.eval()
    
    context_length = model.config.max_position_embeddings
    eos_token_id = model.config.eos_token_id
    
    model.reset_kv_cache()
    
    inputs = input_ids.clone()
    if inputs.shape[1] > context_length:
        inputs = inputs[:, -context_length:]
    print(tokenizer.decode(inputs.flatten().tolist()))
    
    # Prefill
    with torch.inference_mode():
        _ = model(inputs)
    
    all_tokens = inputs
    for token_idx in range(max_new_tokens):
        with torch.inference_mode():
            last_token = all_tokens[:, [-1]]
            logits = model(last_token)
            next_token_logits = logits[:, -1, :]
            
            if temperature > 0.:
                next_token_logits = next_token_logits / temperature
                probs = torch.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
            
            if next_token.item() == eos_token_id:
                break
            
            print(tokenizer.decode(next_token.flatten().tolist()), end='')
            all_tokens = torch.cat([all_tokens, next_token], dim=1)
    print()
    return all_tokens

In [148]:
inputs = get_input('give me a random fact about llamas')
generated = generate(model, config, inputs, max_new_tokens=250, temperature=0.)

<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
give me a random fact about llamas<|im_end|>

llama* is a large, four-legged mammal native to the Andes Mountains of South America. They are known for their long, muscular legs and powerful legs, which are used for jumping and running. Llamas are herbivores, and they eat a diet of grasses, leaves, and other vegetation. They are also known for their ability to climb trees, which they use to reach high branches and reach food sources. Llamas are also known for their unique ability to produce a strong, sticky mucus that helps them to climb and move through the Andes.


In [149]:
inputs = get_input('can you do function calling?')
generated = generate(model, config, inputs, max_new_tokens=250, temperature=0.)

<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
can you do function calling?<|im_end|>

<|im_start|>assistant
 here's a simple function that calls another function:

```python
function(func):
 func()urn

(): main
function(lambda: print("Hello"))

name__ == "__main__":
    main()
```

call_function` is a function that takes a lambda function as an argument. The lambda function is called by `call_function` and returns the result of calling `call_function` with the lambda function as an argument.

When you run this code, it will print "Hello". The `lambda` function is a special function in Python that is used to define a small, one-time-use function.

 `call_function` is a function, not a method. It's a way to call a function without creating an instance of the class.
