In [6]:
import os 
os.environ["https_proxy"] = "http://xen03.iitd.ac.in:3128"
os.environ["http_proxy"] = "http://xen03.iitd.ac.in:3128"

import sys
# sys.path.append('../sae')
from sae import Sae
from utils import *
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
import torch
import torch.nn.functional as F
import os
import pandas as pd
import numpy as np
import time

In [7]:
# import model
model_type = 'gemma2-2b'
layer_num = 20
device = 'cpu'
model, tokenizer, sae = load_model_and_sae(model_type, layer_num, device)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

# Random analysis of the SAE robustness

In [15]:
x1_raw_text = 'The film explores love and trauma through non-linear storytelling, blending magical realism with emotionally raw performances'
x1_raw = tokenizer(x1_raw_text, return_tensors="pt")['input_ids'].to(device)
x1_raw_preprocessed = tokenizer(x1_raw_text + "\nThe previous text is about", 
                                return_tensors="pt")['input_ids'].to(device)
h1_raw = model(x1_raw_preprocessed, output_hidden_states=True).hidden_states[layer_num + 1][0][-1].detach()

x2_raw_text = "The film explores love and trauma through non-linear storytelling, blending magical realism with emotionally raw performancesHacker Encryption implementations"
x2_raw = tokenizer(x2_raw_text, return_tensors="pt")['input_ids'].to(device)
x2_raw_preprocessed = tokenizer(x2_raw_text + "\nThe previous text is about", 
                                return_tensors="pt")['input_ids'].to(device)
h2_raw = model(x2_raw_preprocessed, output_hidden_states=True).hidden_states[layer_num + 1][0][-1].detach()

x3_raw_text = "Encryption secures sensitive digital communication by converting readable data into unreadable ciphertext"
x3_raw = tokenizer(x3_raw_text, return_tensors="pt")['input_ids'].to(device)
x3_raw_preprocessed = tokenizer(x3_raw_text + "\nThe previous text is about", 
                                return_tensors="pt")['input_ids'].to(device)
h3_raw = model(x3_raw_preprocessed, output_hidden_states=True).hidden_states[layer_num + 1][0][-1].detach()

base_raw_sent = "This is a base sentence for comparison"
base_raw = tokenizer(base_raw_sent, return_tensors="pt")['input_ids'].to(device)
base_raw_preprocessed = tokenizer(base_raw_sent + "\nThe previous text is about", 
                                return_tensors="pt")['input_ids'].to(device)
h_base_raw = model(base_raw_preprocessed, output_hidden_states=True).hidden_states[layer_num + 1][0][-1].detach()

In [13]:
h1_raw = F.layer_norm(h1_raw, h1_raw.shape[-1:])
h2_raw = F.layer_norm(h2_raw, h2_raw.shape[-1:])
h3_raw = F.layer_norm(h3_raw, h3_raw.shape[-1:])
h_base_raw = F.layer_norm(h_base_raw, h_base_raw.shape[-1:])

In [16]:
# cosine similarity of h1_raw and h2_raw and h3_raw
cos_sim_h1_h2 = F.cosine_similarity(h1_raw, h2_raw, dim=-1)
cos_sim_h1_h3 = F.cosine_similarity(h1_raw, h3_raw, dim=-1)
cos_sim_h2_h3 = F.cosine_similarity(h2_raw, h3_raw, dim=-1)
print(f"Cosine similarity between h1_raw and h2_raw: {cos_sim_h1_h2.item()}")
print(f"Cosine similarity between h1_raw and h3_raw: {cos_sim_h1_h3.item()}")
print(f"Cosine similarity between h2_raw and h3_raw: {cos_sim_h2_h3.item()}")

print('---' * 20)

# compare h1_raw, h2_raw, h3_raw with base_raw
cos_sim_h1_base = F.cosine_similarity(h1_raw, h_base_raw, dim=-1)
cos_sim_h2_base = F.cosine_similarity(h2_raw, h_base_raw, dim=-1)
cos_sim_h3_base = F.cosine_similarity(h3_raw, h_base_raw, dim=-1)
print(f"Cosine similarity between h1_raw and base_raw: {cos_sim_h1_base.item()}")
print(f"Cosine similarity between h2_raw and base_raw: {cos_sim_h2_base.item()}")
print(f"Cosine similarity between h3_raw and base_raw: {cos_sim_h3_base.item()}")



def polytope(x):
    """
    If value is greater than 0, then 1 else 0
    """
    return (x > 0).float()

h1_plt = polytope(h1_raw)
h2_plt = polytope(h2_raw)
h3_plt = polytope(h3_raw)
base_plt = polytope(h_base_raw)

# Print the shapes of the polytopes
print(f"Shape of h1_plt: {h1_plt.shape}")

# hamming distance
hamming_distance = torch.sum(h1_plt != h2_plt).item()
print(f"Hamming distance between h1 and h2: {hamming_distance}")

hamming_distance = torch.sum(h1_plt != h3_plt).item()
print(f"Hamming distance between h1 and h3: {hamming_distance}")

hamming_distance = torch.sum(h2_plt != h3_plt).item()
print(f"Hamming distance between h2 and h3: {hamming_distance}")

hamming_distance = torch.sum(h1_plt != base_plt).item()
print(f"Hamming distance between h1 and base: {hamming_distance}")

hamming_distance = torch.sum(h2_plt != base_plt).item()
print(f"Hamming distance between h2 and base: {hamming_distance}")

hamming_distance = torch.sum(h3_plt != base_plt).item()
print(f"Hamming distance between h3 and base: {hamming_distance}")

k=170

z1,s1,s1_acts = extract_sae_features(h1_raw, sae, model_type, k)
z2,s2,s2_acts = extract_sae_features(h2_raw, sae, model_type, k)
z3,s3,s3_acts = extract_sae_features(h3_raw, sae, model_type, k)
# value in z1, z2, z3 are not 1 only
# appy polytope
z1_plt = polytope(z1)
z2_plt = polytope(z2)
z3_plt = polytope(z3)

