# ðŸ§  Playing with Weights: Modifying the Math

An AI model isn't magic; under the hood, it is just a massive collection of numbers arranged in grids called **Tensors** (or Matrices).

Let's peel back the curtain, extract the raw mathematical weights of Llama 3, look at them, and then intentionally break the model by mathematically destroying its vocabulary!

In [1]:
import mlx.core as mx
import mlx.utils as mux
from mlx_lm import load

# 1. Load the model into memory
print("Loading the model's brain...")
model_id = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
model, tokenizer = load(model_id)

# The MLX model has two top-level parts:
#   model.model  -> The transformer body (embeddings, attention layers, etc.)
#   model.lm_head -> The final output projection (turns numbers back into words)
print("\nTop-level components:", list(model.children().keys()))

# Let's flatten all the parameters into a list so we can see them
flat_params = mux.tree_flatten(model.parameters())

print("\n--- First 10 Weight Matrices in the Model ---")
for i, (name, parameter) in enumerate(flat_params[:10]):
    shape = parameter.shape if hasattr(parameter, 'shape') else 'N/A'
    dtype = parameter.dtype if hasattr(parameter, 'dtype') else 'N/A'
    print(f"  {i: >2}. {name: <50} Shape: {str(shape): <20} Dtype: {dtype}")

print(f"\n... and {len(flat_params)} total weight matrices in the model!")

Loading the model's brain...


Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]


Top-level components: ['model', 'lm_head']

--- First 10 Weight Matrices in the Model ---
   0. model.embed_tokens.weight                          Shape: (128256, 4096)       Dtype: mlx.core.float16
   1. model.layers.0.self_attn.q_proj.weight             Shape: (4096, 512)          Dtype: mlx.core.uint32
   2. model.layers.0.self_attn.q_proj.scales             Shape: (4096, 64)           Dtype: mlx.core.float16
   3. model.layers.0.self_attn.q_proj.biases             Shape: (4096, 64)           Dtype: mlx.core.float16
   4. model.layers.0.self_attn.k_proj.weight             Shape: (1024, 512)          Dtype: mlx.core.uint32
   5. model.layers.0.self_attn.k_proj.scales             Shape: (1024, 64)           Dtype: mlx.core.float16
   6. model.layers.0.self_attn.k_proj.biases             Shape: (1024, 64)           Dtype: mlx.core.float16
   7. model.layers.0.self_attn.v_proj.weight             Shape: (1024, 512)          Dtype: mlx.core.uint32
   8. model.layers.0.self_attn.v_proj.sc

### Understanding the Layout

Notice two things above:
- `model.embed_tokens.weight` has dtype `float16` and shape `[128256, 4096]` (regular floats!)
- Most other layers have dtype `uint32` (these are **quantized** / compressed to save memory)

The embedding table maps each of the 128,256 tokens to a vector of 4,096 numbers.

Let's look at the exact numbers that define the word "Apple"!

In [2]:
word = "Apple"
token_id = tokenizer.encode(word)[0]

# The embedding table is inside model.model (not directly on model)
apple_weights = model.model.embed_tokens.weight[token_id]

print(f"Token ID for '{word}': {token_id}")
print(f"Dtype: {apple_weights.dtype}")
print(f"\nThe mathematical representation of '{word}':")
print(apple_weights)
print(f"\n(This array contains exactly {apple_weights.shape[0]} floating-point numbers)")

# Bonus: Let's compare it to another word to see they are completely different numbers
word2 = "Banana"
token_id2 = tokenizer.encode(word2)[0]
banana_weights = model.model.embed_tokens.weight[token_id2]

print(f"\n--- Comparing '{word}' vs '{word2}' ---")
print(f"First 5 values of 'Apple':  {apple_weights[:5].tolist()}")
print(f"First 5 values of 'Banana': {banana_weights[:5].tolist()}")
print("\nCompletely different numbers = completely different meanings in the model's brain!")

Token ID for 'Apple': 27665
Dtype: mlx.core.float16

The mathematical representation of 'Apple':
array([-0.0201416, -0.000938416, -0.0100708, ..., -0.000482559, 0.00113678, 0.00567627], dtype=float16)

(This array contains exactly 4096 floating-point numbers)

--- Comparing 'Apple' vs 'Banana' ---
First 5 values of 'Apple':  [-0.0201416015625, -0.00093841552734375, -0.01007080078125, -0.00102996826171875, -0.01190185546875]
First 5 values of 'Banana': [-0.006072998046875, 0.01361083984375, 0.0003070831298828125, 0.005859375, -0.0023040771484375]

Completely different numbers = completely different meanings in the model's brain!


### Doing Brain Surgery

Since the embedding layer uses regular `float16` numbers (not quantized `uint32`), we can safely do math on it!

Let's zero out the entire embedding table. This means *every single word* the model knows will become a meaningless zero-vector. The model will have no idea what any word means.

In [3]:
# from mlx_lm import generate

# def test_model(test_model_instance, label):
#     messages = [
#         {"role": "system", "content": "You are a helpful, smart AI assistant. Answer the user's questions clearly & concisely."},
#         {"role": "user", "content": "What is 2+2?"}
#     ]
#     prompt = tokenizer.apply_chat_template(
#         messages,
#         tokenize=False,
#         add_generation_prompt=True,
#         max_tokens=20,
#         verbose=True

