# 02: Full Layer Sweep on Llama-2-7B

**Goal:** Find Llama's critical layers for multilingual fairness

**Time:** ~2 hours on T4

**Method:** Protect each layer individually, measure disparity

In [None]:
!pip install -q transformers accelerate tqdm

import sys
sys.path.append('..')

import torch
import json
from datetime import datetime
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.disparity import full_sweep, DEFAULT_TEXTS
from src.quantize import get_num_layers
from src.utils import save_results

In [None]:
MODEL_ID = "NousResearch/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto",
)
model.eval()

num_layers = get_num_layers(model)
print(f"Sweeping {num_layers} layers...")

In [None]:
# Use subset of languages for speed
TEXTS = {k: v for k, v in DEFAULT_TEXTS.items() if k in ['en', 'he', 'ar', 'zh']}

# Full sweep
results = full_sweep(model, tokenizer, TEXTS, verbose=True)

In [None]:
# Results analysis
print("\n" + "=" * 60)
print("LAYER CRITICALITY RANKING")
print("=" * 60)

print(f"\n{'Rank':<6} {'Layer':<8} {'Disparity':>12} {'Position':>10}")
print("-" * 40)

for rank, (layer, disp, _) in enumerate(results[:10], 1):
    position = layer / (num_layers - 1) * 100
    print(f"{rank:<6} L{layer:<7} {disp:>11.2f}x {position:>9.0f}%")

# Top-3 critical layers
top3 = [r[0] for r in results[:3]]
print(f"\nTop-3 critical layers: {top3}")

# Compare with GPT-2 equivalent
gpt2_equiv = [0, int(num_layers * 0.75), num_layers - 1]
print(f"GPT-2 equivalent (L0+75%+last): {gpt2_equiv}")

In [None]:
# Save results
output = {
    'model': MODEL_ID,
    'timestamp': datetime.now().isoformat(),
    'num_layers': num_layers,
    'sweep_results': [(l, d) for l, d, _ in results],
    'top_3': top3,
    'gpt2_equivalent': gpt2_equiv,
}

save_results(output, 'layer_sweep_results.json')
print("Results saved to layer_sweep_results.json")

In [None]:
# Visualization
import matplotlib.pyplot as plt

layers = [r[0] for r in results]
disparities = [r[1] for r in results]

# Sort by layer index for plotting
sorted_data = sorted(zip(layers, disparities))
layers_sorted = [x[0] for x in sorted_data]
disp_sorted = [x[1] for x in sorted_data]

plt.figure(figsize=(12, 5))
plt.bar(layers_sorted, disp_sorted)
plt.axhline(y=1.0, color='r', linestyle='--', label='Parity (1.0x)')
plt.xlabel('Layer Index')
plt.ylabel('Disparity (lower = more critical)')
plt.title(f'Layer Criticality: {MODEL_ID}')
plt.legend()
plt.tight_layout()
plt.savefig('layer_sweep_plot.png')
plt.show()