# hamming distance
hamming_distance = torch.sum(z1_plt != z2_plt).item()
print(f"Hamming distance between z1 and z2: {hamming_distance}")
hamming_distance = torch.sum(z1_plt != z3_plt).item()  
print(f"Hamming distance between z1 and z3: {hamming_distance}")
hamming_distance = torch.sum(z2_plt != z3_plt).item()
print(f"Hamming distance between z2 and z3: {hamming_distance}")

# how about get overlap
overlap_score = get_overlap(s1, s2)
print(f"Overlap score between s1 and s2: {overlap_score}")
overlap_score = get_overlap(s1, s3)
print(f"Overlap score between s1 and s3: {overlap_score}")
overlap_score = get_overlap(s2, s3)
print(f"Overlap score between s2 and s3: {overlap_score}")

Cosine similarity between h1_raw and h2_raw: 0.8860180974006653
Cosine similarity between h1_raw and h3_raw: 0.8103485703468323
Cosine similarity between h2_raw and h3_raw: 0.9335434436798096
------------------------------------------------------------
Cosine similarity between h1_raw and base_raw: 0.860113263130188
Cosine similarity between h2_raw and base_raw: 0.8216165900230408
Cosine similarity between h3_raw and base_raw: 0.788619875907898
Shape of h1_plt: torch.Size([2304])
Hamming distance between h1 and h2: 595
Hamming distance between h1 and h3: 743
Hamming distance between h2 and h3: 412
Hamming distance between h1 and base: 643
Hamming distance between h2 and base: 744
Hamming distance between h3 and base: 806
Hamming distance between z1 and z2: 229
Hamming distance between z1 and z3: 234
Hamming distance between z2 and z3: 157
Overlap score between s1 and s2: 0.4000000059604645
Overlap score between s1 and s3: 0.3529411852359772
Overlap score between s2 and s3: 0.6235294342

# Boundary Crossing

# Interpretation and Conclusions
print("=== BOUNDARY CROSSING ANALYSIS CONCLUSIONS ===\n")

if 'batch_results' in locals() and batch_results:
    successful_results = {k: v for k, v in batch_results.items() if v.target_reached}
    
    if successful_results:
        delta_norms = [r.delta_min for r in successful_results.values()]
        mean_boundary_dist = np.mean(delta_norms)
        std_boundary_dist = np.std(delta_norms)
        
        print("1. MARGIN ANALYSIS:")
        print(f"   - Average minimal perturbation to cross boundary: {mean_boundary_dist:.6f}")
        print(f"   - Standard deviation: {std_boundary_dist:.6f}")
        print(f"   - This suggests SAE boundaries are {'close' if mean_boundary_dist < 0.1 else 'distant'}")
        
        # Compare with typical input norms for context
        if 'all_texts' in locals():
            hidden_norms = []
            for text in all_texts[:3]:  # Sample a few to get typical norms
                h = tracer.get_hidden_representation(text, layer_num)
                hidden_norms.append(h.norm().item())
            avg_hidden_norm = np.mean(hidden_norms)
            relative_perturbation = mean_boundary_dist / avg_hidden_norm
            print(f"   - Relative to input magnitude ({avg_hidden_norm:.2f}): {relative_perturbation:.4f} ({relative_perturbation*100:.2f}%)")

if 'concept_results' in locals() and concept_results:
    print("\n2. CONCEPT DISTANCE ANALYSIS:")
    
    concept_distances = []
    boundary_distances = []
    ratios = []
    
    for data in concept_results.values():
        concept_distances.extend([data['a_to_b_distance'], data['b_to_a_distance']])
        boundary_distances.extend([data['a_boundary_distance'], data['b_boundary_distance']])
        if data['distance_ratio_a'] != float('inf'):
            ratios.append(data['distance_ratio_a'])
        if data['distance_ratio_b'] != float('inf'):
            ratios.append(data['distance_ratio_b'])
    
    if ratios:
        mean_ratio = np.mean(ratios)
        print(f"   - Concept distances are {mean_ratio:.1f}x larger than single boundary distances")
        print(f"   - This means crossing from Education→Technology requires crossing ~{mean_ratio:.0f} boundaries")
        
        if mean_ratio > 3:
            print("   - INTERPRETATION: Concepts are well-separated in SAE space")
        elif mean_ratio > 1.5:
            print("   - INTERPRETATION: Moderate concept separation")
        else:
            print("   - INTERPRETATION: Concepts may be close together (potential vulnerability)")

if 'feature_flips' in locals() and feature_flips:
    print("\n3. FEATURE VULNERABILITY:")
    vulnerable_features = len([f for f, counts in feature_flips.items() if counts['total'] >= 2])
    total_features = len(original_code)  # Assuming this is available
    vulnerability_rate = vulnerable_features / total_features if total_features > 0 else 0
    
    print(f"   - {vulnerable_features} out of {total_features} features are consistently vulnerable")
    print(f"   - Vulnerability rate: {vulnerability_rate:.3f} ({vulnerability_rate*100:.1f}%)")
    
    if vulnerability_rate > 0.1:
        print("   - INTERPRETATION: High feature vulnerability - SAE may be susceptible to attacks")
    elif vulnerability_rate > 0.05:
        print("   - INTERPRETATION: Moderate vulnerability")
    else:
        print("   - INTERPRETATION: Low vulnerability - SAE appears robust")

print("\n4. ROBUSTNESS ASSESSMENT:")
if 'successful_results' in locals() and len(successful_results) > 0:
    success_rate = len(successful_results) / len(batch_results) if batch_results else 0
    print(f"   - Boundary search success rate: {success_rate:.2f} ({success_rate*100:.1f}%)")
    
    if success_rate > 0.8:
        print("   - INTERPRETATION: Boundaries are easily found - potential robustness concerns")
    elif success_rate > 0.5:
        print("   - INTERPRETATION: Moderate boundary accessibility")
    else:
        print("   - INTERPRETATION: Boundaries are hard to find - suggests robustness")

