# CKA Transferability Analysis (Framework v3)

This notebook analyzes attack transferability between open-weight and closed-weight LLMs
using Centered Kernel Alignment (CKA) to identify optimal surrogate models.

**Phases:**
1. Phase 1: Compute CKA similarity matrix between open models
2. Phase 2: Run attack suite on all models (open + closed)
3. Phase 3: Correlate CKA with attack transfer, recommend surrogates

**Requirements:**
- HuggingFace token (for Gemma, Llama)
- OpenAI API key (optional, for Phase 2)
- Anthropic API key (optional, for Phase 2)

In [None]:
# ==========================================
# CELL 1: Install Dependencies
# ==========================================
!pip install -q transformers accelerate torch scipy numpy huggingface_hub
!pip install -q openai anthropic  # Optional for Phase 2

print("Dependencies installed.")

In [None]:
# ==========================================
# CELL 2: Clone Framework from GitHub
# ==========================================
import os
import sys
import shutil

# Clear any cached imports
for mod in list(sys.modules.keys()):
    if 'transferability' in mod:
        del sys.modules[mod]

# Clone the repository
if not os.path.exists('/content/AI-SecOps'):
    !git clone -b framework-v3 https://github.com/zbovaird/AI-SecOps.git /content/AI-SecOps
else:
    !cd /content/AI-SecOps && git fetch origin framework-v3 && git checkout framework-v3 && git pull origin framework-v3

# Copy framework to content for easy imports
framework_src = '/content/AI-SecOps/transferability_framework'
framework_dest = '/content/transferability_framework'

if os.path.exists(framework_src):
    if os.path.exists(framework_dest):
        shutil.rmtree(framework_dest)
    shutil.copytree(framework_src, framework_dest)
    print("[OK] Framework copied to /content/transferability_framework")

# Add to path
if '/content' not in sys.path:
    sys.path.insert(0, '/content')

# Verify core modules
print("\nVerifying modules:")
for module in ['cka', 'model_loader', 'attack_suite', 'api_clients']:
    path = f'/content/transferability_framework/core/{module}.py'
    status = "OK" if os.path.exists(path) else "MISSING"
    print(f"  {module}: [{status}]")

In [None]:
# ==========================================
# CELL 3: Configure API Keys and Tokens
# ==========================================
from google.colab import userdata
import os

# HuggingFace token (REQUIRED for Gemma, Llama)
# Set in Colab secrets or paste here
try:
    HF_TOKEN = userdata.get('HF_TOKEN')
    print("[OK] HuggingFace token loaded from secrets")
except:
    HF_TOKEN = ""  # Paste your token here if not using secrets
    if HF_TOKEN:
        print("[OK] HuggingFace token set manually")
    else:
        print("[WARNING] No HuggingFace token - some models may not load")

# OpenAI API key (OPTIONAL for Phase 2)
try:
    OPENAI_KEY = userdata.get('OPENAI_API_KEY')
    print("[OK] OpenAI API key loaded")
except:
    OPENAI_KEY = None
    print("[INFO] No OpenAI key - skipping GPT-4 testing")

# Anthropic API key (OPTIONAL for Phase 2)
try:
    ANTHROPIC_KEY = userdata.get('ANTHROPIC_API_KEY')
    print("[OK] Anthropic API key loaded")
except:
    ANTHROPIC_KEY = None
    print("[INFO] No Anthropic key - skipping Claude testing")

# Login to HuggingFace
if HF_TOKEN:
    from huggingface_hub import login
    login(token=HF_TOKEN, add_to_git_credential=True)
    print("\n[OK] Logged in to HuggingFace")

In [None]:
# ==========================================
# CELL 4: Verify GPU and Setup Logging
# ==========================================
import torch
import logging

# Check GPU
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"[OK] GPU: {gpu_name} ({gpu_memory:.1f} GB)")
else:
    print("[WARNING] No GPU detected - analysis will be slow")

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# ==========================================
# CELL 5: Mount Google Drive for Results
# ==========================================
from google.colab import drive
import os
from datetime import datetime

# Mount Drive
drive.mount('/content/drive')

# Create results directory
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
RESULTS_DIR = f'/content/drive/MyDrive/transferability_results/{timestamp}'
os.makedirs(RESULTS_DIR, exist_ok=True)

