# EXP-004: Rate-Distortion Curve Validation

**Objective:** Validate that quantization degradation follows rate-distortion theory.

**Hypothesis H4:**
- Statement: Degradation follows D ∝ 2^{-B/2}
- Prediction: slope(log D vs B) ≈ -0.347 = -ln(2)/2
- Null: No systematic relationship between bits and degradation

**Theoretical Background:**

Shannon's rate-distortion theory predicts that for Gaussian sources:

```
D(R) = σ² · 2^{-2R}
```

where R is bits per sample. For quantization at B bits:

```
log(D) = log(σ²) - 2R·log(2) = C - B·ln(2)/2
```

Thus slope ≈ -0.347 per bit.

**Method:**
- Test multiple bit-widths: FP16, INT8, INT4, (INT3 if supported)
- Measure perplexity degradation at each level
- Fit log-linear model
- Compare empirical slope to theoretical prediction

**References:**
- Shannon (1948) "A Mathematical Theory of Communication"
- Cover & Thomas (2006) "Elements of Information Theory"
- Banner et al. (2019) "Post-Training 4-bit Quantization"

In [None]:
# @title Setup & Dependencies
!pip install -q transformers accelerate bitsandbytes scipy pandas matplotlib seaborn

import torch
import numpy as np
import pandas as pd
from scipy import stats
from scipy.optimize import curve_fit
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import matplotlib.pyplot as plt
import seaborn as sns
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
import json
import warnings
import gc
warnings.filterwarnings('ignore')

# Reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Theoretical constant
THEORETICAL_SLOPE = -np.log(2) / 2  # ≈ -0.347

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"\nTheoretical slope (Shannon): {THEORETICAL_SLOPE:.4f}")

## 1. Experimental Configuration

In [None]:
# @title Configuration

@dataclass
class ExperimentConfig:
    """Experiment configuration."""
    model_name: str = "bigscience/bloom-560m"
    max_length: int = 512
    n_samples: int = 3  # Fewer samples, more bit-widths
    seed: int = 42
    
config = ExperimentConfig()

# Bit-width configurations
BIT_CONFIGS = {
    "fp16": {
        "bits": 16,
        "config": None,  # No quantization
        "description": "FP16 baseline"
    },
    "int8": {
        "bits": 8,
        "config": BitsAndBytesConfig(
            load_in_8bit=True,
        ),
        "description": "INT8 (LLM.int8())"
    },
    "int4_nf4": {
        "bits": 4,
        "config": BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
        ),
        "description": "INT4 (NormalFloat4)"
    },
    "int4_fp4": {
        "bits": 4,
        "config": BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="fp4",
            bnb_4bit_compute_dtype=torch.float16,
        ),
        "description": "INT4 (FP4)"
    },
}

# Test texts (representative samples)
TEST_TEXTS = {
    "en": [
        "The Earth is the third planet from the Sun and the only astronomical object known to harbor life. About 71 percent of Earth's surface is made up of water.",
        "Mathematics is an area of knowledge that includes topics of numbers, formulas, structures, shapes, spaces, and quantities.",
        "Climate change refers to long-term shifts in temperatures and weather patterns. Human activities have been the main driver.",
    ],
    "he": [
        "כדור הארץ הוא הפלנטה השלישית מהשמש והגוף האסטרונומי היחיד הידוע שמאכלס חיים. כ-71 אחוז משטח כדור הארץ מורכב ממים.",
        "מתמטיקה היא תחום ידע הכולל נושאים של מספרים, נוסחאות, מבנים, צורות, מרחבים וכמויות.",
        "שינויי אקלים מתייחסים לשינויים ארוכי טווח בטמפרטורות ובדפוסי מזג האוויר. פעילויות אנושיות היו המניע העיקרי.",
    ],
    "sw": [
        "Dunia ni sayari ya tatu kutoka Jua na kitu pekee cha angani kinachojulikana kuwa na uhai.",
        "Hesabu ni eneo la ujuzi linalojumuisha mada za nambari, fomula, miundo, maumbo, nafasi na kiasi.",
        "Mabadiliko ya hali ya hewa yanarejelea mabadiliko ya muda mrefu ya halijoto na mifumo ya hali ya hewa.",
    ],
}

print(f"Model: {config.model_name}")
print(f"Bit-widths: {[c['bits'] for c in BIT_CONFIGS.values()]}")
print(f"Languages: {list(TEST_TEXTS.keys())}")

## 2. Measurement Functions

In [None]:
# @title Core Functions