print("\n=== RECOMMENDATIONS ===")
print("1. Consider adversarial training to increase boundary distances")
print("2. Monitor vulnerable features during deployment")
print("3. Test with more diverse text pairs to validate generalization")
print("4. Compare with other SAE architectures for robustness benchmarking")

In [None]:
# Research Conclusions and Theoretical Interpretations

print("=== RESEARCH CONCLUSIONS ===")
print()

if 'batch_results' in locals() and batch_results and boundary_ratios:
    avg_boundary_ratio = np.mean(boundary_ratios)
    avg_distance_ratio = np.mean(np.array(hidden_distances) / np.array(sae_distances)) if sae_distances else 0
    
    print("1. POLYTOPE STRUCTURE ANALYSIS:")
    print(f"   • Average boundary crossing ratio: {avg_boundary_ratio:.2f}")
    
    if avg_boundary_ratio > 2:
        print("   ✓ FINDING: Hidden space traversal crosses MANY SAE decision boundaries")
        print("   → INTERPRETATION: SAE creates fine-grained polytope structure")
        print("   → IMPLICATION: Concepts are subdivided into multiple SAE regions")
        print("   → ROBUSTNESS: More boundaries = more potential attack vectors")
    elif avg_boundary_ratio > 1.2:
        print("   ✓ FINDING: Hidden space moderately more complex than SAE space")
        print("   → INTERPRETATION: SAE provides some structural organization")
        print("   → IMPLICATION: Balanced trade-off between compression and structure")
    else:
        print("   ✓ FINDING: Similar boundary complexity in both spaces")
        print("   → INTERPRETATION: SAE preserves geometric structure of hidden space")
        print("   → IMPLICATION: Minimal structural reorganization")
    
    print(f"\n2. DECISION BOUNDARY EFFICIENCY:")
    print(f"   • Average distance ratio: {avg_distance_ratio:.2f}")
    
    if avg_distance_ratio > 1.5:
        print("   ✓ FINDING: Hidden-to-hidden transitions require larger perturbations")
        print("   → INTERPRETATION: SAE provides more direct concept pathways")
        print("   → IMPLICATION: SAE may be MORE robust for concept stability")
        print("   → SURPRISE: Sparse representation offers better concept separation")
    elif avg_distance_ratio < 0.8:
        print("   ✓ FINDING: SAE-to-SAE transitions require larger perturbations")
        print("   → INTERPRETATION: Hidden space has more direct paths")
        print("   → IMPLICATION: SAE may create artificial separation")
    else:
        print("   ✓ FINDING: Similar perturbation requirements in both spaces")
        print("   → INTERPRETATION: Equivalent geometric efficiency")
    
    print(f"\n3. ROBUSTNESS IMPLICATIONS:")
    
    # Analyze vulnerability patterns
    min_boundary_crossing = min(boundary_ratios) if boundary_ratios else 0
    max_boundary_crossing = max(boundary_ratios) if boundary_ratios else 0
    
    print(f"   • Boundary crossing range: {min_boundary_crossing:.2f} to {max_boundary_crossing:.2f}")
    
    if max_boundary_crossing > 5:
        print("   ⚠️  HIGH VULNERABILITY: Some concepts require crossing many boundaries")
        print("   → RISK: Complex attack paths but multiple failure points")
    
    if min_boundary_crossing < 0.5:
        print("   ⚠️  LOW VULNERABILITY: Some concepts very close in SAE space")
        print("   → RISK: Easy concept confusion with small perturbations")
    
    # Success rate analysis
    hidden_success_rate = hidden_success/total if 'hidden_success' in locals() and total > 0 else 0
    sae_success_rate = sae_success/total if 'sae_success' in locals() and total > 0 else 0
    
    print(f"\n4. ATTACK SUCCESS ANALYSIS:")
    print(f"   • Hidden space attack success: {hidden_success_rate:.1%}")
    print(f"   • SAE space attack success: {sae_success_rate:.1%}")
    
    if sae_success_rate > hidden_success_rate:
        print("   ⚠️  SAE space MORE vulnerable to targeted attacks")
        print("   → IMPLICATION: Sparse codes easier to manipulate")
    elif hidden_success_rate > sae_success_rate:
        print("   ✓ SAE space LESS vulnerable to targeted attacks")
        print("   → IMPLICATION: Sparsity provides some protection")
    else:
        print("   → Both spaces show similar attack susceptibility")

print(f"\n5. ANSWERS TO RESEARCH QUESTIONS:")

print(f"\n   Q1: Do code changes align with crossing ReLU hyperplanes?")
if 'boundary_result' in locals():
    print(f"   ✓ YES: DeepFool successfully found exact decision boundaries")
    print(f"   → Boundary distance: {boundary_result.delta_min:.6f}")
    print(f"   → Feature flips: {len(boundary_result.flipped_features)}")

print(f"\n   Q2: How many intermediate boundaries separate concepts?")
if boundary_ratios:
    print(f"   ✓ ANSWER: {avg_boundary_ratio:.1f} SAE boundaries on average")
    print(f"   → Range: {min_boundary_crossing:.1f} to {max_boundary_crossing:.1f}")
    if avg_boundary_ratio > 3:
        print("   → CONCLUSION: Concepts well-separated by multiple boundaries")
    elif avg_boundary_ratio > 1:
        print("   → CONCLUSION: Moderate separation with some intermediate structure")
    else:
        print("   → CONCLUSION: Direct transitions possible between concepts")

print(f"\n6. METHODOLOGICAL VALIDATION:")
print(f"   ✓ DeepFool adaptation successfully finds minimal perturbations")
print(f"   ✓ Boundary counting provides quantitative complexity measure")
print(f"   ✓ Hidden vs SAE comparison reveals structural differences")
print(f"   ✓ Statistical analysis enables generalization beyond single examples")