print(f"[OK] Results will be saved to: {RESULTS_DIR}")

In [None]:
# ==========================================
# CELL 6: Import Framework Components
# ==========================================
from transferability_framework.core.cka import CKA, linear_cka
from transferability_framework.core.model_loader import ModelLoader, SUPPORTED_MODELS
from transferability_framework.core.attack_suite import AttackSuite, STANDARD_ATTACKS
from transferability_framework.core.api_clients import OpenAIClient, AnthropicClient

from transferability_framework.experiments.open_model_cka import OpenModelCKAExperiment, CKA_PROMPTS
from transferability_framework.experiments.attack_testing import AttackTestingExperiment
from transferability_framework.experiments.correlation import CorrelationAnalysis

print("[OK] All framework components imported")
print(f"\nSupported open models: {list(SUPPORTED_MODELS.keys())}")
print(f"Standard attacks: {len(STANDARD_ATTACKS)} prompts")

---
## Phase 1: CKA Similarity Matrix

Compute CKA between open-weight models to identify structural similarities.
This phase works without API keys.

In [None]:
# ==========================================
# CELL 7: Configure Phase 1
# ==========================================

# Select models to compare
# Start with fewer models if GPU memory is limited
OPEN_MODELS = [
    "gemma2",    # google/gemma-2-2b-it
    "mistral",   # mistralai/Ministral-3-8B-Instruct-2512
    "llama",     # meta-llama/Llama-3.2-3B-Instruct
    # "phi",     # microsoft/phi-2 (optional)
    # "qwen",    # Qwen/Qwen2.5-3B-Instruct (optional)
]

# Prompts for hidden state extraction
# Using standard diverse prompts
prompts_for_cka = CKA_PROMPTS[:20]  # 20 prompts is usually sufficient

print(f"Models to compare: {OPEN_MODELS}")
print(f"Prompts for CKA: {len(prompts_for_cka)}")

In [None]:
# ==========================================
# CELL 8: Run Phase 1 - CKA Computation
# ==========================================

# Initialize experiment
cka_experiment = OpenModelCKAExperiment(
    hf_token=HF_TOKEN,
    device="auto",
    dtype="bfloat16",
    kernel="linear",
    n_layer_samples=5,  # Sample 5 layers per model for efficiency
)

# Run CKA
print("="*60)
print("PHASE 1: Computing CKA Similarity Matrix")
print("="*60)

cka_matrix = cka_experiment.run(
    model_names=OPEN_MODELS,
    prompts=prompts_for_cka,
    save_path=f"{RESULTS_DIR}/cka_matrix.json",
)

print("\n" + cka_matrix.summary())

---
## Phase 2: Attack Testing

Run standard attack suite on all models (open and closed).

In [None]:
# ==========================================
# CELL 9: Run Phase 2 - Attack Testing (Open Models)
# ==========================================

# Initialize attack testing experiment
attack_experiment = AttackTestingExperiment(
    hf_token=HF_TOKEN,
    openai_key=OPENAI_KEY,
    anthropic_key=ANTHROPIC_KEY,
    device="auto",
    dtype="bfloat16",
)

print("="*60)
print("PHASE 2: Attack Testing")
print("="*60)

# Test open models first
print("\nTesting open-weight models...")
open_attack_results = attack_experiment.run_open_models(
    model_names=OPEN_MODELS,
    save_path=f"{RESULTS_DIR}/open_model_attacks.json",
)

In [None]:
# ==========================================
# CELL 10: Run Phase 2 - Attack Testing (Closed Models)
# ==========================================

# Test closed models if API keys are available
if OPENAI_KEY or ANTHROPIC_KEY:
    print("\nTesting closed-weight models via API...")
    closed_attack_results = attack_experiment.run_closed_models(
        save_path=f"{RESULTS_DIR}/closed_model_attacks.json",
    )
    
    # Merge results
    all_attack_results = open_attack_results
    all_attack_results.models.update(closed_attack_results.models)
else:
    print("\n[INFO] Skipping closed model testing - no API keys")
    all_attack_results = open_attack_results

print("\n" + all_attack_results.summary())

