In [13]:
from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained("facebook/MobileLLM-125M", trust_remote_code=True)

from flashml.legacy import inspect_model

inspect_model(model )

Some weights of the model checkpoint at facebook/MobileLLM-125M were not used when initializing MobileLLMForCausalLM: ['lm_head.weight']
- This IS expected if you are initializing MobileLLMForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MobileLLMForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [12]:
print(model)

MobileLLMForCausalLM(
  (model): MobileLLMModel(
    (embed_tokens): Embedding(32000, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
)


In [14]:
model._modules['model'].layers[0].self_attn.rotary_emb
rotary_emb = model.model.layers[0].self_attn.rotary_emb
print(type(rotary_emb))

<class 'transformers_modules.facebook.MobileLLM-125M.f6820a829b347aaa6674a705c82be3db47b1f9ee.modeling_mobilellm.LlamaRotaryEmbedding'>


In [None]:
import math
import os
import torch
import numpy as np
from tqdm import tqdm
model._modules["model"]

norm = model._modules["model"].norm
with open(f"params/norm.bin", "wb")as f:
    f.write(norm.weight.detach().cpu().numpy().astype(np.float32).tobytes())


embed_tokens = model._modules["model"].embed_tokens.weight.detach()
embed_tokens_flat = embed_tokens.flatten()

os.makedirs("params/embed_tokens", exist_ok=True)

num_chunks = 8
chunks = torch.chunk(embed_tokens_flat, num_chunks)

for idx, chunk in enumerate(chunks):
    # Convert to float32 numpy array
    np_chunk = chunk.cpu().numpy().astype('float32')
    
    # Write raw binary
    with open(f"params/embed_tokens/part_{idx}.bin", "wb") as f:
        f.write(np_chunk.tobytes())
    
    print(f"Saved chunk {idx} with {np_chunk.size} weights")

In [None]:
import os
import numpy as np
from tqdm import tqdm
for idx, layer in tqdm(enumerate(model._modules["model"].layers)):
    self_attn = layer.self_attn
    mlp = layer.mlp
    input_layernorm = layer.input_layernorm
    post_attention_layernorm = layer.post_attention_layernorm
    
    os.makedirs(f"params/layer_{idx}", exist_ok = True)
    
    # ================================================================ GQA =====================================================
    with open(f"params/layer_{idx}/self_attn_q_proj.bin", "wb")as f:
        f.write(self_attn.q_proj.weight.detach().flatten().cpu().numpy().astype(np.float32).tobytes())
        
    with open(f"params/layer_{idx}/self_attn_k_proj.bin", "wb")as f:
        f.write(self_attn.k_proj.weight.detach().flatten().cpu().numpy().astype(np.float32).tobytes())
        
    with open(f"params/layer_{idx}/self_attn_v_proj.bin", "wb")as f:
        f.write(self_attn.v_proj.weight.detach().flatten().cpu().numpy().astype(np.float32).tobytes())
    with open(f"params/layer_{idx}/self_attn_o_proj.bin", "wb")as f:
        f.write(self_attn.o_proj.weight.detach().flatten().cpu().numpy().astype(np.float32).tobytes())
        
    
    # =============================================================== MLP ======================================================
    with open(f"params/layer_{idx}/mlp_gate_proj.bin", "wb")as f:
        f.write(mlp.gate_proj.weight.detach().cpu().flatten().numpy().astype(np.float32).tobytes())
    with open(f"params/layer_{idx}/mlp_up_proj.bin", "wb")as f:
        f.write(mlp.up_proj.weight.detach().cpu().flatten().numpy().astype(np.float32).tobytes())
    with open(f"params/layer_{idx}/mlp_down_proj.bin", "wb")as f:
        f.write(mlp.down_proj.weight.detach().cpu().flatten().numpy().astype(np.float32).tobytes())
        
    # ================================================================ RMS =====================================================
    with open(f"params/layer_{idx}/input_layernorm.bin", "wb")as f:
        f.write(input_layernorm.weight.detach().cpu().flatten().numpy().astype(np.float32).tobytes())
    with open(f"params/layer_{idx}/post_attention_layernorm.bin", "wb")as f:
        f.write(post_attention_layernorm.weight.detach().cpu().flatten().numpy().astype(np.float32).tobytes())