print(f"\n=== FUTURE RESEARCH DIRECTIONS ===")
print(f"1. Test with larger concept vocabularies (beyond education/technology)")
print(f"2. Analyze boundary density as function of SAE width/sparsity")
print(f"3. Develop adversarial defenses based on boundary distance metrics")
print(f"4. Compare different SAE architectures (standard vs. TopK vs. Gated)")
print(f"5. Investigate relationship between boundary crossing and semantic similarity")

else:
    print("Please run the experiments above to generate conclusions.")

print("\n" + "="*80)

In [None]:
# Experiment 4: Visualization of Boundary Crossing Analysis

print("=== EXPERIMENT 4: VISUALIZATION ===")

# Create comprehensive visualization
if 'batch_results' in locals() and batch_results:
    print("Creating boundary crossing comparison plots...")
    
    # Use the specialized visualization function
    visualize_boundary_comparison(batch_results, save_path='boundary_comparison.png')
    
    # Additional analysis plot
    plt.figure(figsize=(15, 5))
    
    # Plot 1: Boundary trajectory comparison
    plt.subplot(1, 3, 1)
    if boundary_ratios:
        plt.hist(boundary_ratios, bins=max(10, len(boundary_ratios)), alpha=0.7, edgecolor='black')
        plt.axvline(x=1.0, color='red', linestyle='--', alpha=0.7, label='Equal complexity')
        plt.axvline(x=np.mean(boundary_ratios), color='blue', linestyle='-', alpha=0.7, label=f'Mean: {np.mean(boundary_ratios):.2f}')
        plt.xlabel('Boundary Ratio (Hidden/SAE)')
        plt.ylabel('Frequency')
        plt.title('Complexity Ratio Distribution')
        plt.legend()
        plt.grid(True, alpha=0.3)
    
    # Plot 2: Distance vs Boundary Count
    plt.subplot(1, 3, 2)
    if hidden_distances and hidden_boundaries:
        plt.scatter(hidden_distances, hidden_boundaries, alpha=0.7, label='Hidden space', color='blue')
        plt.scatter(sae_distances, sae_boundaries, alpha=0.7, label='SAE space', color='orange')
        plt.xlabel('Perturbation Distance')
        plt.ylabel('Boundaries Crossed')
        plt.title('Distance vs Boundary Complexity')
        plt.legend()
        plt.grid(True, alpha=0.3)
    
    # Plot 3: Success rate comparison
    plt.subplot(1, 3, 3)
    hidden_success = sum(1 for data in batch_results.values() if data['hidden_transition'].target_reached)
    sae_success = sum(1 for data in batch_results.values() if data['sae_transition'].target_reached)
    total = len(batch_results)
    
    categories = ['Hidden\nSpace', 'SAE\nSpace']
    success_rates = [hidden_success/total if total > 0 else 0, sae_success/total if total > 0 else 0]
    colors = ['skyblue', 'orange']
    
    bars = plt.bar(categories, success_rates, alpha=0.7, color=colors)
    plt.ylabel('Success Rate')
    plt.title('Transition Success Rates')
    plt.ylim(0, 1)
    
    # Add value labels on bars
    for bar, rate in zip(bars, success_rates):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                f'{rate:.1%}', ha='center', va='bottom')
    
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Summary statistics table
    print(f"\n=== FINAL SUMMARY TABLE ===")
    print("Metric                    | Hidden Space | SAE Space   | Ratio (H/S)")
    print("-" * 65)
    print(f"Avg. Distance             | {np.mean(hidden_distances):8.4f}   | {np.mean(sae_distances):7.4f}   | {np.mean(hidden_distances)/np.mean(sae_distances):6.2f}")
    print(f"Avg. Boundaries Crossed   | {np.mean(hidden_boundaries):8.1f}   | {np.mean(sae_boundaries):7.1f}   | {np.mean(boundary_ratios):6.2f}")
    print(f"Success Rate              | {hidden_success/total:8.1%}   | {sae_success/total:7.1%}   | {(hidden_success/sae_success if sae_success > 0 else float('inf')):6.2f}")
    print(f"Std. Distance             | {np.std(hidden_distances):8.4f}   | {np.std(sae_distances):7.4f}   | {np.std(hidden_distances)/np.std(sae_distances):6.2f}")

else:
    print("No batch results available for visualization.")
    print("Please run the previous experiments first.")

print("\n" + "="*60)

In [None]:
# Experiment 3: Batch Analysis of Concept Transitions
# Analyze multiple concept pairs to get statistical insights

print("=== EXPERIMENT 3: BATCH CONCEPT TRANSITION ANALYSIS ===")

# Create concept pairs
concept_pairs = [
    (education_texts[0], technology_texts[0]),
    (education_texts[1], technology_texts[1]),
    (education_texts[2], technology_texts[2])
]

print(f"Analyzing {len(concept_pairs)} concept pairs...")
print("This will compare hidden-space vs SAE-space transitions for each pair")
print()

# Run batch analysis
batch_results = tracer.analyze_concept_transitions_batch(
    concept_pairs=concept_pairs,
    layer_idx=layer_num,
    max_iter=100,
    verbose=True
)

print(f"\n=== BATCH ANALYSIS RESULTS ===")
print(f"Successfully analyzed: {len(batch_results)} pairs")

# Collect statistics
hidden_distances = []
sae_distances = []
boundary_ratios = []
hidden_boundaries = []
sae_boundaries = []

for pair_key, data in batch_results.items():
    hidden_res = data['hidden_transition']
    sae_res = data['sae_transition']
    
    hidden_distances.append(hidden_res.delta_min)
    sae_distances.append(sae_res.delta_min)
    hidden_boundaries.append(hidden_res.num_boundaries_crossed)
    sae_boundaries.append(sae_res.num_boundaries_crossed)
    boundary_ratios.append(data['boundary_ratio'])
    
    print(f"\n{pair_key}:")
    print(f"  Hidden: {hidden_res.delta_min:.4f} (boundaries: {hidden_res.num_boundaries_crossed})")
    print(f"  SAE:    {sae_res.delta_min:.4f} (boundaries: {sae_res.num_boundaries_crossed})")
    print(f"  Ratio:  {data['boundary_ratio']:.2f}")

