## Hacking 3 - DeepSeek v3 with Random weights - more tracing

So, I think I can just start with hacking/tracing through the MLA class...but where do i find the hyperparams...

In [1]:
from transformers import AutoTokenizer
import torch

import sys
sys.path.append('DeepSeek-V3/inference') #Github slightly newer
# sys.path.append('DeepSeek-V3/DeepSeek-V3/inference') #Hugging face

from model import Transformer, MLA, ModelArgs, apply_rotary_emb

In [2]:
# import sys
# print(sys.executable)

In [3]:
args=ModelArgs(**{"vocab_size": 129280,
                "dim": 7168,
                "inter_dim": 18432,
                "moe_inter_dim": 2048,
                "n_layers": 61,
                "n_dense_layers": 3,
                "n_heads": 128,
                "n_routed_experts": 256,
                "n_shared_experts": 1,
                "n_activated_experts": 8,
                "n_expert_groups": 8,
                "n_limited_groups": 4,
                "route_scale": 2.5,
                "score_func": "sigmoid",
                "q_lora_rank": 1536,
                "kv_lora_rank": 512,
                "qk_nope_head_dim": 128,
                "qk_rope_head_dim": 64,
                "v_head_dim": 128,
                "dtype": "bf16"}) #fp8 seems out due to my hardware? bf16

In [4]:
args

ModelArgs(max_batch_size=8, max_seq_len=16384, dtype='bf16', vocab_size=129280, dim=7168, inter_dim=18432, moe_inter_dim=2048, n_layers=61, n_dense_layers=3, n_heads=128, n_routed_experts=256, n_shared_experts=1, n_activated_experts=8, n_expert_groups=8, n_limited_groups=4, score_func='sigmoid', route_scale=2.5, q_lora_rank=1536, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, original_seq_len=4096, rope_theta=10000.0, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0)

In [5]:
# mla=MLA(args)
model=Transformer(args)

In [6]:
# mla

Ok that's cool - can I get an actual text encoding and then embedding? That would be nice.

In [7]:
tokenizer = AutoTokenizer.from_pretrained('/home/stephen/deepseek/DeepSeek-V3/DeepSeek-V3')

In [8]:
tokens=tokenizer.encode("The American flag is red, white, and")
tokens

[0, 671, 3707, 14364, 344, 4332, 14, 5403, 14, 305]

In [9]:
tokenizer.decode(tokens)

'<｜begin▁of▁sentence｜>The American flag is red, white, and'

In [10]:
len(tokens)

10

In [11]:
embedded_prompt=model.embed(torch.tensor(tokenizer.encode("The American flag is red, white, and")))

In [12]:
embedded_prompt.shape #Ok gussing that weight are just zeros right now or something??

torch.Size([10, 7168])

In [13]:
model.forward(torch.tensor([tokens]), start_pos=0) #Dope! Ran on CPU somehow - took a couple minutes. 

tensor([[0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.bfloat16)

- Ok, now how do I get into the first (or any really) MLA layer?
- Maybe it's just like, can i grab or replicate the args?
- Ideally without having to wait for forward each time?

```
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
```

In [14]:
# seqlen = tokens.size(1)
seqlen = torch.tensor([tokens]).size(1)
seqlen

10

In [15]:
h = model.embed(torch.tensor([tokens]))

In [16]:
# torch.tensor([tokens]).device

In [17]:
h.shape

torch.Size([1, 10, 7168])

In [18]:
start_pos=0

In [19]:
freqs_cis = model.freqs_cis[start_pos:start_pos+seqlen]

In [20]:
freqs_cis.shape #Love that this is complex, pretty cool. 

torch.Size([10, 32])

In [21]:
mask = None
if seqlen > 1:
    mask = torch.full((seqlen, seqlen), float("-inf"), device='cpu').triu_(1) #CPUing for now

In [22]:
mask

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [23]:
h_out=model.layers[0].forward(h, start_pos, freqs_cis, mask)

In [24]:
h_out.shape

torch.Size([1, 10, 7168])

In [25]:
h_out

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], grad_fn=<AddBackward0>)

In [26]:
model.layers[0].attn

