# üî¨ Quantized Weight Sensitivity Analysis

This notebook systematically tests how much we can perturb the quantized `lm_head` weights of a 4-bit Llama 3 8B model before its output collapses.

### Key Concepts
- The `lm_head` is the final output layer that converts internal representations into word probabilities
- In a 4-bit quantized model, weights are packed as `uint32` integers (8 x 4-bit values per uint32)
- The real weight value is computed as: `real_weight = (packed_uint32 * scale) + bias`
- We subtract increasing values from the raw `uint32` weights to find the breaking point

In [None]:
# Setup: Load the model and store original weights
import mlx.core as mx
import mlx.utils as mux
from mlx_lm import load, generate

print("Loading Llama 3 8B (4-bit quantized)...")
model_id = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
model, tokenizer = load(model_id)

# Store originals for restoring between tests
original_weight = model.lm_head.weight
original_scales = model.lm_head.scales
original_biases = model.lm_head.biases

print(f"lm_head.weight: shape={original_weight.shape}, dtype={original_weight.dtype}")
print(f"lm_head.scales: shape={original_scales.shape}, dtype={original_scales.dtype}")
print(f"lm_head.biases: shape={original_biases.shape}, dtype={original_biases.dtype}")
print(f"\nFirst 10 original weight values: {original_weight.reshape(-1)[:10].tolist()}")
print("\n‚úÖ Model loaded. Ready for experiments.")

In [None]:
# Helper function: Perturb weights and test
def sensitivity_test(shift_amount):
    """Apply a shift to lm_head weights, generate a response, then restore originals."""
    # Reset to originals
    model.lm_head.weight = original_weight
    model.lm_head.scales = original_scales
    model.lm_head.biases = original_biases
    
    # Apply shift (cast to int64 to avoid uint32 underflow)
    model.lm_head.weight = (original_weight.astype(mx.int64) - shift_amount).astype(mx.uint32)
    
    # Generate
    prompt = tokenizer.apply_chat_template(
        [{"role": "user", "content": "What is 2+2?"}],
        tokenize=False, add_generation_prompt=True
    )
    response = generate(model, tokenizer, prompt=prompt, verbose=False, max_tokens=10)
    
    # Print first 5 modified weight values for comparison
    modified_first_5 = model.lm_head.weight.reshape(-1)[:5].tolist()
    
    return response, modified_first_5

print("Helper function ready. Running experiments below...")

---
## Experiment 1: Baseline (No Shift)
The unmodified model should respond correctly.

In [None]:
response, vals = sensitivity_test(0)
print(f"Shift: 0")
print(f"Weights: {vals}")
print(f"Response: '{response}'")
print("Status: HEALTHY ‚úÖ")

---
## Experiment 2: Shift = 1 (Minimal Perturbation)

In [None]:
response, vals = sensitivity_test(1)
print(f"Shift: 1")
print(f"Weights: {vals}")
print(f"Response: '{response}'")
print("Status: HEALTHY ‚úÖ ‚Äî Model is robust to single-digit shifts")

---
## Experiment 3: Shift = 1,000,000 (Last Stable Point)

In [None]:
response, vals = sensitivity_test(1_000_000)
print(f"Shift: 1,000,000")
print(f"Weights: {vals}")
print(f"Response: '{response}'")
print("Status: HEALTHY ‚úÖ ‚Äî Still coherent at 1M shift")

---
## Experiment 4: Shift = 1,000,001 (First Signs of Degradation)

In [None]:
response, vals = sensitivity_test(1_000_001)
print(f"Shift: 1,000,001")
print(f"Weights: {vals}")
print(f"Response: '{response}'")
print("Status: WOBBLING ‚ö†Ô∏è ‚Äî Response is slightly different but still coherent")

---
## Experiment 5: Shift = 1,000,003 (Dying Mid-Sentence)

In [None]:
response, vals = sensitivity_test(1_000_003)
print(f"Shift: 1,000,003")
print(f"Weights: {vals}")
print(f"Response: '{response}'")
print("Status: DYING üü† ‚Äî Correct start, then collapses into repeated tokens")

---
## Experiment 6: Shift = 1,000,004 (Collapse After First Words)

In [None]:
response, vals = sensitivity_test(1_000_004)
print(f"Shift: 1,000,004")
print(f"Weights: {vals}")
print(f"Response: '{response}'")
print("Status: COLLAPSING üî¥ ‚Äî Only first words survive, then symbol soup")

---
## Experiment 7: Shift = 1,000,005 (Almost Fully Dead)

In [None]:
response, vals = sensitivity_test(1_000_005)
print(f"Shift: 1,000,005")
print(f"Weights: {vals}")
print(f"Response: '{response}'")
print("Status: DEAD üíÄ ‚Äî Collapses after 2 words")

---
## Experiment 8: Shift = 1,000,010 (Instant Gibberish)

In [None]:
response, vals = sensitivity_test(1_000_010)
print(f"Shift: 1,000,010")
print(f"Weights: {vals}")
print(f"Response: '{response}'")
print("Status: GONE ‚ò†Ô∏è ‚Äî No coherent output from the start")

---
## Experiment 9: Shift = 10,000,000 (Multilingual Chaos)

In [None]:
response, vals = sensitivity_test(10_000_000)
print(f"Shift: 10,000,000")
print(f"Weights: {vals}")
print(f"Response: '{response}'")
print("Status: GONE ‚ò†Ô∏è ‚Äî Outputs random Chinese, code tokens, multilingual gibberish")

---
## Experiment 10: Shift = 100,000,000 (Stuck in a Loop)

In [None]:
response, vals = sensitivity_test(100_000_000)
print(f"Shift: 100,000,000")
print(f"Weights: {vals}")
print(f"Response: '{response}'")
print("Status: GONE ‚ò†Ô∏è ‚Äî So broken it gets stuck repeating one broken token")

---
## Cleanup: Restore the Original Model

In [None]:
# Restore original weights
model.lm_head.weight = original_weight
model.lm_head.scales = original_scales
model.lm_head.biases = original_biases

# Verify restoration
prompt = tokenizer.apply_chat_template(
    [{"role": "user", "content": "What is 2+2?"}],
    tokenize=False, add_generation_prompt=True
)
response = generate(model, tokenizer, prompt=prompt, verbose=False, max_tokens=10)
print(f"Restored model response: '{response}'")
print("\n‚úÖ Original weights restored. Model is healthy again!")