In [1]:
from sentence_transformers import SentenceTransformer

# Download from the ðŸ¤— Hub
emb_model = SentenceTransformer("google/embeddinggemma-300m")
model = emb_model[0].auto_model
model

W1116 10:13:28.884000 2052 site-packages\torch\distributed\elastic\multiprocessing\redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.


Gemma3TextModel(
  (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 768, padding_idx=0)
  (layers): ModuleList(
    (0-23): 24 x Gemma3DecoderLayer(
      (self_attn): Gemma3Attention(
        (q_proj): Linear(in_features=768, out_features=768, bias=False)
        (k_proj): Linear(in_features=768, out_features=256, bias=False)
        (v_proj): Linear(in_features=768, out_features=256, bias=False)
        (o_proj): Linear(in_features=768, out_features=768, bias=False)
        (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
        (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
      )
      (mlp): Gemma3MLP(
        (gate_proj): Linear(in_features=768, out_features=1152, bias=False)
        (up_proj): Linear(in_features=768, out_features=1152, bias=False)
        (down_proj): Linear(in_features=1152, out_features=768, bias=False)
        (act_fn): PytorchGELUTanh()
      )
      (input_layernorm): Gemma3RMSNorm((768,), eps=1e-06)
      (post_attention_layernorm): Gemma3RMSNorm((768,), eps=

### Save the norm and embedding dim

In [2]:
import math
import os
import torch
import numpy as np
from tqdm import tqdm

norm = model._modules["norm"]
os.makedirs("params", exist_ok=True)

with open(f"params/norm.bin", "wb")as f:
    f.write(norm.weight.detach().cpu().numpy().astype(np.float32).tobytes())


embed_tokens = model._modules["embed_tokens"].weight.detach()
embed_tokens_flat = embed_tokens.flatten()

os.makedirs("params/embed_tokens", exist_ok=True)

num_chunks = 16
chunks = torch.chunk(embed_tokens_flat, num_chunks)

for idx, chunk in enumerate(chunks):
    # Convert to float32 numpy array
    np_chunk = chunk.cpu().numpy().astype('float32')
    
    # Write raw binary
    with open(f"params/embed_tokens/part_{idx}.bin", "wb") as f:
        f.write(np_chunk.tobytes())
    
    print(f"Saved chunk {idx} with {np_chunk.size} weights")

Saved chunk 0 with 12582912 weights
Saved chunk 1 with 12582912 weights
Saved chunk 2 with 12582912 weights
Saved chunk 3 with 12582912 weights
Saved chunk 4 with 12582912 weights
Saved chunk 5 with 12582912 weights
Saved chunk 6 with 12582912 weights
Saved chunk 7 with 12582912 weights
Saved chunk 8 with 12582912 weights
Saved chunk 9 with 12582912 weights
Saved chunk 10 with 12582912 weights
Saved chunk 11 with 12582912 weights
Saved chunk 12 with 12582912 weights
Saved chunk 13 with 12582912 weights
Saved chunk 14 with 12582912 weights
Saved chunk 15 with 12582912 weights


### Save the layers

In [None]:
import os
import numpy as np
from tqdm import tqdm
for idx, layer in tqdm(enumerate(model._modules["layers"])):
    self_attn = layer.self_attn
    mlp = layer.mlp
    input_layernorm = layer.input_layernorm
    post_attention_layernorm = layer.post_attention_layernorm
    pre_feedforward_layernorm = layer.pre_feedforward_layernorm
    post_feedforward_layernorm = layer.post_feedforward_layernorm
    
    os.makedirs(f"params/layer_{idx}", exist_ok = True)
    
    # ================================================================ GQA =====================================================
    with open(f"params/layer_{idx}/self_attn_q_proj.bin", "wb")as f:
        f.write(self_attn.q_proj.weight.detach().flatten().cpu().numpy().astype(np.float32).tobytes())
        
    with open(f"params/layer_{idx}/self_attn_k_proj.bin", "wb")as f:
        f.write(self_attn.k_proj.weight.detach().flatten().cpu().numpy().astype(np.float32).tobytes())
        
    with open(f"params/layer_{idx}/self_attn_v_proj.bin", "wb")as f:
        f.write(self_attn.v_proj.weight.detach().flatten().cpu().numpy().astype(np.float32).tobytes())
    with open(f"params/layer_{idx}/self_attn_o_proj.bin", "wb")as f:
        f.write(self_attn.o_proj.weight.detach().flatten().cpu().numpy().astype(np.float32).tobytes())
        
    with open(f"params/layer_{idx}/self_attn_q_norm.bin", "wb")as f:
        f.write(self_attn.q_norm.weight.detach().flatten().cpu().numpy().astype(np.float32).tobytes())
    with open(f"params/layer_{idx}/self_attn_k_norm.bin", "wb")as f:
        f.write(self_attn.k_norm.weight.detach().flatten().cpu().numpy().astype(np.float32).tobytes())
        
    
    # =============================================================== MLP ======================================================
    with open(f"params/layer_{idx}/mlp_gate_proj.bin", "wb")as f:
        f.write(mlp.gate_proj.weight.detach().cpu().flatten().numpy().astype(np.float32).tobytes())
    with open(f"params/layer_{idx}/mlp_up_proj.bin", "wb")as f:
        f.write(mlp.up_proj.weight.detach().cpu().flatten().numpy().astype(np.float32).tobytes())
    with open(f"params/layer_{idx}/mlp_down_proj.bin", "wb")as f:
        f.write(mlp.down_proj.weight.detach().cpu().flatten().numpy().astype(np.float32).tobytes())
        
    # ================================================================ RMS =====================================================
    with open(f"params/layer_{idx}/input_layernorm.bin", "wb")as f:
        f.write(input_layernorm.weight.detach().cpu().flatten().numpy().astype(np.float32).tobytes())
    with open(f"params/layer_{idx}/post_attention_layernorm.bin", "wb")as f:
        f.write(post_attention_layernorm.weight.detach().cpu().flatten().numpy().astype(np.float32).tobytes())
    with open(f"params/layer_{idx}/pre_feedforward_layernorm.bin", "wb")as f:
        f.write(pre_feedforward_layernorm.weight.detach().cpu().flatten().numpy().astype(np.float32).tobytes())  
    with open(f"params/layer_{idx}/post_feedforward_layernorm.bin", "wb")as f:
        f.write(post_feedforward_layernorm.weight.detach().cpu().flatten().numpy().astype(np.float32).tobytes())
        
    # print(vars(layer))

### Write the weights of the linear modules (module doesn't contain Bias btw)

In [None]:
dense1 = emb_model[2].linear
dense2 = emb_model[3].linear

with open(f"params/dense_1.bin", "wb")as f:
    f.write(dense1.weight.detach().flatten().cpu().numpy().astype(np.float32).tobytes())
        
with open(f"params/dense_2.bin", "wb")as f:
    f.write(dense2.weight.detach().flatten().cpu().numpy().astype(np.float32).tobytes())