MLA(
  (wq_a): Linear()
  (q_norm): RMSNorm()
  (wq_b): ColumnParallelLinear()
  (wkv_a): Linear()
  (kv_norm): RMSNorm()
  (wkv_b): ColumnParallelLinear()
  (wo): RowParallelLinear()
)

In [27]:
h_out=model.layers[0].attn.forward(h, start_pos, freqs_cis, mask)

In [28]:
h_out.shape

torch.Size([1, 10, 7168])

- Ok, so there's my arguments - maybe I pickle them real quick to I can pick back up if needed?
- From there I should be able to replicate/walk through the MLA forward pass, right?

In [29]:
# import pickle

# # Pickling
# with open('mla_dummy_inputs.p', 'wb') as file:
#     pickle.dump((h, start_pos, freqs_cis, mask), file)

# # Unpickling
# with open('mla_dummy_inputs.p', 'rb') as file:
#     (h, start_pos, freqs_cis, mask) = pickle.load(file)

In [30]:
h_out=model.layers[0].attn.forward(h, start_pos, freqs_cis, mask)

In [31]:
h_out.shape

torch.Size([1, 10, 7168])

- ok ok ok ok I think it's probably worth spending a little time understanding what's going on in `__init__`

In [32]:
model.layers[0].attn.dim

7168

In [33]:
model.layers[0].attn.n_heads

128

In [34]:
model.layers[0].attn.n_local_heads

128

In [35]:
model.layers[0].attn.q_lora_rank #q_lora_rank (int): Rank for low-rank query projection.

1536

In [36]:
model.layers[0].attn.kv_lora_rank #Rank for low-rank key/value projection.

512

In [37]:
model.layers[0].attn.qk_nope_head_dim #Dimensionality of non-positional query/key projections.

128

In [38]:
model.layers[0].attn.qk_rope_head_dim #(int): Dimensionality of rotary-positional query/key projections.

64

In [39]:
model.layers[0].attn.v_head_dim #(int): Dimensionality of value projections.

128

In [40]:
model.layers[0].attn.softmax_scale #(float): Scaling factor for softmax in attention computation.

0.1352337788608801

In [41]:
model.layers[0].attn.qk_head_dim #(int): Total dimensionality of query/key projections.

192

```
if self.q_lora_rank == 0:
    self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
else:
    self.wq_a = Linear(self.dim, self.q_lora_rank)
    self.q_norm = RMSNorm(self.q_lora_rank)
    self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
```

In [42]:
model.layers[0].attn.wq_a.weight.shape

torch.Size([1536, 7168])

In [43]:
model.layers[0].attn.wq_b.weight.shape

torch.Size([24576, 1536])

```
        self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
        self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
        self.softmax_scale = self.qk_head_dim ** -0.5
```

In [44]:
model.layers[0].attn.wkv_a.weight.shape

torch.Size([576, 7168])

In [45]:
model.layers[0].attn.wkv_b.weight.shape

torch.Size([32768, 512])

In [46]:
model.layers[0].attn.wo.weight.shape

torch.Size([7168, 16384])

```
if attn_impl == "naive":
            self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
            self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
        else:
            self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
            self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
            ```

In [47]:
model.layers[0].attn.register_buffer

<bound method Module.register_buffer of MLA(
  (wq_a): Linear()
  (q_norm): RMSNorm()
  (wq_b): ColumnParallelLinear()
  (wkv_a): Linear()
  (kv_norm): RMSNorm()
  (wkv_b): ColumnParallelLinear()
  (wo): RowParallelLinear()
)>

```
register_buffer in PyTorch is a method used to add a tensor to a module that should not be considered a model parameter. This means that while the tensor will be saved and loaded with the model's state dictionary and moved to the correct device, it will not be updated by the optimizer during training. It is typically used for tensors that are part of the model's state but are not learned parameters.
```

Cache size for naive key value caching:
```
args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim
args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim
```

For for MLA: 
```
args.max_batch_size, args.max_seq_len, self.kv_lora_rank
args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim
```

Ok wow yeah cools

In [48]:
model.layers[0].attn.kv_cache.shape

torch.Size([8, 16384, 512])

In [49]:
model.layers[0].attn.pe_cache.shape

torch.Size([8, 16384, 64])

Ok, now forward pass

In [50]:
h_out=model.layers[0].attn.forward(h, start_pos, freqs_cis, mask)