if hidden_distances and sae_distances:
    print(f"\n=== STATISTICAL SUMMARY ===")
    print(f"Hidden space distances:")
    print(f"  - Mean: {np.mean(hidden_distances):.6f}")
    print(f"  - Std:  {np.std(hidden_distances):.6f}")
    print(f"  - Range: {np.min(hidden_distances):.6f} to {np.max(hidden_distances):.6f}")
    
    print(f"\nSAE space distances:")
    print(f"  - Mean: {np.mean(sae_distances):.6f}")
    print(f"  - Std:  {np.std(sae_distances):.6f}")
    print(f"  - Range: {np.min(sae_distances):.6f} to {np.max(sae_distances):.6f}")
    
    print(f"\nBoundary crossings:")
    print(f"  - Hidden space: {np.mean(hidden_boundaries):.1f} ± {np.std(hidden_boundaries):.1f}")
    print(f"  - SAE space: {np.mean(sae_boundaries):.1f} ± {np.std(sae_boundaries):.1f}")
    print(f"  - Average ratio: {np.mean(boundary_ratios):.2f}")
    
    # Key research insights
    avg_boundary_ratio = np.mean(boundary_ratios)
    avg_distance_ratio = np.mean(np.array(hidden_distances) / np.array(sae_distances))
    
    print(f"\n=== RESEARCH INSIGHTS ===")
    print(f"1. BOUNDARY COMPLEXITY:")
    if avg_boundary_ratio > 2:
        print(f"   → Hidden space transitions cross {avg_boundary_ratio:.1f}x more SAE boundaries")
        print("   → SAE creates complex intermediate structure")
    elif avg_boundary_ratio > 1.2:
        print(f"   → Hidden space moderately more complex ({avg_boundary_ratio:.1f}x boundaries)")
        print("   → SAE introduces some structural organization")
    else:
        print(f"   → Similar boundary complexity ({avg_boundary_ratio:.1f}x)")
        print("   → SAE preserves hidden space structure")
    
    print(f"\n2. DISTANCE EFFICIENCY:")
    if avg_distance_ratio > 1.5:
        print(f"   → Hidden transitions require {avg_distance_ratio:.1f}x larger perturbations")
        print("   → SAE provides more direct paths between concepts")
    elif avg_distance_ratio < 0.8:
        print(f"   → SAE transitions require {1/avg_distance_ratio:.1f}x larger perturbations")
        print("   → Hidden space provides more direct paths")
    else:
        print(f"   → Similar perturbation requirements ({avg_distance_ratio:.1f}x)")
        print("   → Both spaces have comparable path efficiency")

print("\n" + "="*60)

In [None]:
# Experiment 2: Concept Transition Analysis
# Compare hidden-to-hidden vs SAE-to-SAE transitions

print("=== EXPERIMENT 2: CONCEPT TRANSITION ANALYSIS ===")
print("Comparing:")
print("1. Hidden space: source_hidden → target_hidden (counting SAE boundary crossings)")  
print("2. SAE space: source_hidden → target_SAE_code")
print()

# Analyze single concept pair
source_text = education_texts[0]
target_text = technology_texts[0]

print(f"Source concept: '{source_text[:50]}...'")
print(f"Target concept: '{target_text[:50]}...'")
print()

# Run targeted concept transition analysis
print("Running targeted transition analysis...")
hidden_result, sae_result = tracer.targeted_concept_transition(
    source_text=source_text,
    target_text=target_text,
    layer_idx=layer_num,
    max_iter=100,  # Reduced for testing
    verbose=True
)

print(f"\n=== TRANSITION COMPARISON RESULTS ===")

print(f"\n1. HIDDEN SPACE TRANSITION (source_hidden → target_hidden):")
print(f"   - Perturbation magnitude: {hidden_result.delta_min:.6f}")
print(f"   - SAE boundaries crossed: {hidden_result.num_boundaries_crossed}")
print(f"   - Target reached: {hidden_result.target_reached}")
print(f"   - Optimization steps: {hidden_result.num_steps}")

print(f"\n2. SAE SPACE TRANSITION (source_hidden → target_SAE_code):")
print(f"   - Perturbation magnitude: {sae_result.delta_min:.6f}")
print(f"   - Boundaries crossed: {sae_result.num_boundaries_crossed}")
print(f"   - Target reached: {sae_result.target_reached}")
print(f"   - Optimization steps: {sae_result.num_steps}")

# Key insight: boundary crossing comparison
if hidden_result.num_boundaries_crossed > 0 and sae_result.num_boundaries_crossed > 0:
    boundary_ratio = hidden_result.num_boundaries_crossed / sae_result.num_boundaries_crossed
    print(f"\n=== KEY INSIGHT ===")
    print(f"Boundary crossing ratio (Hidden/SAE): {boundary_ratio:.2f}")
    
    if boundary_ratio > 2:
        print("→ Hidden space requires crossing MANY more SAE boundaries")
        print("  This suggests SAE creates a complex intermediate representation")
    elif boundary_ratio > 1.2:
        print("→ Hidden space requires moderately more SAE boundaries")
        print("  This suggests some intermediate structure in SAE space")
    else:
        print("→ Hidden and SAE space have similar boundary complexity")
        print("  This suggests SAE preserves the structure of hidden space")

# Compare perturbation magnitudes
distance_ratio = hidden_result.delta_min / sae_result.delta_min if sae_result.delta_min > 0 else float('inf')
print(f"\nDistance ratio (Hidden/SAE): {distance_ratio:.2f}")

