# Watermark Ablation Analysis

This notebook analyzes results from the watermark ablation experiment.

## Objectives
1. Load all result JSONs
2. Compute ROC curves, AUC, TPR@FPR=0.01
3. Compute summary statistics
4. Create visualizations
5. Rank configurations


In [None]:
import json
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from sklearn.metrics import roc_curve, auc, roc_auc_score

# Set style
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (12, 8)


## 1. Load All Result JSONs


In [None]:
# Use project root for path resolution (robust to notebook execution context)
# Try to find project root: if we're in analysis/, go up one level; otherwise assume we're in project root
current_dir = Path.cwd()
if current_dir.name == "analysis":
    project_root = current_dir.parent
elif (current_dir / "experiments").exists():
    project_root = current_dir
else:
    # Fallback: try parent directory
    project_root = current_dir.parent

results_dir = project_root / "experiments" / "watermark_ablation" / "results"

# Final fallback to relative path
if not results_dir.exists():
    results_dir = Path("../experiments/watermark_ablation/results")

# Find all result JSON files
result_files = sorted(results_dir.glob("*.json"))
print(f"Found {len(result_files)} result files")

if len(result_files) == 0:
    print("‚ö†Ô∏è  No result files found. Please run the experiment first:")
    print("   python scripts/run_watermark_ablation.py --prompts-file <path> --master-key <key>")
else:
    # Load all results
    results = {}
    for result_file in result_files:
        config_name = result_file.stem
        try:
            with open(result_file, "r") as f:
                results[config_name] = json.load(f)
            print(f"  ‚úì Loaded: {config_name}")
        except Exception as e:
            print(f"  ‚úó Failed to load {config_name}: {e}")
    
    print(f"\n‚úì Successfully loaded {len(results)} result files")


## 2. Compute Metrics for Each Config


