# Phase 2 — Feature Analysis & Ablation

**CSP-Ablation-Project** · T1–T5

Builds on Phase 1 probe. Pipeline:
- **T1** Probe weight extraction & feature ranking
- **T2** Feature activation profiling (AUC, histograms)
- **T3** Ablation engine (zero + mean)
- **T4** Ablation threshold sweep → ablation curve
- **T5** Generation sanity check (original vs ablated)

**Prerequisite:** Run `phase1_probing.ipynb` first. Artifacts in `DATA/CSP-Ablation-Project/artifacts/`.

---
## Setup

In [None]:
# Run ONCE, then: Runtime → Restart session. Skip afterwards.
#!pip install -q torch transformers accelerate scikit-learn matplotlib pandas

In [None]:
import os, sys, json, pickle
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import roc_auc_score

from google.colab import drive
drive.mount("/content/drive", force_remount=True)

DRIVE_ROOT = "/content/drive/MyDrive"
CODE_DIR   = os.path.join(DRIVE_ROOT, "CODE", "CSP-Ablation-Project")
DATA_DIR   = os.path.join(DRIVE_ROOT, "DATA", "CSP-Ablation-Project")
SPRINT, VERSION = "sprint1", "v1.0"

if not os.path.isdir(CODE_DIR):
    !git clone https://github.com/piotrwilam/CSP-Ablation-Project.git "{CODE_DIR}"
else:
    !cd "{CODE_DIR}" && git pull

if CODE_DIR not in sys.path:
    sys.path.insert(0, CODE_DIR)

from src.config import artifacts_dir
ARTIFACTS = artifacts_dir(SPRINT, VERSION)
# Fallback: legacy flat artifacts (if phase1 ran before versioning)
if not os.path.exists(os.path.join(ARTIFACTS, "code_vuln_probe.pkl")):
    legacy = os.path.join(DATA_DIR, "artifacts")
    if os.path.exists(os.path.join(legacy, "code_vuln_probe.pkl")):
        ARTIFACTS = legacy
        print("Using legacy artifacts path")

print(f"ARTIFACTS: {ARTIFACTS}")

In [None]:
# Load Phase 1 artifacts
with open(os.path.join(ARTIFACTS, "code_vuln_probe.pkl"), "rb") as f:
    probe_dump = pickle.load(f)
probe = probe_dump["probe"]
scaler = probe_dump["scaler"]
PROBE_LAYER = probe_dump["probe_layer"]

X = np.load(os.path.join(ARTIFACTS, "X_train.npy"))
y = np.load(os.path.join(ARTIFACTS, "y_train.npy"))

print(f"Probe layer: {PROBE_LAYER} | X: {X.shape} | y: {y.shape}")

---
## T1. Probe Weight Extraction & Feature Ranking

In [None]:
weights = probe.coef_[0]
n_features = len(weights)

# Rank by magnitude
order = np.argsort(np.abs(weights))[::-1]
ranks = np.empty_like(order)
ranks[order] = np.arange(n_features)

# Top-20 positive (insecurity-associated) and top-20 negative (security-associated)
pos_idx = np.where(weights > 0)[0]
neg_idx = np.where(weights < 0)[0]
top20_pos = pos_idx[np.argsort(weights[pos_idx])[::-1][:20]] if len(pos_idx) > 0 else np.array([])
top20_neg = neg_idx[np.argsort(weights[neg_idx])[:20]] if len(neg_idx) > 0 else np.array([])

top_features = {
    "top20_insecurity": [{"idx": int(i), "weight": float(weights[i]), "rank": int(ranks[i])} for i in top20_pos],
    "top20_security": [{"idx": int(i), "weight": float(weights[i]), "rank": int(ranks[i])} for i in top20_neg],
    "n_features": n_features,
    "weight_stats": {
        "mean_abs": float(np.mean(np.abs(weights))),
        "std": float(np.std(weights)),
        "max": float(np.max(weights)),
        "min": float(np.min(weights)),
    },
}

with open(os.path.join(ARTIFACTS, "top_features.json"), "w") as f:
    json.dump(top_features, f, indent=2)

print("Top 5 insecurity-associated:", [f"{t['idx']}:{t['weight']:.4f}" for t in top_features["top20_insecurity"][:5]])
print("Top 5 security-associated:", [f"{t['idx']}:{t['weight']:.4f}" for t in top_features["top20_security"][:5]])
print(f"Saved → {ARTIFACTS}/top_features.json")

In [None]:
# Weight distribution histogram
fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(np.abs(weights), bins=80, edgecolor="black", alpha=0.7)
ax.set_xlabel("|Weight|")
ax.set_ylabel("Count")
ax.set_title("Probe weight magnitude distribution (is signal concentrated or spread?)")
plt.tight_layout()
plt.savefig(os.path.join(ARTIFACTS, "weight_distribution_histogram.png"), dpi=150, bbox_inches="tight")
plt.show()
print(f"Saved → {ARTIFACTS}/weight_distribution_histogram.png")

