# 01: Quick Validation - Does L0+L31 Work on Llama?

**Goal:** Test if GPT-2's gateway layer pattern transfers to Llama-2-7B

**Time:** ~30 minutes on T4

**Key question:** Does protecting input (L0) + output (L31) layers reduce multilingual disparity?

In [None]:
# Setup
!pip install -q transformers accelerate

import sys
sys.path.append('..')

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.disparity import measure_disparity, DEFAULT_TEXTS
from src.quantize import simulate_int4, save_state, restore_model, get_num_layers
from src.utils import print_results

In [None]:
# Load model (use NousResearch version - no approval needed)
MODEL_ID = "NousResearch/Llama-2-7b-hf"

print(f"Loading {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto",
)
model.eval()

num_layers = get_num_layers(model)
print(f"Model loaded: {num_layers} layers")

In [None]:
# Test texts (subset for speed)
TEXTS = {
    'en': DEFAULT_TEXTS['en'],
    'he': DEFAULT_TEXTS['he'],
    'ar': DEFAULT_TEXTS['ar'],
    'zh': DEFAULT_TEXTS['zh'],
}

# Save state for restoration
state = save_state(model)

# Baseline perplexity
from src.disparity import perplexity
baseline = {lang: perplexity(model, tokenizer, text) for lang, text in TEXTS.items()}
print("Baseline PPL (FP16):")
for lang, ppl in sorted(baseline.items(), key=lambda x: x[1]):
    print(f"  {lang}: {ppl:.1f}")

In [None]:
# Test configurations
# GPT-2: L0+L11 (12 layers) → Llama: L0+L31 (32 layers)
# GPT-2: L0+L9+L11 → Llama: L0+L24+L31 (75% = 24)

CONFIGS = {
    "no_protection": [],
    "L0+L31 (GPT-2 equivalent)": [0, 31],
    "L0+L24+L31 (with 75% layer)": [0, 24, 31],
    "L0+L16+L31 (with 50% layer)": [0, 16, 31],
}

results = {}
for config_name, protect in CONFIGS.items():
    print(f"\nTesting: {config_name}")
    restore_model(model, state)
    simulate_int4(model, exclude=set(protect))
    
    metrics = measure_disparity(model, tokenizer, TEXTS, baseline)
    results[config_name] = metrics
    
    print(f"  Avg disparity: {metrics['avg_disparity']:.2f}x")

In [None]:
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)

print(f"\n{'Config':<35} {'Avg Disparity':>15}")
print("-" * 52)
for config, metrics in results.items():
    print(f"{config:<35} {metrics['avg_disparity']:>14.2f}x")

# Interpretation
best_config = min(results.items(), key=lambda x: x[1]['avg_disparity'])
print(f"\nBest config: {best_config[0]}")
print(f"Disparity: {best_config[1]['avg_disparity']:.2f}x")

if best_config[1]['avg_disparity'] < 2.0:
    print("\n✓ PATTERN TRANSFERS: Gateway layer protection works on Llama!")
else:
    print("\n✗ PATTERN DIFFERS: Need full layer sweep to find Llama's critical layers")

In [None]:
# Per-language breakdown for best config
print_results(best_config[1], f"Best Config: {best_config[0]}")