if distance_ratio > 1.5:
    print("→ Hidden space transitions require LARGER perturbations")
elif distance_ratio < 0.8:
    print("→ SAE space transitions require LARGER perturbations")
else:
    print("→ Both transitions require similar perturbation magnitudes")

print("\n" + "="*60)

NameError: name 'education_texts' is not defined

In [None]:
# Experiment 1: DeepFool-style Nearest Boundary Search
# Find minimal perturbation to cross the nearest SAE decision boundary

print("=== EXPERIMENT 1: NEAREST BOUNDARY ANALYSIS ===")

# Test with a single example first
test_text = education_texts[0]
print(f"Testing with: '{test_text}'")

# Extract hidden representation
hidden_state = tracer.get_hidden_representation(test_text, layer_num)
print(f"Hidden state shape: {hidden_state.shape}")

# Get original SAE code
original_code = tracer.get_sae_code(hidden_state)
original_active = (original_code > 0).sum().item()
print(f"Original code has {original_active} active features out of {len(original_code)}")

# Find nearest boundary using DeepFool approach
print("\nFinding nearest decision boundary...")
boundary_result = tracer.find_nearest_boundary_deepfool(
    hidden_state, 
    max_iter=50,
    overshoot=0.02,
    verbose=True
)

print(f"\n=== BOUNDARY ANALYSIS RESULTS ===")
print(f"Minimal perturbation norm ||δ_min||: {boundary_result.delta_min:.6f}")
print(f"Boundary successfully crossed: {boundary_result.target_reached}")
print(f"Number of DeepFool iterations: {boundary_result.num_steps}")
print(f"Features that flipped: {len(boundary_result.flipped_features)}")

if boundary_result.flipped_features:
    print("Feature flip details:")
    for feature_idx, flip_type in boundary_result.flipped_features[:5]:  # Show first 5
        print(f"  - Feature {feature_idx}: {flip_type}")

# Compare original vs perturbed codes
print(f"\nOriginal active features: {original_active}")
perturbed_active = (boundary_result.perturbed_code > 0).sum().item()
print(f"Perturbed active features: {perturbed_active}")
print(f"Net change in active features: {perturbed_active - original_active}")

# Analyze perturbation direction
perturbation_norm = boundary_result.perturbation.norm().item()
hidden_norm = hidden_state.norm().item()
relative_perturbation = perturbation_norm / hidden_norm
print(f"\nPerturbation analysis:")
print(f"  - Perturbation magnitude: {perturbation_norm:.6f}")
print(f"  - Hidden state magnitude: {hidden_norm:.2f}")
print(f"  - Relative perturbation: {relative_perturbation:.4f} ({relative_perturbation*100:.2f}%)")

print("\n" + "="*60)

In [None]:
# Import the improved boundary crossing analysis
from boundary_crossing_fixed import SAEBoundaryTracer, visualize_boundary_comparison

# Initialize the boundary tracer
tracer = SAEBoundaryTracer(sae, tokenizer, model, device=device)

# Example texts for concept transition analysis
education_texts = [
    "The university offers comprehensive programs in computer science and engineering.",
    "Students can pursue advanced degrees in mathematics and physics.",
    "The curriculum includes hands-on laboratory experiences and research opportunities."
]

technology_texts = [
    "The new smartphone features advanced AI capabilities and improved battery life.",
    "Cloud computing platforms enable scalable and efficient data processing.",
    "Machine learning algorithms are revolutionizing industries across the globe."
]

print("=== IMPROVED SAE BOUNDARY CROSSING ANALYSIS ===")
print("This implementation addresses:")
print("1. DeepFool-style minimal perturbation finding")
print("2. Hidden-to-hidden vs SAE-to-SAE transition comparison")
print("3. Boundary counting for intermediate crossings")
print("4. Exact decision boundary identification")
print()
print("Setup complete!")

In [None]:
# Concept Distance Analysis: Education vs Technology
# Measure how many boundaries need to be crossed to go from one concept to another

concept_pairs = [
    (education_texts[0], technology_texts[0]),
]

print("Analyzing concept distances between education and technology...")
print(f"Number of concept pairs: {len(concept_pairs)}")

# Run concept distance analysis
concept_results = tracer.concept_distance_analysis(
    concept_pairs,
    layer_idx=layer_num,
    max_iterations=100  # More iterations for targeted search
)

print(f"\nConcept Distance Results:")
for pair_key, data in concept_results.items():
    print(f"\n{pair_key}:")
    print(f"  Education → Technology: {data['a_to_b_distance']:.6f} (reached: {data['a_to_b_reached']})")
    print(f"  Technology → Education: {data['b_to_a_distance']:.6f} (reached: {data['b_to_a_reached']})")
    print(f"  Education boundary distance: {data['a_boundary_distance']:.6f}")
    print(f"  Technology boundary distance: {data['b_boundary_distance']:.6f}")
    
    if data['distance_ratio_a'] != float('inf'):
        print(f"  Education concept distance ratio: {data['distance_ratio_a']:.2f}x single boundary")
    if data['distance_ratio_b'] != float('inf'):
        print(f"  Technology concept distance ratio: {data['distance_ratio_b']:.2f}x single boundary")

# Summary statistics
all_concept_distances = []
all_boundary_distances = []
all_ratios = []

for data in concept_results.values():
    all_concept_distances.extend([data['a_to_b_distance'], data['b_to_a_distance']])
    all_boundary_distances.extend([data['a_boundary_distance'], data['b_boundary_distance']])
    if data['distance_ratio_a'] != float('inf'):
        all_ratios.append(data['distance_ratio_a'])
    if data['distance_ratio_b'] != float('inf'):
        all_ratios.append(data['distance_ratio_b'])

print(f"\nOverall Summary:")
print(f"- Mean concept distance: {np.mean(all_concept_distances):.6f}")
print(f"- Mean single boundary distance: {np.mean(all_boundary_distances):.6f}")
if all_ratios:
    print(f"- Mean distance ratio: {np.mean(all_ratios):.2f}x")
    print(f"- This suggests concepts are {np.mean(all_ratios):.1f}x farther apart than single boundaries")