---
## T2. Feature Activation Profiling

In [None]:
# All 40 top features (20 pos + 20 neg)
all_top_idx = list(top20_pos) + list(top20_neg)

# Per-feature AUC
aucs = []
for idx in all_top_idx:
    act = X[:, idx]
    if len(np.unique(y)) < 2:
        aucs.append(0.5)
        continue
    auc = roc_auc_score(y, act)
    aucs.append(auc)

auc_table = pd.DataFrame({
    "feature_idx": all_top_idx,
    "weight": [float(weights[i]) for i in all_top_idx],
    "auc": aucs,
})
auc_table = auc_table.sort_values("auc", ascending=False).reset_index(drop=True)
auc_table.to_csv(os.path.join(ARTIFACTS, "per_feature_auc.csv"), index=False)
print(auc_table.head(10))
print(f"\nSaved → {ARTIFACTS}/per_feature_auc.csv")

In [None]:
# Top 5 most discriminative (by AUC) — overlapping histograms
top5_idx = auc_table.head(5)["feature_idx"].tolist()

fig, axes = plt.subplots(2, 3, figsize=(12, 7))
axes = axes.flatten()
for i, idx in enumerate(top5_idx):
    ax = axes[i]
    secure = X[y == 0, idx]
    insecure = X[y == 1, idx]
    ax.hist(secure, bins=25, alpha=0.6, label="Secure", color="green", density=True)
    ax.hist(insecure, bins=25, alpha=0.6, label="Insecure", color="red", density=True)
    ax.set_xlabel(f"Feature {idx} activation")
    ax.set_ylabel("Density")
    auc_val = auc_table[auc_table["feature_idx"] == idx]["auc"].values[0]
    ax.set_title(f"Feature {idx} (AUC={auc_val:.3f})")
    ax.legend()
axes[-1].axis("off")
plt.suptitle("Top 5 discriminative features: Secure vs Insecure activation distributions")
plt.tight_layout()
plt.savefig(os.path.join(ARTIFACTS, "activation_histograms_top5.png"), dpi=150, bbox_inches="tight")
plt.show()
print(f"Saved → {ARTIFACTS}/activation_histograms_top5.png")

---
## T3. Ablation Engine

In [None]:
from src.model_loader import load_model_and_tokenizer
from src.ablation import CircuitAblator
from src.hidden_states import collect_resid_all_layers
from src.data_loader import load_minimal_pairs, prompt_from_scenario

model, tokenizer, layers = load_model_and_tokenizer()
device = next(model.parameters()).device

In [None]:
# Test ablation on 3–5 examples
from src.data_utils import get_dataset_path
PAIRS_PATH = get_dataset_path(SPRINT, CODE_DIR, DATA_DIR)
examples = load_minimal_pairs(PAIRS_PATH)[:5]

# Extract hidden states without ablation
all_data = collect_resid_all_layers(examples, model, tokenizer, layers, max_samples=5)
X_orig, y_orig = all_data[PROBE_LAYER]

# Zero-ablate top-5 features
top5_ablate = [t["idx"] for t in top_features["top20_insecurity"][:5]]
ablator = CircuitAblator(layers[PROBE_LAYER], PROBE_LAYER, top5_ablate, strategy="zero")
ablator.enable()
all_data_abl = collect_resid_all_layers(examples, model, tokenizer, layers, max_samples=5)
ablator.disable()
X_abl, y_abl = all_data_abl[PROBE_LAYER]

# Compare: ablated features should be zero
for idx in top5_ablate:
    print(f"Feature {idx}: orig mean={X_orig[:, idx].mean():.4f}, ablated mean={X_abl[:, idx].mean():.6f}")
print("\nAblation engine OK.")

---
## T4. Ablation Threshold Sweep

In [None]:
# Top features by |weight| (combined ranking)
order = np.argsort(np.abs(weights))[::-1]

k_values = [1, 3, 5, 10, 20, 50, 100, 200]
results = []

X_s = scaler.transform(X)
acc_baseline = probe.score(X_s, y)
print(f"Baseline probe accuracy (no ablation): {acc_baseline:.2%}")

for k in k_values:
    indices = order[:k].tolist()
    X_abl = X.copy()
    X_abl[:, indices] = 0
    X_abl_s = scaler.transform(X_abl)
    acc = probe.score(X_abl_s, y)
    results.append({"k": k, "accuracy": float(acc)})
    print(f"  k={k:>3}: {acc:.2%}")

with open(os.path.join(ARTIFACTS, "ablation_sweep_results.json"), "w") as f:
    json.dump({"baseline": acc_baseline, "sweep": results}, f, indent=2)