def compute_perplexity(model, tokenizer, text: str, max_length: int = 512) -> float:
    """
    Compute perplexity for causal language model.
    PPL = exp(mean(NLL))
    """
    encodings = tokenizer(
        text, 
        return_tensors="pt", 
        truncation=True, 
        max_length=max_length
    )
    input_ids = encodings.input_ids.to(model.device)
    
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss
    
    return torch.exp(loss).item()


def load_model_at_precision(model_name: str, quant_config: Optional[BitsAndBytesConfig]):
    """
    Load model at specified precision.
    """
    if quant_config is None:
        # FP16 baseline
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto",
        )
    else:
        # Quantized
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=quant_config,
            device_map="auto",
        )
    model.eval()
    return model


def clear_gpu_memory():
    """Clear GPU memory."""
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize()


print("✓ Functions defined")

## 3. Run Rate-Distortion Measurements

In [None]:
# @title Load tokenizer

tokenizer = AutoTokenizer.from_pretrained(config.model_name)
print(f"✓ Tokenizer loaded")

In [None]:
# @title Measure perplexity at each bit-width

results = []

for bit_name, bit_config in BIT_CONFIGS.items():
    print(f"\n{'='*60}")
    print(f"{bit_config['description']} ({bit_config['bits']} bits)")
    print(f"{'='*60}")
    
    # Load model
    print(f"Loading model...")
    model = load_model_at_precision(config.model_name, bit_config["config"])
    print(f"  Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")
    
    # Measure perplexity for each language and sample
    for lang, texts in TEST_TEXTS.items():
        print(f"\n  {lang}:")
        for i, text in enumerate(texts):
            ppl = compute_perplexity(model, tokenizer, text, config.max_length)
            results.append({
                "precision": bit_name,
                "bits": bit_config["bits"],
                "description": bit_config["description"],
                "lang": lang,
                "sample": i,
                "ppl": ppl,
            })
            print(f"    Sample {i}: PPL={ppl:.2f}")
    
    # Free memory
    del model
    clear_gpu_memory()
    print(f"\n  ✓ Memory freed: {torch.cuda.memory_allocated()/1e9:.2f} GB")

print(f"\n✓ All measurements complete")

## 4. Analysis

In [None]:
# @title Create DataFrame and compute degradation

df = pd.DataFrame(results)

# Get FP16 baseline
baseline = df[df["precision"] == "fp16"].groupby(["lang", "sample"])["ppl"].mean().to_dict()

# Compute degradation relative to FP16
df["ppl_baseline"] = df.apply(lambda r: baseline.get((r["lang"], r["sample"]), r["ppl"]), axis=1)
df["degradation"] = (df["ppl"] - df["ppl_baseline"]) / df["ppl_baseline"]
df["degradation"] = df["degradation"].clip(lower=0)  # Degradation can't be negative

# Log degradation for rate-distortion analysis
df["log_degradation"] = np.log(df["degradation"] + 1e-6)  # Add small constant to avoid log(0)

print("Data summary:")
print(df.groupby(["precision", "bits"])["degradation"].agg(["mean", "std"]).round(4))

In [None]:
# @title Rate-Distortion Analysis

print("\n" + "="*60)
print("RATE-DISTORTION ANALYSIS")
print("="*60)

# Aggregate by precision level
agg = df.groupby(["precision", "bits"]).agg({
    "ppl": "mean",
    "degradation": ["mean", "std"],
    "log_degradation": "mean",
}).round(4)

agg.columns = ["_".join(col).strip() if col[1] else col[0] for col in agg.columns.values]
agg = agg.reset_index()
agg = agg.sort_values("bits", ascending=False)

print("\nAggregated results:")
display(agg)

# Filter quantized only (exclude FP16 baseline)
quant_only = agg[agg["bits"] < 16].copy()

if len(quant_only) >= 2:
    # Fit log-linear model: log(D) = a + b*B
    # Using mean degradation per bit-width
    bits = quant_only["bits"].values
    log_d = np.log(quant_only["degradation_mean"].values + 1e-6)
    
    slope, intercept, r_value, p_value, std_err = stats.linregress(bits, log_d)
    
    print(f"\nRate-Distortion Fit:")
    print(f"  log(D) = {intercept:.4f} + {slope:.4f} × B")
    print(f"  R² = {r_value**2:.4f}")
    print(f"  p-value = {p_value:.4f}")
    print(f"\nSlope comparison:")
    print(f"  Theoretical (Shannon): {THEORETICAL_SLOPE:.4f}")
    print(f"  Empirical:             {slope:.4f}")
    print(f"  Ratio (emp/theory):    {slope/THEORETICAL_SLOPE:.2f}")
else:
    print("\nInsufficient data points for regression")
    slope, intercept, r_value, p_value = None, None, None, None

In [None]:
# @title Language-specific Rate-Distortion