In [None]:
# CORRECTED: Hidden Space Polytope Boundary Analysis
# This fixes the fundamental error in _targeted_transition_hidden

print("=== CORRECTED HIDDEN SPACE POLYTOPE ANALYSIS ===")
print("ISSUE FOUND: Previous implementation tracked SAE boundary crossings")
print("CORRECTION: Now tracking ReLU polytope boundaries in HIDDEN space")
print()

def analyze_hidden_polytope_transition(
    source_hidden: torch.Tensor,
    target_hidden: torch.Tensor,
    max_iter: int = 200,
    step_size: float = 0.01,
    verbose: bool = True
) -> dict:
    """
    Correct implementation: Track ReLU polytope boundaries in HIDDEN space.
    
    The polytope in hidden space is defined by h_i > 0 for each dimension i.
    We count how many times we cross hyperplanes h_i = 0 during optimization.
    """
    
    delta = torch.zeros_like(source_hidden, requires_grad=True)
    optimizer = torch.optim.Adam([delta], lr=step_size)
    
    # Track HIDDEN space polytope changes (not SAE)
    original_hidden_polytope = (source_hidden > 0).float()
    previous_hidden_polytope = original_hidden_polytope.clone()
    boundary_count = 0
    trajectory = []
    
    best_delta = None
    best_loss = float('inf')
    
    for step in range(max_iter):
        optimizer.zero_grad()
        
        current_hidden = source_hidden + delta
        
        # Count ReLU polytope boundary crossings in HIDDEN space
        current_hidden_polytope = (current_hidden > 0).float()
        
        # Check if we crossed any hyperplane h_i = 0
        if not torch.equal(current_hidden_polytope, previous_hidden_polytope):
            boundary_count += 1
            
            # Identify which dimensions crossed zero
            differences = current_hidden_polytope - previous_hidden_polytope
            activated_dims = (differences > 0).nonzero(as_tuple=True)[0]
            deactivated_dims = (differences < 0).nonzero(as_tuple=True)[0]
            
            if verbose and step % 20 == 0:
                print(f"Step {step}: Hidden polytope boundary #{boundary_count}")
                if len(activated_dims) > 0:
                    print(f"  Activated dimensions: {activated_dims[:5].tolist()}...")
                if len(deactivated_dims) > 0:
                    print(f"  Deactivated dimensions: {deactivated_dims[:5].tolist()}...")
        
        previous_hidden_polytope = current_hidden_polytope.clone()
        
        # Loss: L2 distance to target hidden state
        loss = F.mse_loss(current_hidden, target_hidden)
        
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_delta = delta.detach().clone()
        
        # Check convergence
        if loss.item() < 1e-6:
            if verbose:
                print(f"Target reached at step {step}")
            break
            
        loss.backward()
        optimizer.step()
        trajectory.append(delta.detach().clone())
        
        if verbose and step % 50 == 0:
            print(f"Step {step}: loss = {loss.item():.6f}, ||δ|| = {delta.norm().item():.6f}, boundaries = {boundary_count}")
    
    return {
        'delta_min': best_delta.norm().item(),
        'perturbation': best_delta,
        'num_steps': step,
        'hidden_polytope_boundaries_crossed': boundary_count,
        'target_reached': best_loss < 1e-3,
        'final_loss': best_loss,
        'trajectory': trajectory
    }

# Test with single concept pair
source_text = education_texts[0]
target_text = technology_texts[0]

print(f"Source: '{source_text[:50]}...'")
print(f"Target: '{target_text[:50]}...'")
print()

# Get hidden representations
source_hidden = tracer.get_hidden_representation(source_text, layer_num)
target_hidden = tracer.get_hidden_representation(target_text, layer_num)

print(f"Source hidden shape: {source_hidden.shape}")
print(f"Target hidden shape: {target_hidden.shape}")

# Original polytope analysis
source_polytope = (source_hidden > 0).float()
target_polytope = (target_hidden > 0).float()
polytope_hamming = (source_polytope != target_polytope).sum().item()

print(f"Source polytope active dims: {source_polytope.sum().item()}/{len(source_polytope)}")
print(f"Target polytope active dims: {target_polytope.sum().item()}/{len(target_polytope)}")
print(f"Polytope Hamming distance: {polytope_hamming}")
print()

# Run corrected analysis
print("Running CORRECTED hidden polytope transition analysis...")
result = analyze_hidden_polytope_transition(
    source_hidden, target_hidden, 
    max_iter=100, 
    verbose=True
)

print(f"\n=== CORRECTED RESULTS ===")
print(f"Minimal perturbation: {result['delta_min']:.6f}")
print(f"Hidden polytope boundaries crossed: {result['hidden_polytope_boundaries_crossed']}")
print(f"Target reached: {result['target_reached']}")
print(f"Final loss: {result['final_loss']:.6f}")
print(f"Optimization steps: {result['num_steps']}")

# Compare with SAE boundary counting for perspective
original_sae_code = tracer.get_sae_code(source_hidden)
final_sae_code = tracer.get_sae_code(source_hidden + result['perturbation'])
sae_changes = ((original_sae_code > 0).float() != (final_sae_code > 0).float()).sum().item()

print(f"\nFor comparison:")
print(f"SAE feature changes during this transition: {sae_changes}")
print(f"Hidden polytope boundaries crossed: {result['hidden_polytope_boundaries_crossed']}")
print(f"Ratio (Hidden boundaries / SAE changes): {result['hidden_polytope_boundaries_crossed'] / max(sae_changes, 1):.2f}")

print("\n" + "="*60)
print("KEY INSIGHT:")
print("Now we're properly measuring transitions between")
print("ReLU polytopes in the HIDDEN space (h_i > 0),")
print("not SAE feature boundaries!")
print("="*60)