print(f"\nSaved → {ARTIFACTS}/ablation_sweep_results.json")

In [None]:
# Ablation curve plot
ks = [r["k"] for r in results]
accs = [r["accuracy"] for r in results]

fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(ks, accs, "o-", linewidth=2, markersize=8)
ax.axhline(acc_baseline, color="gray", linestyle="--", alpha=0.7, label="Baseline")
ax.axhline(0.5, color="red", linestyle=":", alpha=0.5, label="Chance")
ax.set_xlabel("Number of ablated features (k)")
ax.set_ylabel("Probe accuracy")
ax.set_title("Ablation curve: probe accuracy vs top-k zero ablation")
ax.legend()
ax.set_ylim(0.4, 1.05)
plt.tight_layout()
plt.savefig(os.path.join(ARTIFACTS, "ablation_curve.png"), dpi=150, bbox_inches="tight")
plt.show()
print(f"Saved → {ARTIFACTS}/ablation_curve.png")

---
## T5. Generation Sanity Check

In [None]:
# Pick optimal k from T4 (smallest k with notable accuracy drop)
OPTIMAL_K = 20  # or derive from sweep: first k where acc < baseline - 0.05
for r in results:
    if r["accuracy"] < acc_baseline - 0.05:
        OPTIMAL_K = r["k"]
        break
print(f"Using k={OPTIMAL_K} for ablation.")

# Build 40 prompts from minimal pairs (20 insecure, 20 secure)
with open(PAIRS_PATH) as f:
    pairs = json.load(f)

prompts_insecure = []
prompts_secure = []
for p in pairs:
    if len(prompts_insecure) < 20:
        # Use ~70% of corrupted code as prompt (model completes the rest)
        code = p["corrupted"]
        cut = max(1, int(len(code) * 0.7))
        prompts_insecure.append(code[:cut])
    if len(prompts_secure) < 20:
        code = p["clean"]
        cut = max(1, int(len(code) * 0.7))
        prompts_secure.append(code[:cut])
    if len(prompts_insecure) >= 20 and len(prompts_secure) >= 20:
        break

prompts = prompts_insecure + prompts_secure
labels = ["insecure"] * 20 + ["secure"] * 20
print(f"{len(prompts)} prompts ({len(prompts_insecure)} insecure, {len(prompts_secure)} secure)")

In [None]:
import torch

GEN_MAX_NEW = 50
top_k_indices = order[:OPTIMAL_K].tolist()

# Ablator for generation
ablator = CircuitAblator(layers[PROBE_LAYER], PROBE_LAYER, top_k_indices, strategy="zero")

outputs_orig = []
outputs_abl = []

for i, prompt in enumerate(prompts):
    enc = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400).to(device)
    input_ids = enc["input_ids"]
    
    with torch.no_grad():
        out_orig = model.generate(input_ids, max_new_tokens=GEN_MAX_NEW, do_sample=False, pad_token_id=tokenizer.eos_token_id)
    
    ablator.enable()
    with torch.no_grad():
        out_abl = model.generate(input_ids, max_new_tokens=GEN_MAX_NEW, do_sample=False, pad_token_id=tokenizer.eos_token_id)
    ablator.disable()
    
    text_orig = tokenizer.decode(out_orig[0][input_ids.shape[1]:], skip_special_tokens=True)
    text_abl = tokenizer.decode(out_abl[0][input_ids.shape[1]:], skip_special_tokens=True)
    outputs_orig.append(text_orig)
    outputs_abl.append(text_abl)
    if (i + 1) % 10 == 0:
        print(f"  {i+1}/{len(prompts)} done")

print("Generation complete.")

In [None]:
# Side-by-side comparison table
df = pd.DataFrame({
    "label": labels,
    "prompt_preview": [p[:60] + "..." if len(p) > 60 else p for p in prompts],
    "original_output": outputs_orig,
    "ablated_output": outputs_abl,
})
df.to_csv(os.path.join(ARTIFACTS, "generation_comparison.csv"), index=False)
df.head(10)
print(f"\nFull table saved → {ARTIFACTS}/generation_comparison.csv")

# Push all Phase 2 artifacts to Hugging Face
from src.utils import save_to_hub
from src.config import HF_REPO_ID
hf_prefix = f"artifacts/{SPRINT}/{VERSION}"
for f in os.listdir(ARTIFACTS):
    p = os.path.join(ARTIFACTS, f)
    if os.path.isfile(p):
        save_to_hub(p, f"{hf_prefix}/{f}", HF_REPO_ID)

---
## T5. Notes (manual)

Review `generation_comparison.csv` and note:
- Does the ablated model still produce coherent code?
- Does it avoid insecure patterns more often?
- Any obvious breakage or degradation?