---
## Phase 3: Correlation Analysis

Correlate CKA similarity with attack transferability and generate surrogate recommendations.

In [None]:
# ==========================================
# CELL 11: Run Phase 3 - Correlation Analysis
# ==========================================

# Initialize correlation analysis
correlation = CorrelationAnalysis()

print("="*60)
print("PHASE 3: Correlation Analysis")
print("="*60)

# Identify closed models in results
closed_models = ["openai", "anthropic"] if (OPENAI_KEY or ANTHROPIC_KEY) else []
closed_models = [m for m in closed_models if m in all_attack_results.models]

# Run correlation analysis
report = correlation.run(
    cka_matrix=cka_matrix,
    attack_results=all_attack_results,
    closed_model_keys=closed_models,
    save_path=f"{RESULTS_DIR}/transferability_report.json",
)

print("\n" + report.summary())

---
## Summary and Recommendations

In [None]:
# ==========================================
# CELL 12: Generate Final Summary
# ==========================================

print("="*60)
print("FINAL ANALYSIS SUMMARY")
print("="*60)

# CKA findings
print("\n1. CKA SIMILARITY (Structural)")
print("-" * 40)
m1, m2, sim = cka_matrix.get_most_similar_pair()
print(f"   Most similar pair: {m1} <-> {m2}")
print(f"   CKA similarity: {sim:.3f}")

# Attack findings
print("\n2. ATTACK SUCCESS RATES")
print("-" * 40)
for model, results in sorted(
    all_attack_results.models.items(),
    key=lambda x: x[1].compliance_rate,
    reverse=True
):
    print(f"   {model}: {results.compliance_rate:.1%} compliance")

# Correlation findings
print("\n3. CKA-ATTACK CORRELATION")
print("-" * 40)
print(f"   Pearson correlation: {report.cka_attack_correlation:.3f}")
print(f"   P-value: {report.correlation_p_value:.4f}")

if report.cka_attack_correlation > 0.5:
    print("   Interpretation: CKA is a GOOD predictor of attack transfer")
else:
    print("   Interpretation: CKA has LIMITED predictive value")

# Surrogate recommendations
if report.surrogate_recommendations:
    print("\n4. SURROGATE RECOMMENDATIONS")
    print("-" * 40)
    for target, rec in report.surrogate_recommendations.items():
        print(f"   For {target}: Use {rec.best_surrogate}")
        print(f"      Expected prediction error: {rec.prediction_error:.1%}")

# Files saved
print(f"\n5. RESULTS SAVED TO")
print("-" * 40)
print(f"   {RESULTS_DIR}")
for f in os.listdir(RESULTS_DIR):
    size = os.path.getsize(f"{RESULTS_DIR}/{f}")
    print(f"   - {f} ({size:,} bytes)")

In [None]:
# ==========================================
# CELL 13: Red Team Action Items
# ==========================================

print("="*60)
print("RED TEAM ACTION ITEMS")
print("="*60)

# Based on analysis
print("\n1. SURROGATE STRATEGY")
if report.surrogate_recommendations:
    for target, rec in report.surrogate_recommendations.items():
        print(f"   - Test attacks on {rec.best_surrogate} first")
        print(f"     Then transfer successful attacks to {target}")
        print(f"     Expected transfer accuracy: {1 - rec.prediction_error:.0%}")
else:
    print("   - No closed models tested; use most compliant open model as baseline")

print("\n2. HIGH-VALUE ATTACK CATEGORIES")
# Find categories with highest success
category_success = {}
for model, results in all_attack_results.models.items():
    for cat, metrics in results.category_metrics.items():
        if cat not in category_success:
            category_success[cat] = []
        category_success[cat].append(metrics.get('compliance_rate', 0))

for cat in sorted(category_success.keys(), key=lambda c: max(category_success[c]), reverse=True):
    max_success = max(category_success[cat])
    if max_success > 0:
        print(f"   - {cat}: up to {max_success:.0%} success (focus area)")

print("\n3. NEXT STEPS")
print("   - Develop variations of successful attack categories")
print("   - Test on surrogate model first to save API costs")
print("   - Track CKA changes if models are updated")
print("   - Expand attack suite for categories with low coverage")