In [None]:
def compute_roc_curve(scores: np.ndarray, labels: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Compute ROC curve."""
    fpr, tpr, thresholds = roc_curve(labels, scores)
    return fpr, tpr, thresholds


def find_tpr_at_fpr(fpr: np.ndarray, tpr: np.ndarray, target_fpr: float = 0.01) -> float:
    """Find TPR at target FPR."""
    # Find index where FPR is closest to target
    idx = np.argmin(np.abs(fpr - target_fpr))
    return float(tpr[idx])


def compute_predicted_separation(n_eff: float, delta_hat: float) -> float:
    """
    Compute heuristic predicted separation (approximate): ŒîŒº ‚âà 4 * N_eff * Œ¥_hat^2
    
    NOTE: This is only a heuristic approximation and should not be treated as a primary metric.
    
    Args:
        n_eff: Mean effective number of positions
        delta_hat: Mean(p_hat - 0.5) for watermarked samples
        
    Returns:
        Heuristic predicted separation (approximate)
    """
    if n_eff <= 0 or abs(delta_hat) < 1e-10:
        return 0.0
    return 4.0 * n_eff * (delta_hat ** 2)


# Compute metrics for each config
summary_data = []

for config_name, data in results.items():
    # Extract separated arrays from new schema
    log_odds_wm = np.array(data["log_odds_wm"])
    log_odds_clean = np.array(data["log_odds_clean"])
    n_eff_wm = np.array(data["N_eff_wm"])
    n_eff_clean = np.array(data["N_eff_clean"])
    p_hat_wm = np.array(data["p_hat_wm"])
    p_hat_clean = np.array(data["p_hat_clean"])
    
    # üî• Fix 3: ROC data integrity assertions
    # Assert array lengths match before concatenation to prevent silent corruption
    assert len(log_odds_wm) == len(n_eff_wm), (
        f"Length mismatch for {config_name}: "
        f"log_odds_wm ({len(log_odds_wm)}) != N_eff_wm ({len(n_eff_wm)})"
    )
    assert len(log_odds_wm) == len(p_hat_wm), (
        f"Length mismatch for {config_name}: "
        f"log_odds_wm ({len(log_odds_wm)}) != p_hat_wm ({len(p_hat_wm)})"
    )
    assert len(log_odds_clean) == len(n_eff_clean), (
        f"Length mismatch for {config_name}: "
        f"log_odds_clean ({len(log_odds_clean)}) != N_eff_clean ({len(n_eff_clean)})"
    )
    assert len(log_odds_clean) == len(p_hat_clean), (
        f"Length mismatch for {config_name}: "
        f"log_odds_clean ({len(log_odds_clean)}) != p_hat_clean ({len(p_hat_clean)})"
    )
    
    # Combine scores and labels for ROC
    scores = np.concatenate([log_odds_wm, log_odds_clean])
    labels = np.concatenate([np.ones(len(log_odds_wm)), np.zeros(len(log_odds_clean))])
    
    # Compute ROC curve
    fpr, tpr, thresholds = compute_roc_curve(scores, labels)
    
    # Compute AUC
    roc_auc = auc(fpr, tpr)
    
    # Compute TPR @ FPR = 0.01
    tpr_at_1pct = find_tpr_at_fpr(fpr, tpr, target_fpr=0.01)
    
    # Compute statistics using separated wm and clean arrays
    mean_log_odds_wm = float(np.mean(log_odds_wm))
    std_log_odds_wm = float(np.std(log_odds_wm))
    mean_log_odds_clean = float(np.mean(log_odds_clean))
    std_log_odds_clean = float(np.std(log_odds_clean))
    
    # Separate N_eff statistics
    mean_n_eff_wm = float(np.mean(n_eff_wm)) if len(n_eff_wm) > 0 else 0.0
    mean_n_eff_clean = float(np.mean(n_eff_clean)) if len(n_eff_clean) > 0 else 0.0
    
    # Separate p_hat statistics
    mean_p_hat_wm = float(np.mean(p_hat_wm)) if len(p_hat_wm) > 0 else 0.0
    mean_p_hat_clean = float(np.mean(p_hat_clean)) if len(p_hat_clean) > 0 else 0.0
    
    # Compute delta_hat = mean(p_hat_wm - 0.5) for watermarked samples only
    delta_hat = float(np.mean(p_hat_wm - 0.5)) if len(p_hat_wm) > 0 else 0.0
    
    # Compute heuristic predicted separation (approximate - for exploratory visualization only)
    heuristic_predicted_separation = compute_predicted_separation(mean_n_eff_wm, delta_hat)
    
    # Compute empirical separation (mean log_odds difference) - primary metric
    empirical_separation = mean_log_odds_wm - mean_log_odds_clean
    
    # Extract metadata fields (if available)
    metadata_fields = {
        "config_path": data.get("config_path"),
        "num_samples": data.get("num_samples"),
        "num_inversion_steps": data.get("num_inversion_steps"),
        "likelihood_model_path": data.get("likelihood_model_path"),
        "timestamp": data.get("timestamp"),
        "git_commit": data.get("git_commit"),
    }
    
    summary_data.append({
        "config_name": config_name,
        "auc": roc_auc,
        "tpr_at_1pct_fpr": tpr_at_1pct,
        "mean_log_odds_wm": mean_log_odds_wm,
        "std_log_odds_wm": std_log_odds_wm,
        "mean_log_odds_clean": mean_log_odds_clean,
        "std_log_odds_clean": std_log_odds_clean,
        "mean_n_eff_wm": mean_n_eff_wm,
        "mean_n_eff_clean": mean_n_eff_clean,
        "mean_p_hat_wm": mean_p_hat_wm,
        "mean_p_hat_clean": mean_p_hat_clean,
        "delta_hat": delta_hat,
        "heuristic_predicted_separation": heuristic_predicted_separation,
        "empirical_separation": empirical_separation,
        "fpr": fpr.tolist(),
        "tpr": tpr.tolist(),
        "thresholds": thresholds.tolist(),
        **metadata_fields,  # Include metadata fields
    })

# Create summary DataFrame
summary_df = pd.DataFrame(summary_data)
print("Summary Statistics:")
print(summary_df.to_string(index=False))


## 3. Ranking Configurations


In [None]:
# Rank by primary metric: TPR @ FPR = 0.01 (or empirical separation as secondary)
summary_df_ranked = summary_df.sort_values("tpr_at_1pct_fpr", ascending=False).reset_index(drop=True)

print("\nRanking by TPR @ FPR = 0.01 (Primary Metric):")
print("=" * 80)
for idx, row in summary_df_ranked.iterrows():
    print(f"{idx+1:2d}. {row['config_name']:40s} | TPR@1%: {row['tpr_at_1pct_fpr']:.4f} | AUC: {row['auc']:.4f} | N_eff_wm: {row['mean_n_eff_wm']:.1f}")

# Display full summary table with metadata and separated stats
print("\n\nFull Summary Table:")
print("=" * 80)
display_cols = [
    "config_name",
    "num_samples",
    "num_inversion_steps",
    "likelihood_model_path",
    "tpr_at_1pct_fpr",
    "auc",
    "empirical_separation",
    "mean_n_eff_wm",
    "mean_n_eff_clean",
    "mean_p_hat_wm",
    "mean_p_hat_clean",
    "delta_hat",
]
print(summary_df_ranked[display_cols].to_string(index=False))


## 4. Visualizations


In [None]:
# Use same path resolution as results_dir
plots_dir = project_root / "experiments" / "watermark_ablation" / "plots"
if not plots_dir.exists():
    plots_dir = Path("../experiments/watermark_ablation/plots")
plots_dir.mkdir(parents=True, exist_ok=True)

# Plot 1: ROC Curves for all configs
plt.figure(figsize=(10, 8))
for config_name, data in results.items():
    log_odds_wm = np.array(data["log_odds_wm"])
    log_odds_clean = np.array(data["log_odds_clean"])
    
    # Data integrity check (should already be validated, but double-check for safety)
    assert len(log_odds_wm) > 0 and len(log_odds_clean) > 0, (
        f"Empty arrays for {config_name}: wm={len(log_odds_wm)}, clean={len(log_odds_clean)}"
    )
    
    scores = np.concatenate([log_odds_wm, log_odds_clean])
    labels = np.concatenate([np.ones(len(log_odds_wm)), np.zeros(len(log_odds_clean))])
    
    fpr, tpr, _ = roc_curve(labels, scores)
    roc_auc = auc(fpr, tpr)
    
    plt.plot(fpr, tpr, label=f"{config_name} (AUC={roc_auc:.3f})", linewidth=2)

plt.plot([0, 1], [0, 1], "k--", linewidth=1, label="Random")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate", fontsize=12)
plt.ylabel("True Positive Rate", fontsize=12)
plt.title("ROC Curves - All Configurations", fontsize=14)
plt.legend(loc="lower right", fontsize=8, ncol=2)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(plots_dir / "roc_curves_all.png", dpi=300, bbox_inches="tight")
plt.show()
print(f"‚úì Saved: {plots_dir / 'roc_curves_all.png'}")


In [None]:
# Plot 2: Histogram overlays for each config
for config_name, data in results.items():
    log_odds_wm = np.array(data["log_odds_wm"])
    log_odds_clean = np.array(data["log_odds_clean"])
    
    plt.figure(figsize=(10, 6))
    
    # Histogram overlay
    plt.hist(log_odds_clean, bins=50, alpha=0.6, label="Clean", color="blue", density=True)
    plt.hist(log_odds_wm, bins=50, alpha=0.6, label="Watermarked", color="red", density=True)
    
    plt.xlabel("Log-Odds", fontsize=12)
    plt.ylabel("Density", fontsize=12)
    plt.title(f"Log-Odds Distribution: {config_name}", fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    # Save plot
    plot_filename = f"histogram_{config_name}.png"
    plt.savefig(plots_dir / plot_filename, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"‚úì Saved: {plots_dir / plot_filename}")


In [None]:
# Plot 3: TPR @ 1% FPR comparison
summary_df_ranked = summary_df.sort_values("tpr_at_1pct_fpr", ascending=True)

plt.figure(figsize=(12, 8))
y_pos = np.arange(len(summary_df_ranked))
plt.barh(y_pos, summary_df_ranked["tpr_at_1pct_fpr"], alpha=0.7)
plt.yticks(y_pos, summary_df_ranked["config_name"])
plt.xlabel("TPR @ FPR = 0.01", fontsize=12)
plt.title("TPR @ 1% FPR Comparison (Primary Ranking Metric)", fontsize=14)
plt.grid(True, alpha=0.3, axis="x")
plt.tight_layout()
plt.savefig(plots_dir / "tpr_comparison.png", dpi=300, bbox_inches="tight")
plt.show()
print(f"‚úì Saved: {plots_dir / 'tpr_comparison.png'}")


In [None]:
# Plot 4: N_eff_wm vs TPR @ 1% FPR
plt.figure(figsize=(10, 6))
plt.scatter(summary_df["mean_n_eff_wm"], summary_df["tpr_at_1pct_fpr"], s=100, alpha=0.6, label="Watermarked")
plt.scatter(summary_df["mean_n_eff_clean"], summary_df["tpr_at_1pct_fpr"], s=100, alpha=0.6, marker="^", label="Clean")
for idx, row in summary_df.iterrows():
    plt.annotate(row["config_name"], (row["mean_n_eff_wm"], row["tpr_at_1pct_fpr"]), 
                fontsize=8, alpha=0.7)
plt.xlabel("Mean N_eff", fontsize=12)
plt.ylabel("TPR @ FPR = 0.01", fontsize=12)
plt.title("N_eff vs TPR @ 1% FPR", fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(plots_dir / "n_eff_vs_tpr.png", dpi=300, bbox_inches="tight")
plt.show()
print(f"‚úì Saved: {plots_dir / 'n_eff_vs_tpr.png'}")


In [None]:
# Plot 5: Heuristic Predicted vs Empirical Separation (Exploratory Only)
plt.figure(figsize=(10, 6))
plt.scatter(summary_df["heuristic_predicted_separation"], summary_df["empirical_separation"], s=100, alpha=0.6)
for idx, row in summary_df.iterrows():
    plt.annotate(row["config_name"], (row["heuristic_predicted_separation"], row["empirical_separation"]), 
                fontsize=8, alpha=0.7)

# Add diagonal line
min_val = min(summary_df["heuristic_predicted_separation"].min(), summary_df["empirical_separation"].min())
max_val = max(summary_df["heuristic_predicted_separation"].max(), summary_df["empirical_separation"].max())
plt.plot([min_val, max_val], [min_val, max_val], "k--", alpha=0.5, label="y=x")

plt.xlabel("Heuristic Predicted Separation (approximate: ŒîŒº ‚âà 4 * N_eff * Œ¥_hat¬≤)", fontsize=12)
plt.ylabel("Empirical Separation (mean log_odds_wm - mean log_odds_clean)", fontsize=12)
plt.title("Heuristic Predicted vs Empirical Separation (Approximate - Exploratory Only)", fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(plots_dir / "predicted_vs_empirical_separation.png", dpi=300, bbox_inches="tight")
plt.show()
print(f"‚úì Saved: {plots_dir / 'predicted_vs_empirical_separation.png'}")


## üî• Fix 4: Diagnostic Visualizations


In [None]:
# üìä Plot A ‚Äî N_eff Distributions
# Histogram/KDE overlay for N_eff_wm and N_eff_clean per config
for config_name, data in results.items():
    n_eff_wm = np.array(data["N_eff_wm"])
    n_eff_clean = np.array(data["N_eff_clean"])
    
    plt.figure(figsize=(10, 6))
    
    # Histogram overlay
    plt.hist(n_eff_clean, bins=30, alpha=0.6, label="Clean", color="blue", density=True)
    plt.hist(n_eff_wm, bins=30, alpha=0.6, label="Watermarked", color="red", density=True)
    
    plt.xlabel("N_eff", fontsize=12)
    plt.ylabel("Density", fontsize=12)
    plt.title(f"N_eff Distribution: {config_name}", fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    plot_filename = f"n_eff_distribution_{config_name}.png"
    plt.savefig(plots_dir / plot_filename, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"‚úì Saved: {plots_dir / plot_filename}")


In [None]:
# üìä Plot B ‚Äî p_hat Distributions
# Histogram/KDE overlay for p_hat_wm and p_hat_clean per config
for config_name, data in results.items():
    p_hat_wm = np.array(data["p_hat_wm"])
    p_hat_clean = np.array(data["p_hat_clean"])
    
    plt.figure(figsize=(10, 6))
    
    # Histogram overlay
    plt.hist(p_hat_clean, bins=30, alpha=0.6, label="Clean", color="blue", density=True)
    plt.hist(p_hat_wm, bins=30, alpha=0.6, label="Watermarked", color="red", density=True)
    
    # Add vertical line at 0.5 (expected for clean)
    plt.axvline(x=0.5, color="gray", linestyle="--", alpha=0.5, label="Expected (p=0.5)")
    
    plt.xlabel("p_hat", fontsize=12)
    plt.ylabel("Density", fontsize=12)
    plt.title(f"p_hat Distribution: {config_name}", fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    plot_filename = f"p_hat_distribution_{config_name}.png"
    plt.savefig(plots_dir / plot_filename, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"‚úì Saved: {plots_dir / plot_filename}")


In [None]:
# üìä Plot C ‚Äî Signal Geometry Scatter
# Scatter plot: x = N_eff, y = log_odds
# Separate for watermarked and clean to detect mask collapse, weak watermark regimes, instability
for config_name, data in results.items():
    log_odds_wm = np.array(data["log_odds_wm"])
    log_odds_clean = np.array(data["log_odds_clean"])
    n_eff_wm = np.array(data["N_eff_wm"])
    n_eff_clean = np.array(data["N_eff_clean"])
    
    plt.figure(figsize=(10, 6))
    
    # Scatter plots
    plt.scatter(n_eff_clean, log_odds_clean, alpha=0.6, s=50, label="Clean", color="blue", marker="o")
    plt.scatter(n_eff_wm, log_odds_wm, alpha=0.6, s=50, label="Watermarked", color="red", marker="s")
    
    plt.xlabel("N_eff", fontsize=12)
    plt.ylabel("Log-Odds", fontsize=12)
    plt.title(f"Signal Geometry: {config_name}\n(Helps detect mask collapse, weak watermark regimes, instability)", fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    plot_filename = f"signal_geometry_{config_name}.png"
    plt.savefig(plots_dir / plot_filename, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"‚úì Saved: {plots_dir / plot_filename}")


## 5. Optional: QQ Plot for Gaussianity Check


In [None]:
# QQ plot for clean distribution (should be approximately Gaussian)
# NOTE: Normality assumption may not hold for small N_eff or binary mapping
for config_name, data in results.items():
    log_odds_clean = np.array(data["log_odds_clean"])
    
    # Standardize
    mean_clean = np.mean(log_odds_clean)
    std_clean = np.std(log_odds_clean)
    standardized = (log_odds_clean - mean_clean) / std_clean if std_clean > 0 else log_odds_clean
    
    # QQ plot
    stats.probplot(standardized, dist="norm", plot=plt)
    plt.title(
        f"QQ Plot (Clean Distribution): {config_name}\n"
        "(Note: Normality assumption may not hold for small N_eff or binary mapping)",
        fontsize=14
    )
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    plot_filename = f"qq_plot_{config_name}.png"
    plt.savefig(plots_dir / plot_filename, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"‚úì Saved: {plots_dir / plot_filename}")


## 6. Save Summary Table


In [None]:
# Save summary table to CSV
summary_csv_path = plots_dir.parent / "summary_table.csv"
summary_df_ranked.to_csv(summary_csv_path, index=False)
print(f"‚úì Saved summary table: {summary_csv_path}")

# Also save as JSON
summary_json_path = plots_dir.parent / "summary_table.json"
summary_df_ranked.to_json(summary_json_path, indent=2, orient="records")
print(f"‚úì Saved summary table: {summary_json_path}")
