# ðŸ§  Playing with Weights: Modifying the Math

An AI model isn't magic; under the hood, it is just a massive collection of numbers arranged in grids called **Tensors** (or Matrices).

Let's peel back the curtain, extract the raw mathematical weights of Llama 3, look at them, and then intentionally break the model by mathematically destroying its vocabulary!

In [None]:
import mlx.core as mx
import mlx.utils as mux
from mlx_lm import load

# 1. Load the model into memory
print("Loading the model's brain...")
model_id = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
model, tokenizer = load(model_id)

# The MLX model has two top-level parts:
#   model.model  -> The transformer body (embeddings, attention layers, etc.)
#   model.lm_head -> The final output projection (turns numbers back into words)
print("\nTop-level components:", list(model.children().keys()))

# Let's flatten all the parameters into a list so we can see them
flat_params = mux.tree_flatten(model.parameters())

print("\n--- First 10 Weight Matrices in the Model ---")
for i, (name, parameter) in enumerate(flat_params[:10]):
    shape = parameter.shape if hasattr(parameter, 'shape') else 'N/A'
    dtype = parameter.dtype if hasattr(parameter, 'dtype') else 'N/A'
    print(f"  {i: >2}. {name: <50} Shape: {str(shape): <20} Dtype: {dtype}")

print(f"\n... and {len(flat_params)} total weight matrices in the model!")

### Understanding the Layout

Notice two things above:
- `model.embed_tokens.weight` has dtype `float16` and shape `[128256, 4096]` (regular floats!)
- Most other layers have dtype `uint32` (these are **quantized** / compressed to save memory)

The embedding table maps each of the 128,256 tokens to a vector of 4,096 numbers.

Let's look at the exact numbers that define the word "Apple"!

In [None]:
word = "Apple"
token_id = tokenizer.encode(word)[0]

# The embedding table is inside model.model (not directly on model)
apple_weights = model.model.embed_tokens.weight[token_id]

print(f"Token ID for '{word}': {token_id}")
print(f"Dtype: {apple_weights.dtype}")
print(f"\nThe mathematical representation of '{word}':")
print(apple_weights)
print(f"\n(This array contains exactly {apple_weights.shape[0]} floating-point numbers)")

# Bonus: Let's compare it to another word to see they are completely different numbers
word2 = "Banana"
token_id2 = tokenizer.encode(word2)[0]
banana_weights = model.model.embed_tokens.weight[token_id2]

print(f"\n--- Comparing '{word}' vs '{word2}' ---")
print(f"First 5 values of 'Apple':  {apple_weights[:5].tolist()}")
print(f"First 5 values of 'Banana': {banana_weights[:5].tolist()}")
print("\nCompletely different numbers = completely different meanings in the model's brain!")

### Doing Brain Surgery

Since the embedding layer uses regular `float16` numbers (not quantized `uint32`), we can safely do math on it!

Let's zero out the entire embedding table. This means *every single word* the model knows will become a meaningless zero-vector. The model will have no idea what any word means.

In [None]:
from mlx_lm import generate

def test_model(test_model_instance, label):
    prompt = tokenizer.apply_chat_template(
        [{"role": "user", "content": "Say hello!"}],
        tokenize=False,
        add_generation_prompt=True
    )
    print(f"\nTesting {label}:")
    response = generate(test_model_instance, tokenizer, prompt=prompt, verbose=False, max_tokens=10)
    print(f"Response: '{response}'")

# 1. Test it while it's healthy
test_model(model, "Healthy Brain")

# 2. Perform brain surgery: Zero out the ENTIRE embedding table
# This erases what every single word means!
print("\nPerforming brain surgery... zeroing out all 128,256 word embeddings...")
original_embeddings = model.model.embed_tokens.weight
print(f"Original embedding dtype: {original_embeddings.dtype}, shape: {original_embeddings.shape}")

# Create a zero matrix of the same shape and dtype
model.model.embed_tokens.weight = mx.zeros_like(original_embeddings)

# 3. Test the broken model
test_model(model, "Brain after erasing all word meanings")

print("\nThe model can no longer understand words! It outputs garbage or nothing.")
print("This proves that the embedding weights ARE the model's vocabulary knowledge.")