In [28]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Instruct",trust_remote_code = True)

In [29]:
tokenizer.encode("Hello workd")

[19180, 1133, 67]

In [30]:
tokenizer.decode([19180, 1133, 67])

'Hello workd'

In [31]:
tokenizer.vocab_size

163842

In [32]:
import torch 
from torch import nn
import torch.nn.functional as F
from typing import Optional, Union
from config import Config
config = Config()

In [33]:
class Embeddings(nn.Module):
    def __init__(self, embed_dim = 768, vocab_size = 163842):
        super().__init__()
        self.table = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return self.table(x)


In [23]:
embedder = Embeddings()
embedder(torch.LongTensor(tokenizer.encode("Hello workd"))).shape

torch.Size([3, 768])

# ROPE
Applying rope using torch tune, borrowed from torchtune repo as torchtune was not installing for some reason

In [24]:
class RotaryPositionalEmbeddings(nn.Module):
    def __init__(
        self,
        dim: int,
        max_seq_len: int = 4096,
        base: int = 10_000,
    ) -> None:
        """
        Args:
            dim (int): Embedding dimension.
            max_seq_len (int): Maximum expected sequence length.
            base (int): The base for the geometric progression.
        """
        super().__init__()
        self.dim = dim
        self.base = base
        self.max_seq_len = max_seq_len
        self.rope_init()

    def rope_init(self):
        theta = 1.0 / (
            self.base
            ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim)
        )
        self.register_buffer("theta", theta, persistent=False)
        self.build_rope_cache(self.max_seq_len)

    def build_rope_cache(self, max_seq_len: int = 4096) -> None:
        seq_idx = torch.arange(
            max_seq_len, dtype=self.theta.dtype, device=self.theta.device
        )
        idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float()
        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
        self.register_buffer("cache", cache, persistent=False)

    def forward(
        self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): Input tensor.
            input_pos (Optional[torch.Tensor]): Tensor containing position ids of each token.
        """
        seq_len = x.size(1)
        rope_cache = (
            self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
        )
        xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
        rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
        x_out = torch.stack(
            [
                xshaped[..., 0] * rope_cache[..., 0]
                - xshaped[..., 1] * rope_cache[..., 1],
                xshaped[..., 1] * rope_cache[..., 0]
                + xshaped[..., 0] * rope_cache[..., 1],
            ],
            -1,
        )
        x_out = x_out.flatten(3)
        return x_out.type_as(x)

In [25]:
batch_size = 2
seq_len = 1024
num_heads = 8
head_dim = 64

# --- Test ---
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=seq_len)
input_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim)
output_tensor = rope(input_tensor)

print(f"Input shape:  {input_tensor.shape}")
print(f"Output shape: {output_tensor.shape}")

# --- Test with 'input_pos' for inference-style caching ---
# Simulating a single new token at position 5
input_pos_tensor = torch.tensor([5], dtype=torch.long)
single_token_tensor = torch.randn(batch_size, 1, num_heads, head_dim)
output_single_token = rope(single_token_tensor, input_pos=input_pos_tensor)

print(f"\nSingle token input shape:  {single_token_tensor.shape}")
print(f"Single token output shape: {output_single_token.shape}")

Input shape:  torch.Size([2, 1024, 8, 64])
Output shape: torch.Size([2, 1024, 8, 64])

Single token input shape:  torch.Size([2, 1, 8, 64])
Single token output shape: torch.Size([2, 1, 8, 64])


# Row and Column  Parallelism not needed
We dont need Row and Column parallel. (Used for multigpu training)

# Multi Head Attention Module
This will be simplified attention module with skipping Model parallelism, Low rank adaptation.

There is a change in use of Forward method because we are using a using different implementation of Rotary Embeddings which uses cache

In [None]:
class MLA(nn.Module):
    """
    Multi-Head Latent Attention (MLA) Layer
    Attributes:
        dim(int): Dimensionality of input features
        n_heads(int): Number of heads
        qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
        qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
        qk_head_dim (int): Total dimensionality of query/key projections.
        v_head_dim (int): Dimensionality of value projections.
        softmax_scale (float): Scaling factor for softmax in attention computation.
    """

    def __init__(self, config: Config):
        super().__init__()
        self.dim = config.dim
        self.n_heads = config.n_heads
        self.qk_nope_head_dim = config.qk_nope_head_dim
        self.qk_rope_head_dim = config.qk_rope_head_dim
        self.qk_head_dim = config.qk_rope_head_dim
        self.qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
        self.v_head_dim = config.v_head_dim
        # Projections
        
        self.wq = nn.Linear(self.dim, self.n_heads * self.qk_head_dim, bias=False)
        self.wk = nn.Linear(self.dim, self.n_heads * self.qk_head_dim, bias=False)
        self.wv = nn.Linear(self.dim, self.n_heads * self.v_head_dim, bias=False)
        self.rope = RotaryPositionalEmbeddings(
            dim=self.qk_rope_head_dim, 
            max_seq_len=config.max_seq_length
        )
        self.register_buffer("k_cache", torch.zeros(config.max_batch_size, config.max_seq_length, config.n_heads, self.qk_head_dim))
        self.register_buffer("v_cache", torch.zeros(config.max_batch_size, config.max_seq_length, config.n_heads, config.v_head_dim))
        self.softmax_scale = self.qk_head_dim ** -0.5
        self.c_v_linear = nn.Linear(in_features=self.n_heads * self.v_head_dim, out_features=self.dim)


    def forward(self, x:torch.Tensor, start_pos:int, input_pos:Optional[int], mask:Optional[torch.Tensor]):
        b, seq_len,_ = x.size()
        end_pos = start_pos + seq_len
        
        # Project to Q, K, V
        q = self.wq(x).view(b, seq_len, self.n_heads, self.qk_head_dim)
        k = self.wk(x).view(b, seq_len, self.n_heads, self.qk_head_dim) 
        v = self.wv(x).view(b, seq_len, self.n_heads, self.v_head_dim)
        
        q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        k_nope, k_rope = torch.split(k, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        # original used q_pe for rope (erase after implementation)
        q_rope = self.rope(q_rope,input_pos = input_pos)
        k_rope = self.rope(k_rope, input_pos=input_pos)

        # recombone the rotated and not rotated q and k
        q = torch.cat([q_nope, q_rope], dim=-1)
        k = torch.cat([k_nope, k_rope], dim=-1)
        
        # Update cache
        self.k_cache[b, start_pos:end_pos] = k
        self.v_cache[b, start_pos:end_pos] = v

        # Trying einstein notation :)
        scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:b, :end_pos]) * self.softmax_scale
        #batch,seq_length_query,num_heads,head_dim   | batch,seq_lenght_key,num_heads,head_dim ->batch,seq_length_key,num_heads,seq_lenght_key

        # Apply mask
        if mask is not None:
            scores += mask.unsqueeze(1)
        # Get Attention Weights

        weights = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
        
        # Context vector
        c_v = torch.einsum("bsht,bthd->bshd", weights, self.v_cache[:b, :end_pos])
        return self.c_v_linear(c_v.flatten(2))
        
    


        
        


SyntaxError: incomplete input (3450621908.py, line 13)