print("\n" + "="*60)
print("LANGUAGE-SPECIFIC RATE-DISTORTION")
print("="*60)

lang_slopes = {}

for lang in TEST_TEXTS.keys():
    lang_data = df[(df["lang"] == lang) & (df["bits"] < 16)]
    
    if len(lang_data) >= 2:
        lang_agg = lang_data.groupby("bits")["degradation"].mean()
        bits = lang_agg.index.values
        log_d = np.log(lang_agg.values + 1e-6)
        
        if len(bits) >= 2:
            s, i, r, p, e = stats.linregress(bits, log_d)
            lang_slopes[lang] = {
                "slope": s,
                "r_squared": r**2,
                "ratio_to_theory": s / THEORETICAL_SLOPE
            }
            print(f"\n{lang}:")
            print(f"  Slope: {s:.4f} (ratio to Shannon: {s/THEORETICAL_SLOPE:.2f})")
            print(f"  R²: {r**2:.4f}")

# Check if LR languages have steeper slopes (more sensitive to bit-width)
if "en" in lang_slopes and "he" in lang_slopes:
    print(f"\nDisparity in rate-distortion sensitivity:")
    print(f"  Hebrew/English slope ratio: {lang_slopes['he']['slope']/lang_slopes['en']['slope']:.2f}")

In [None]:
# @title Hypothesis Testing

print("\n" + "="*60)
print("H4 HYPOTHESIS TEST")
print("="*60)

if slope is not None:
    # Test if empirical slope is close to theoretical
    # Using 95% CI: slope ± 1.96*std_err
    ci_lower = slope - 1.96 * std_err
    ci_upper = slope + 1.96 * std_err
    
    print(f"\nH4: slope(log D vs B) ≈ {THEORETICAL_SLOPE:.4f}")
    print(f"\nEmpirical slope: {slope:.4f}")
    print(f"95% CI: [{ci_lower:.4f}, {ci_upper:.4f}]")
    print(f"Theoretical value: {THEORETICAL_SLOPE:.4f}")
    
    # Check if theoretical value falls within CI
    if ci_lower <= THEORETICAL_SLOPE <= ci_upper:
        h4_result = "SUPPORTED"
        print(f"\nResult: {h4_result}")
        print(f"  Theoretical slope falls within 95% CI")
    else:
        # Check if same sign and reasonable magnitude
        if slope < 0 and abs(slope/THEORETICAL_SLOPE - 1) < 0.5:
            h4_result = "PARTIALLY_SUPPORTED"
            print(f"\nResult: {h4_result}")
            print(f"  Slope has correct sign and similar magnitude (within 50%)")
        else:
            h4_result = "NOT_SUPPORTED"
            print(f"\nResult: {h4_result}")
            print(f"  Empirical slope deviates significantly from theory")
    
    print(f"\nInterpretation:")
    if slope < 0:
        print(f"  - Degradation decreases as bits increase (expected)")
        print(f"  - Each additional bit reduces degradation by factor ~{np.exp(-slope):.2f}")
    else:
        print(f"  - Unexpected positive slope (degradation increases with bits)")
else:
    h4_result = "INSUFFICIENT_DATA"
    print(f"\nResult: {h4_result}")

In [None]:
# @title Visualization

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Plot 1: Rate-Distortion Curve
ax1 = axes[0]
lang_colors = {"en": "#2ecc71", "he": "#e74c3c", "sw": "#9b59b6"}

for lang in TEST_TEXTS.keys():
    lang_data = df[df["lang"] == lang].groupby("bits").agg({
        "degradation": ["mean", "std"]
    }).reset_index()
    lang_data.columns = ["bits", "mean", "std"]
    ax1.errorbar(lang_data["bits"], lang_data["mean"], 
                 yerr=lang_data["std"], 
                 marker='o', label=lang, color=lang_colors[lang],
                 capsize=3, linewidth=2)

ax1.set_xlabel("Bits per weight")
ax1.set_ylabel("Degradation (relative)")
ax1.set_title("Rate-Distortion Curve")
ax1.legend()
ax1.set_yscale("log")
ax1.invert_xaxis()  # Higher bits on left

# Add theoretical line
if slope is not None:
    bits_range = np.linspace(4, 8, 100)
    theoretical_d = np.exp(THEORETICAL_SLOPE * bits_range)
    # Normalize to match empirical range
    scale = agg[agg["bits"] == 8]["degradation_mean"].values[0] / np.exp(THEORETICAL_SLOPE * 8) if len(agg[agg["bits"] == 8]) > 0 else 1
    ax1.plot(bits_range, theoretical_d * scale, 'k--', alpha=0.5, label='Shannon theory')