In [None]:
# COMPARISON: Wrong vs Correct Boundary Tracking
# Demonstrates the difference between the flawed and correct approaches

print("=== COMPARISON: WRONG vs CORRECT BOUNDARY TRACKING ===")
print()

def analyze_transition_wrong_way(
    source_hidden: torch.Tensor,
    target_hidden: torch.Tensor,
    sae_model,
    max_iter: int = 100,
    step_size: float = 0.01
) -> dict:
    """
    WRONG WAY: Track SAE boundary crossings during hidden-to-hidden transition.
    This is what the original implementation was doing incorrectly.
    """
    
    delta = torch.zeros_like(source_hidden, requires_grad=True)
    optimizer = torch.optim.Adam([delta], lr=step_size)
    
    # WRONG: Track SAE code changes during hidden space transition
    original_sae_code = tracer.get_sae_code(source_hidden)
    previous_sae_code = original_sae_code.clone()
    sae_boundary_count = 0
    
    best_delta = None
    best_loss = float('inf')
    
    for step in range(max_iter):
        optimizer.zero_grad()
        
        current_hidden = source_hidden + delta
        
        # WRONG: Monitor SAE code changes instead of hidden polytope
        current_sae_code = tracer.get_sae_code(current_hidden)
        if not torch.equal((current_sae_code > 0).float(), (previous_sae_code > 0).float()):
            sae_boundary_count += 1
        
        previous_sae_code = current_sae_code.clone()
        
        # Loss: L2 distance to target
        loss = F.mse_loss(current_hidden, target_hidden)
        
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_delta = delta.detach().clone()
        
        if loss.item() < 1e-6:
            break
            
        loss.backward()
        optimizer.step()
    
    return {
        'delta_min': best_delta.norm().item(),
        'sae_boundaries_crossed': sae_boundary_count,  # WRONG metric
        'target_reached': best_loss < 1e-3,
        'num_steps': step
    }

def analyze_transition_correct_way(
    source_hidden: torch.Tensor,
    target_hidden: torch.Tensor,
    max_iter: int = 100,
    step_size: float = 0.01
) -> dict:
    """
    CORRECT WAY: Track hidden polytope boundary crossings (h_i = 0).
    """
    
    delta = torch.zeros_like(source_hidden, requires_grad=True)
    optimizer = torch.optim.Adam([delta], lr=step_size)
    
    # CORRECT: Track hidden polytope changes
    original_hidden_polytope = (source_hidden > 0).float()
    previous_hidden_polytope = original_hidden_polytope.clone()
    hidden_polytope_boundary_count = 0
    
    best_delta = None
    best_loss = float('inf')
    
    for step in range(max_iter):
        optimizer.zero_grad()
        
        current_hidden = source_hidden + delta
        
        # CORRECT: Monitor hidden polytope changes (h_i > 0)
        current_hidden_polytope = (current_hidden > 0).float()
        if not torch.equal(current_hidden_polytope, previous_hidden_polytope):
            hidden_polytope_boundary_count += 1
        
        previous_hidden_polytope = current_hidden_polytope.clone()
        
        # Loss: L2 distance to target
        loss = F.mse_loss(current_hidden, target_hidden)
        
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_delta = delta.detach().clone()
        
        if loss.item() < 1e-6:
            break
            
        loss.backward()
        optimizer.step()
    
    return {
        'delta_min': best_delta.norm().item(),
        'hidden_polytope_boundaries_crossed': hidden_polytope_boundary_count,  # CORRECT metric
        'target_reached': best_loss < 1e-3,
        'num_steps': step
    }

print("Testing both approaches on the same transition...")
print(f"Education → Technology concept transition")
print()

# Run wrong approach
print("1. WRONG APPROACH (tracking SAE boundaries):")
wrong_result = analyze_transition_wrong_way(
    source_hidden, target_hidden, tracer.sae
)
print(f"   - Perturbation magnitude: {wrong_result['delta_min']:.6f}")
print(f"   - SAE boundaries crossed: {wrong_result['sae_boundaries_crossed']}")
print(f"   - Target reached: {wrong_result['target_reached']}")

# Run correct approach  
print("\n2. CORRECT APPROACH (tracking hidden polytope boundaries):")
correct_result = analyze_transition_correct_way(
    source_hidden, target_hidden
)
print(f"   - Perturbation magnitude: {correct_result['delta_min']:.6f}")
print(f"   - Hidden polytope boundaries crossed: {correct_result['hidden_polytope_boundaries_crossed']}")
print(f"   - Target reached: {correct_result['target_reached']}")

print(f"\n=== CRITICAL COMPARISON ===")
print(f"Wrong method counted: {wrong_result['sae_boundaries_crossed']} 'boundaries'")
print(f"Correct method counted: {correct_result['hidden_polytope_boundaries_crossed']} boundaries")
print()
print("EXPLANATION:")
print("- Wrong method: Counts when SAE features flip during hidden→hidden transition")
print("- Correct method: Counts when hidden dimensions cross h_i = 0 hyperplanes")
print()
print("The wrong method conflates two different spaces:")
print("  • Hidden space polytopes (defined by h_i > 0)")  
print("  • SAE feature space (defined by f_i > 0 where f = ReLU(Wh + b))")
print()
print("For polytope analysis, we want HIDDEN space boundaries, not SAE boundaries!")
print()

# Show what the numbers mean
print("=== INTERPRETATION ===")
if correct_result['hidden_polytope_boundaries_crossed'] > 0:
    print(f"The transition from education→technology concepts")
    print(f"requires crossing {correct_result['hidden_polytope_boundaries_crossed']} hyperplanes in the hidden space")
    print(f"where individual hidden dimensions change sign (h_i: positive ↔ negative)")
else:
    print("The transition stays within the same polytope region!")
    print("(No hidden dimensions changed sign)")
print()

print("This is the CORRECT way to measure polytope complexity!")
print("="*60)