In [1]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Instruct",trust_remote_code = True)

In [2]:
tokenizer.encode("Hello workd")

[19180, 1133, 67]

In [3]:
tokenizer.decode([19180, 1133, 67])

'Hello workd'

In [4]:
tokenizer.vocab_size

163842

In [9]:
import torch 
from torch import nn
from typing import Optional

In [10]:
class Embeddings(nn.Module):
    def __init__(self, embed_dim = 768, vocab_size = 163842):
        super().__init__()
        self.table = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return self.table(x)


In [11]:
embedder = Embeddings()
embedder(torch.LongTensor(tokenizer.encode("Hello workd"))).shape

torch.Size([3, 768])

# ROPE
Applying rope using torch tune, borrowed from torchtune repo as torchtune was not installing for some reason

In [12]:
class RotaryPositionalEmbeddings(nn.Module):
    def __init__(
        self,
        dim: int,
        max_seq_len: int = 4096,
        base: int = 10_000,
    ) -> None:
        """
        Args:
            dim (int): Embedding dimension.
            max_seq_len (int): Maximum expected sequence length.
            base (int): The base for the geometric progression.
        """
        super().__init__()
        self.dim = dim
        self.base = base
        self.max_seq_len = max_seq_len
        self.rope_init()

    def rope_init(self):
        theta = 1.0 / (
            self.base
            ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim)
        )
        self.register_buffer("theta", theta, persistent=False)
        self.build_rope_cache(self.max_seq_len)

    def build_rope_cache(self, max_seq_len: int = 4096) -> None:
        seq_idx = torch.arange(
            max_seq_len, dtype=self.theta.dtype, device=self.theta.device
        )
        idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float()
        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
        self.register_buffer("cache", cache, persistent=False)

    def forward(
        self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): Input tensor.
            input_pos (Optional[torch.Tensor]): Tensor containing position ids of each token.
        """
        seq_len = x.size(1)
        rope_cache = (
            self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
        )
        xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
        rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
        x_out = torch.stack(
            [
                xshaped[..., 0] * rope_cache[..., 0]
                - xshaped[..., 1] * rope_cache[..., 1],
                xshaped[..., 1] * rope_cache[..., 0]
                + xshaped[..., 0] * rope_cache[..., 1],
            ],
            -1,
        )
        x_out = x_out.flatten(3)
        return x_out.type_as(x)

In [13]:
batch_size = 2
seq_len = 1024
num_heads = 8
head_dim = 64

# --- Test ---
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=seq_len)
input_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim)
output_tensor = rope(input_tensor)

print(f"Input shape:  {input_tensor.shape}")
print(f"Output shape: {output_tensor.shape}")

# --- Test with 'input_pos' for inference-style caching ---
# Simulating a single new token at position 5
input_pos_tensor = torch.tensor([5], dtype=torch.long)
single_token_tensor = torch.randn(batch_size, 1, num_heads, head_dim)
output_single_token = rope(single_token_tensor, input_pos=input_pos_tensor)

print(f"\nSingle token input shape:  {single_token_tensor.shape}")
print(f"Single token output shape: {output_single_token.shape}")

Input shape:  torch.Size([2, 1024, 8, 64])
Output shape: torch.Size([2, 1024, 8, 64])

Single token input shape:  torch.Size([2, 1, 8, 64])
Single token output shape: torch.Size([2, 1, 8, 64])


# Row and Column  Parallelism not needed
We dont need Row and Column parallel. (Used for multigpu training)

# Multi Head Attention Module
This will be simplified attention module with skipping Model parallelism, Low rank adaptation.