#     )
#     print(f"\nTesting {label}:")
#     response = generate(test_model_instance, tokenizer, prompt=prompt, verbose=False, max_tokens=10)
#     print(f"Response: '{response}'")

# # 1. Test it while it's healthy
# test_model(model, "Healthy Brain")

# # 2. Perform brain surgery: Zero out the ENTIRE embedding table
# # This erases what every single word means!
# print("\nPerforming brain surgery... zeroing out all 128,256 word embeddings...")
# original_embeddings = model.model.embed_tokens.weight
# print(f"Original embedding dtype: {original_embeddings.dtype}, shape: {original_embeddings.shape}")

# # Create a zero matrix of the same shape and dtype
# model.model.embed_tokens.weight = mx.zeros_like(original_embeddings)

# # 3. Test the broken model
# test_model(model, "Brain after erasing all word meanings")

# print("\nThe model can no longer understand words! It outputs garbage or nothing.")
# print("This proves that the embedding weights ARE the model's vocabulary knowledge.")

In [4]:
import mlx.core as mx

# ===== 1. Store the originals so we can restore them later =====
original_weight = model.lm_head.weight
original_scales = model.lm_head.scales
original_biases = model.lm_head.biases

# ===== 2. Print first 10 values of each (before surgery) =====
print("=== ORIGINAL lm_head values ===")
print(f"weight  (dtype: {original_weight.dtype}, shape: {original_weight.shape})")
print(f"  First 10: {original_weight.reshape(-1)[:10].tolist()}\n")

print(f"scales  (dtype: {original_scales.dtype}, shape: {original_scales.shape})")
print(f"  First 10: {original_scales.reshape(-1)[:10].tolist()}\n")

print(f"biases  (dtype: {original_biases.dtype}, shape: {original_biases.shape})")
print(f"  First 10: {original_biases.reshape(-1)[:10].tolist()}\n")

=== ORIGINAL lm_head values ===
weight  (dtype: mlx.core.uint32, shape: (128256, 512))
  First 10: [1351726744, 1180922772, 658999370, 2203367800, 1717855591, 716830345, 2186301553, 356746582, 1857407783, 1819494979]

scales  (dtype: mlx.core.float16, shape: (128256, 64))
  First 10: [0.004360198974609375, 0.0035076141357421875, 0.0073394775390625, 0.0036296844482421875, 0.00487518310546875, 0.00571441650390625, 0.004215240478515625, 0.004116058349609375, 0.0035724639892578125, 0.0041351318359375]

biases  (dtype: mlx.core.float16, shape: (128256, 64))
  First 10: [-0.02392578125, -0.0240478515625, -0.045654296875, -0.0283203125, -0.054931640625, -0.04345703125, -0.025634765625, -0.027587890625, -0.0281982421875, -0.031494140625]



In [22]:


# ===== 3. Replace with random uint32 values =====
# Pick any uint32 number you like! Change 42 to anything between 0 and 4294967295
custom_value = 1000004

model.lm_head.weight = (original_weight.astype(mx.int64) - custom_value).astype(mx.uint32)
model.lm_head.scales = original_scales
model.lm_head.biases = original_biases

# ===== 4. Print first 10 values after surgery =====
print(f"=== MODIFIED lm_head (all weights reduced by {custom_value}) ===")
print(f"weight  First 10: {model.lm_head.weight.reshape(-1)[:10].tolist()}")
print(f"scales  First 10: {model.lm_head.scales.reshape(-1)[:10].tolist()}")
print(f"biases  First 10: {model.lm_head.biases.reshape(-1)[:10].tolist()}\n")

# ===== 5. Test the broken model =====
from mlx_lm import generate
messages = [
    {"role": "system", "content": "You are a helpful, smart AI assistant. Answer the user's questions clearly & concisely."},
    {"role": "user", "content": "What is 2+2?"}
]
prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False, 
    add_generation_prompt=True,
    max_tokens=20
)
print(f"Generating with lm_head weights shift = {custom_value}...")
response = generate(model, tokenizer, prompt=prompt, verbose=False, max_tokens=10)
print(f"Response: '{response}'")


=== MODIFIED lm_head (all weights reduced by 1000004) ===
weight  First 10: [1350726740, 1179922768, 657999366, 2202367796, 1716855587, 715830341, 2185301549, 355746578, 1856407779, 1818494975]
scales  First 10: [0.004360198974609375, 0.0035076141357421875, 0.0073394775390625, 0.0036296844482421875, 0.00487518310546875, 0.00571441650390625, 0.004215240478515625, 0.004116058349609375, 0.0035724639892578125, 0.0041351318359375]
biases  First 10: [-0.02392578125, -0.0240478515625, -0.045654296875, -0.0283203125, -0.054931640625, -0.04345703125, -0.025634765625, -0.027587890625, -0.0281982421875, -0.031494140625]

Generating with lm_head weights shift = 1000004...
Response: 'The answer to -â€¦andâ€¦and//{{â€”aigtizmet'