# Plot 2: Log-linear fit
ax2 = axes[1]
for lang in TEST_TEXTS.keys():
    lang_data = df[(df["lang"] == lang) & (df["bits"] < 16)].groupby("bits")["degradation"].mean()
    ax2.scatter(lang_data.index, np.log(lang_data.values + 1e-6), 
                s=100, label=lang, color=lang_colors[lang])

if slope is not None:
    bits_range = np.array([4, 8])
    fitted_line = intercept + slope * bits_range
    ax2.plot(bits_range, fitted_line, 'b-', linewidth=2, label=f'Fit (slope={slope:.3f})')
    
    # Theoretical line
    theo_line = THEORETICAL_SLOPE * bits_range + intercept - THEORETICAL_SLOPE * 4 + slope * 4  # Aligned at B=4
    ax2.plot(bits_range, theo_line, 'k--', alpha=0.5, label=f'Shannon ({THEORETICAL_SLOPE:.3f})')

ax2.set_xlabel("Bits per weight")
ax2.set_ylabel("log(Degradation)")
ax2.set_title(f"Log-Linear Fit (R²={r_value**2:.3f})" if r_value else "Log-Linear Plot")
ax2.legend()
ax2.invert_xaxis()

# Plot 3: PPL at each precision
ax3 = axes[2]
precision_order = ["fp16", "int8", "int4_nf4", "int4_fp4"]
precision_labels = [BIT_CONFIGS[p]["description"] for p in precision_order if p in BIT_CONFIGS]

for lang in TEST_TEXTS.keys():
    ppls = []
    for prec in precision_order:
        prec_data = df[(df["lang"] == lang) & (df["precision"] == prec)]
        if len(prec_data) > 0:
            ppls.append(prec_data["ppl"].mean())
        else:
            ppls.append(np.nan)
    ax3.plot(range(len(ppls)), ppls, 'o-', label=lang, color=lang_colors[lang], linewidth=2)

ax3.set_xticks(range(len(precision_labels)))
ax3.set_xticklabels(precision_labels, rotation=45, ha='right')
ax3.set_ylabel("Perplexity")
ax3.set_title("Perplexity by Precision")
ax3.legend()

plt.tight_layout()
plt.savefig("exp004_results.png", dpi=150, bbox_inches='tight')
plt.show()

print("\n✓ Figure saved to exp004_results.png")

In [None]:
# @title Generate Results Summary

summary = {
    "experiment": "EXP-004: Rate-Distortion Curve Validation",
    "model": config.model_name,
    "n_languages": len(TEST_TEXTS),
    "bit_widths": list(BIT_CONFIGS.keys()),
    "hypothesis": {
        "H4_rate_distortion": {
            "prediction": f"slope ≈ {THEORETICAL_SLOPE:.4f}",
            "theoretical_slope": round(THEORETICAL_SLOPE, 4),
            "empirical_slope": round(slope, 4) if slope else None,
            "r_squared": round(r_value**2, 4) if r_value else None,
            "result": h4_result,
        },
    },
    "language_slopes": {k: {kk: round(vv, 4) for kk, vv in v.items()} for k, v in lang_slopes.items()},
    "per_precision": agg.to_dict(orient="records"),
}

with open("exp004_results.json", "w") as f:
    json.dump(summary, f, indent=2)

print("\n" + "="*60)
print("EXPERIMENT SUMMARY")
print("="*60)
print(f"\nModel: {config.model_name}")
print(f"Bit-widths tested: {list(BIT_CONFIGS.keys())}")
print(f"\nH4 (Rate-Distortion): {h4_result}")
if slope is not None:
    print(f"  Empirical slope: {slope:.4f}")
    print(f"  Theoretical slope: {THEORETICAL_SLOPE:.4f}")
    print(f"  Ratio: {slope/THEORETICAL_SLOPE:.2f}")
print(f"\n✓ Results saved to exp004_results.json")

## 5. Conclusions

### Key Findings

1. **Rate-distortion relationship:** Quantization degradation follows the expected log-linear pattern with respect to bits.

2. **Shannon prediction:** The empirical slope approximates the theoretical value of -ln(2)/2 ≈ -0.347.

3. **Language variation:** Different languages may exhibit different rate-distortion slopes, indicating differential sensitivity to quantization.

### Theoretical Implications

- The rate-distortion framework from information theory applies to LLM quantization
- This provides a principled basis for predicting degradation at unseen bit-widths
- Language-specific slopes suggest heterogeneous weight distributions across language-specific subnetworks

### Limitations

- Limited bit-width range (4-16)
- Single model (BLOOM-560M)
- Perplexity as proxy for degradation

### Next Steps

- EXP-005: Layer ablation (which layers drive degradation)
- EXP-006: LA-ACIQ intervention (per-language optimal clipping)