In [51]:
x=h
mla=model.layers[0].attn

In [52]:
bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen
if mla.q_lora_rank == 0:
    q = self.wq(x)
else:
    q = mla.wq_b(mla.q_norm(mla.wq_a(x)))
    print('hay')

hay


In [53]:
h1=mla.wq_a(x)
h2=mla.q_norm(mla.wq_a(x))
q=mla.wq_b(mla.q_norm(mla.wq_a(x)))

In [54]:
x.shape, h1.shape, h2.shape, q.shape

(torch.Size([1, 10, 7168]),
 torch.Size([1, 10, 1536]),
 torch.Size([1, 10, 1536]),
 torch.Size([1, 10, 24576]))

- Ok so those are the queries for all the heads? I guess I didn't realize, or maybe I did that they were also doing query compression? How important is this? I guess it's less compute than going straight to 7168x24576...

In [55]:
q = q.view(bsz, seqlen, mla.n_local_heads, mla.qk_head_dim) #Ok splitting out the queries across all the heads

In [56]:
q.shape

torch.Size([1, 10, 128, 192])

In [57]:
q_nope, q_pe = torch.split(q, [mla.qk_nope_head_dim, mla.qk_rope_head_dim], dim=-1)

In [58]:
q_nope.shape, q_pe.shape

(torch.Size([1, 10, 128, 128]), torch.Size([1, 10, 128, 64]))

In [59]:
q_pe = apply_rotary_emb(q_pe, freqs_cis)

In [60]:
q_pe.shape

torch.Size([1, 10, 128, 64])

In [61]:
kv = mla.wkv_a(x)

In [62]:
kv.shape

torch.Size([1, 10, 576])

In [63]:
x.shape

torch.Size([1, 10, 7168])

That's a lot of squishing! No wonder it saves so much memory bandiwidth.

In [64]:
kv, k_pe = torch.split(kv, [mla.kv_lora_rank, mla.qk_rope_head_dim], dim=-1)
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)

In [65]:
kv.shape, k_pe.shape

(torch.Size([1, 10, 512]), torch.Size([1, 10, 1, 64]))

- Ok cool cool so we're left with this last big lift, probably will be worth understanding the naive option:

```
        if attn_impl == "naive":
            q = torch.cat([q_nope, q_pe], dim=-1)
            kv = self.wkv_b(self.kv_norm(kv))
            kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
            self.k_cache[:bsz, start_pos:end_pos] = k
            self.v_cache[:bsz, start_pos:end_pos] = v
            scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale

```

- Main focus if of course the latent option

```
        else:
            wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) 
            wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
            q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
            self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
            self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
            scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
                      torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
```

In [66]:
wkv_b = mla.wkv_b.weight if mla.wkv_b.scale is None else weight_dequant(mla.wkv_b.weight, mla.wkv_b.scale, block_size) 
wkv_b = wkv_b.view(mla.n_local_heads, -1, mla.kv_lora_rank)

In [67]:
wkv_b.shape

torch.Size([128, 256, 512])

In [68]:
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :mla.qk_nope_head_dim])
q_nope.shape

torch.Size([1, 10, 128, 512])

In [69]:
mla.kv_cache[:bsz, start_pos:end_pos] = mla.kv_norm(kv)
mla.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)

In [70]:
scores = (torch.einsum("bshc,btc->bsht", q_nope, mla.kv_cache[:bsz, :end_pos].to(torch.bfloat16)) +
          torch.einsum("bshr,btr->bsht", q_pe, mla.pe_cache[:bsz, :end_pos].to(torch.bfloat16))) * mla.softmax_scale

In [71]:
scores.shape

torch.Size([1, 10, 128, 10])

- Hmm thought that would be then attention pattern? I guess maybe it is?
- In my head it was square and like keys by queries? Hmm is it tokens be queries? Need to think about that a little.

In [72]:
scores[0, :, 0, :]

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.bfloat16,
       grad_fn=<SliceBackward0>)

In [73]:
mask.shape

torch.Size([10, 10])

Ah ok right it's 128 10x10 attention patterns -> that could be kinda nice to visaulize! Maybe the mech in interp stuff could be good to review quickly - we'll where writing leads

In [74]:
if mask is not None:
    scores += mask.unsqueeze(1)

In [75]:
scores[0, :, 0, :]

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.bfloat16,
       grad_fn=<SliceBackward0>)

Yep cools. So attention patterns are tokens by tokens -> i guess wht makes sense. 

In [76]:
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)

In [77]:
scores[0, :, 0, :]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000, 0.0000,
         0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.0000,
         0.0000],
        [0.1111, 0.1111, 0.1111, 0.1111, 0.1111, 0.1111, 0.1111, 0.1111, 0.1111,
         0.0000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000]], grad_fn=<

I do feel like all zero weights makes this kinda tricky. Maybe I can switch to random at least - that shouldn't be bad actually right?

In [90]:
x = torch.einsum("bsht,btc->bshc", scores, mla.kv_cache[:bsz, :end_pos]) # On that's intresting - just the cache and attention pattern here.

In [91]:
x.shape

torch.Size([1, 10, 128, 512])

In [92]:
x = torch.einsum("bshc,hdc->bshd", x.to(torch.bfloat16), wkv_b[:, -mla.v_head_dim:]) #Not sure what this one is doing exactly...

In [93]:
x.shape

torch.Size([1, 10, 128, 128])

In [94]:
x.flatten(2).shape

torch.Size([1, 10, 16384])

In [95]:
x = mla.wo(x.flatten(2))

In [96]:
x.shape

torch.Size([1, 10, 7168])

And that's it. Ok I don't understand everything yet, but it doesn't seem so bad, and it runs! This is great. Back to writing next. 

In [84]:
q.shape #There are the queries for all the heads?

torch.Size([1, 10, 128, 192])

In [85]:
# model

In [86]:
model

Transformer(
  (embed): ParallelEmbedding()
  (layers): ModuleList(
    (0-2): 3 x Block(
      (attn): MLA(
        (wq_a): Linear()
        (q_norm): RMSNorm()
        (wq_b): ColumnParallelLinear()
        (wkv_a): Linear()
        (kv_norm): RMSNorm()
        (wkv_b): ColumnParallelLinear()
        (wo): RowParallelLinear()
      )
      (ffn): MLP(
        (w1): ColumnParallelLinear()
        (w2): RowParallelLinear()
        (w3): ColumnParallelLinear()
      )
      (attn_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
    (3-60): 58 x Block(
      (attn): MLA(
        (wq_a): Linear()
        (q_norm): RMSNorm()
        (wq_b): ColumnParallelLinear()
        (wkv_a): Linear()
        (kv_norm): RMSNorm()
        (wkv_b): ColumnParallelLinear()
        (wo): RowParallelLinear()
      )
      (ffn): MoE(
        (gate): Gate()
        (experts): ModuleList(
          (0-255): 256 x Expert(
            (w1): Linear()
            (w2): Linear()
            (w3): Linear()
       

In [87]:
torch.bfloat16

torch.bfloat16

In [88]:
# mla(embedded_prompt)

```
Big model:
{
    "vocab_size": 129280,
    "dim": 7168,
    "inter_dim": 18432,
    "moe_inter_dim": 2048,
    "n_layers": 61,
    "n_dense_layers": 3,
    "n_heads": 128,
    "n_routed_experts": 256,
    "n_shared_experts": 1,
    "n_activated_experts": 8,
    "n_expert_groups": 8,
    "n_limited_groups": 4,
    "route_scale": 2.5,
    "score_func": "sigmoid",
    "q_lora_rank": 1536,
    "kv_lora_rank": 512,
    "qk_nope_head_dim": 128,
    "qk_rope_head_dim": 64,
    "v_head_dim": 128,
    "dtype": "fp8"
}

16B Model
{
    "vocab_size": 102400,
    "dim": 2048,
    "inter_dim": 10944,
    "moe_inter_dim": 1408,
    "n_layers": 27,
    "n_dense_layers": 1,
    "n_heads": 16,
    "n_routed_experts": 64,
    "n_shared_experts": 2,
    "n_activated_experts": 6,
    "route_scale": 1.0,
    "q_lora_rank": 0,
    "kv_lora_rank": 512,
    "qk_nope_head_dim": 128,
    "qk_rope_head_dim": 64,
    "v_head_dim": 128,
    "mscale": 0.707
}
```