In [42]:
import os 
os.environ["https_proxy"] = "http://xen03.iitd.ac.in:3128"
os.environ["http_proxy"] = "http://xen03.iitd.ac.in:3128"

import sys
# sys.path.append('../sae')
from sae import Sae
from utils import *
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
import torch
import torch.nn.functional as F
import os
import pandas as pd
import numpy as np
import time

In [4]:
# import model
model_type = 'gemma2-2b'
layer_num = 20
device = 'cpu'
model, tokenizer, sae = load_model_and_sae(model_type, layer_num, device)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

# Interpretation and Conclusions
print("=== BOUNDARY CROSSING ANALYSIS CONCLUSIONS ===\n")

if 'batch_results' in locals() and batch_results:
    successful_results = {k: v for k, v in batch_results.items() if v.target_reached}
    
    if successful_results:
        delta_norms = [r.delta_min for r in successful_results.values()]
        mean_boundary_dist = np.mean(delta_norms)
        std_boundary_dist = np.std(delta_norms)
        
        print("1. MARGIN ANALYSIS:")
        print(f"   - Average minimal perturbation to cross boundary: {mean_boundary_dist:.6f}")
        print(f"   - Standard deviation: {std_boundary_dist:.6f}")
        print(f"   - This suggests SAE boundaries are {'close' if mean_boundary_dist < 0.1 else 'distant'}")
        
        # Compare with typical input norms for context
        if 'all_texts' in locals():
            hidden_norms = []
            for text in all_texts[:3]:  # Sample a few to get typical norms
                h = tracer.get_hidden_representation(text, layer_num)
                hidden_norms.append(h.norm().item())
            avg_hidden_norm = np.mean(hidden_norms)
            relative_perturbation = mean_boundary_dist / avg_hidden_norm
            print(f"   - Relative to input magnitude ({avg_hidden_norm:.2f}): {relative_perturbation:.4f} ({relative_perturbation*100:.2f}%)")

if 'concept_results' in locals() and concept_results:
    print("\n2. CONCEPT DISTANCE ANALYSIS:")
    
    concept_distances = []
    boundary_distances = []
    ratios = []
    
    for data in concept_results.values():
        concept_distances.extend([data['a_to_b_distance'], data['b_to_a_distance']])
        boundary_distances.extend([data['a_boundary_distance'], data['b_boundary_distance']])
        if data['distance_ratio_a'] != float('inf'):
            ratios.append(data['distance_ratio_a'])
        if data['distance_ratio_b'] != float('inf'):
            ratios.append(data['distance_ratio_b'])
    
    if ratios:
        mean_ratio = np.mean(ratios)
        print(f"   - Concept distances are {mean_ratio:.1f}x larger than single boundary distances")
        print(f"   - This means crossing from Education→Technology requires crossing ~{mean_ratio:.0f} boundaries")
        
        if mean_ratio > 3:
            print("   - INTERPRETATION: Concepts are well-separated in SAE space")
        elif mean_ratio > 1.5:
            print("   - INTERPRETATION: Moderate concept separation")
        else:
            print("   - INTERPRETATION: Concepts may be close together (potential vulnerability)")

if 'feature_flips' in locals() and feature_flips:
    print("\n3. FEATURE VULNERABILITY:")
    vulnerable_features = len([f for f, counts in feature_flips.items() if counts['total'] >= 2])
    total_features = len(original_code)  # Assuming this is available
    vulnerability_rate = vulnerable_features / total_features if total_features > 0 else 0
    
    print(f"   - {vulnerable_features} out of {total_features} features are consistently vulnerable")
    print(f"   - Vulnerability rate: {vulnerability_rate:.3f} ({vulnerability_rate*100:.1f}%)")
    
    if vulnerability_rate > 0.1:
        print("   - INTERPRETATION: High feature vulnerability - SAE may be susceptible to attacks")
    elif vulnerability_rate > 0.05:
        print("   - INTERPRETATION: Moderate vulnerability")
    else:
        print("   - INTERPRETATION: Low vulnerability - SAE appears robust")

print("\n4. ROBUSTNESS ASSESSMENT:")
if 'successful_results' in locals() and len(successful_results) > 0:
    success_rate = len(successful_results) / len(batch_results) if batch_results else 0
    print(f"   - Boundary search success rate: {success_rate:.2f} ({success_rate*100:.1f}%)")
    
    if success_rate > 0.8:
        print("   - INTERPRETATION: Boundaries are easily found - potential robustness concerns")
    elif success_rate > 0.5:
        print("   - INTERPRETATION: Moderate boundary accessibility")
    else:
        print("   - INTERPRETATION: Boundaries are hard to find - suggests robustness")

print("\n=== RECOMMENDATIONS ===")
print("1. Consider adversarial training to increase boundary distances")
print("2. Monitor vulnerable features during deployment")
print("3. Test with more diverse text pairs to validate generalization")
print("4. Compare with other SAE architectures for robustness benchmarking")

In [None]:
# Analysis: Which SAE features are most vulnerable to boundary crossings?
# This helps identify which features are closest to decision boundaries

if 'batch_results' in locals() and batch_results:
    print("Analyzing feature vulnerability...")
    
    # Collect all flipped features across all successful boundary crossings
    feature_flips = {}
    for result in batch_results.values():
        if result.target_reached:
            for feature_idx, flip_type in result.flipped_features:
                if feature_idx not in feature_flips:
                    feature_flips[feature_idx] = {'activate': 0, 'deactivate': 0, 'total': 0}
                feature_flips[feature_idx][flip_type] += 1
                feature_flips[feature_idx]['total'] += 1
    
    if feature_flips:
        # Sort features by total flip frequency
        sorted_features = sorted(feature_flips.items(), key=lambda x: x[1]['total'], reverse=True)
        
        print(f"\nMost Vulnerable Features (top 10):")
        print("Feature ID | Total Flips | Activations | Deactivations")
        print("-" * 55)
        
        for i, (feature_idx, counts) in enumerate(sorted_features[:10]):
            print(f"{feature_idx:9d} | {counts['total']:11d} | {counts['activate']:11d} | {counts['deactivate']:13d}")
        
        # Analyze flip patterns
        total_features_flipped = len(feature_flips)
        total_flips = sum(counts['total'] for counts in feature_flips.values())
        
        print(f"\nFeature Vulnerability Summary:")
        print(f"- Total unique features that flipped: {total_features_flipped}")
        print(f"- Total feature flips across all texts: {total_flips}")
        print(f"- Average flips per vulnerable feature: {total_flips/total_features_flipped:.2f}")
        
        # Check if some features are consistently more vulnerable
        frequent_flippers = [f for f, counts in feature_flips.items() if counts['total'] >= 2]
        print(f"- Features that flipped multiple times: {len(frequent_flippers)}")
        
        if frequent_flippers:
            print(f"- Most vulnerable feature: {sorted_features[0][0]} (flipped {sorted_features[0][1]['total']} times)")
    else:
        print("No feature flips found in the results.")
else:
    print("No batch results available for feature vulnerability analysis.")

In [None]:
# Visualize the boundary analysis results
import matplotlib.pyplot as plt

print("Creating visualizations...")

# Visualize batch boundary analysis
if 'batch_results' in locals() and batch_results:
    plt.figure(figsize=(15, 5))
    
    # Boundary distance distribution
    plt.subplot(1, 3, 1)
    successful_results = {k: v for k, v in batch_results.items() if v.target_reached}
    if successful_results:
        delta_norms = [r.delta_min for r in successful_results.values()]
        plt.hist(delta_norms, bins=10, alpha=0.7, edgecolor='black')
        plt.xlabel('Minimal Perturbation Norm ||δ_min||')
        plt.ylabel('Frequency')
        plt.title('Distribution of Boundary Distances')
        plt.grid(True, alpha=0.3)
    
    # Feature flips distribution
    plt.subplot(1, 3, 2)
    if successful_results:
        num_flips = [len(r.flipped_features) for r in successful_results.values()]
        if num_flips and max(num_flips) > 0:
            plt.hist(num_flips, bins=max(num_flips)+1, alpha=0.7, edgecolor='black')
        plt.xlabel('Number of Feature Flips')
        plt.ylabel('Frequency')
        plt.title('Feature Flips at Boundaries')
        plt.grid(True, alpha=0.3)
    
    # Success rate
    plt.subplot(1, 3, 3)
    total_texts = len(batch_results)
    successful_texts = len(successful_results)
    labels = ['Successful', 'Failed']
    sizes = [successful_texts, total_texts - successful_texts]
    colors = ['lightgreen', 'lightcoral']
    plt.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
    plt.title('Boundary Search Success Rate')
    
    plt.tight_layout()
    plt.show()

# Visualize concept distances if available
if 'concept_results' in locals() and concept_results:
    visualize_concept_distances(concept_results, save_path='concept_distances.png')

print("Visualization complete!")

In [None]:
# Concept Distance Analysis: Education vs Technology
# Measure how many boundaries need to be crossed to go from one concept to another

concept_pairs = [
    (education_texts[0], technology_texts[0]),
    (education_texts[1], technology_texts[1]),
    (education_texts[2], technology_texts[2])
]

print("Analyzing concept distances between education and technology...")
print(f"Number of concept pairs: {len(concept_pairs)}")

# Run concept distance analysis
concept_results = tracer.concept_distance_analysis(
    concept_pairs,
    layer_idx=layer_num,
    max_iterations=100  # More iterations for targeted search
)

print(f"\nConcept Distance Results:")
for pair_key, data in concept_results.items():
    print(f"\n{pair_key}:")
    print(f"  Education → Technology: {data['a_to_b_distance']:.6f} (reached: {data['a_to_b_reached']})")
    print(f"  Technology → Education: {data['b_to_a_distance']:.6f} (reached: {data['b_to_a_reached']})")
    print(f"  Education boundary distance: {data['a_boundary_distance']:.6f}")
    print(f"  Technology boundary distance: {data['b_boundary_distance']:.6f}")
    
    if data['distance_ratio_a'] != float('inf'):
        print(f"  Education concept distance ratio: {data['distance_ratio_a']:.2f}x single boundary")
    if data['distance_ratio_b'] != float('inf'):
        print(f"  Technology concept distance ratio: {data['distance_ratio_b']:.2f}x single boundary")

# Summary statistics
all_concept_distances = []
all_boundary_distances = []
all_ratios = []

for data in concept_results.values():
    all_concept_distances.extend([data['a_to_b_distance'], data['b_to_a_distance']])
    all_boundary_distances.extend([data['a_boundary_distance'], data['b_boundary_distance']])
    if data['distance_ratio_a'] != float('inf'):
        all_ratios.append(data['distance_ratio_a'])
    if data['distance_ratio_b'] != float('inf'):
        all_ratios.append(data['distance_ratio_b'])

print(f"\nOverall Summary:")
print(f"- Mean concept distance: {np.mean(all_concept_distances):.6f}")
print(f"- Mean single boundary distance: {np.mean(all_boundary_distances):.6f}")
if all_ratios:
    print(f"- Mean distance ratio: {np.mean(all_ratios):.2f}x")
    print(f"- This suggests concepts are {np.mean(all_ratios):.1f}x farther apart than single boundaries")

In [1]:
# Batch analysis for multiple texts
all_texts = education_texts + technology_texts

print(f"Analyzing {len(all_texts)} texts for boundary distances...")

# Run batch analysis (this might take a while)
batch_results = tracer.analyze_margins_batch(
    all_texts, 
    layer_idx=layer_num,
    max_iterations=50,  # Reduced for faster testing
    verbose=True
)

print(f"\nBatch Analysis Summary:")
print(f"- Total texts processed: {len(batch_results)}")

# Analyze the results
successful_results = {k: v for k, v in batch_results.items() if v.target_reached}
print(f"- Successful boundary crossings: {len(successful_results)}")

if successful_results:
    delta_norms = [r.delta_min for r in successful_results.values()]
    print(f"- Mean boundary distance: {np.mean(delta_norms):.6f}")
    print(f"- Std boundary distance: {np.std(delta_norms):.6f}")
    print(f"- Min boundary distance: {np.min(delta_norms):.6f}")
    print(f"- Max boundary distance: {np.max(delta_norms):.6f}")
    
    # Count feature flips
    all_flips = [len(r.flipped_features) for r in successful_results.values()]
    print(f"- Mean feature flips: {np.mean(all_flips):.2f}")
    print(f"- Total unique flipped features: {len(set(f for r in successful_results.values() for f, _ in r.flipped_features))}")
else:
    print("- No successful boundary crossings found")

NameError: name 'education_texts' is not defined

# Experiment 1: Measuring Margins and Boundary Crossings
# Find minimal perturbations that cause SAE code changes

# Test with a single example first
test_text = education_texts[0]
print(f"Testing with: '{test_text}'")

# Extract hidden representation
hidden_state = tracer.get_hidden_representation(test_text, layer_num)
print(f"Hidden state shape: {hidden_state.shape}")

# Get original SAE code
original_code = tracer.get_sae_code(hidden_state)
original_active = (original_code > 0).sum().item()
print(f"Original code has {original_active} active features out of {len(original_code)}")

# Find minimal boundary perturbation
result = tracer.find_minimal_boundary_perturbation(
    hidden_state, 
    max_iterations=50,  # Start with fewer iterations for testing
    verbose=True
)

print(f"\nBoundary Analysis Results:")
print(f"- Minimal perturbation norm: {result.delta_min:.6f}")
print(f"- Boundary reached: {result.target_reached}")
print(f"- Number of steps: {result.num_steps}")
print(f"- Features flipped: {result.flipped_features}")
print(f"- Boundary type: {result.boundary_type}")

In [None]:
# Example: Boundary Crossing Analysis for SAE Robustness
# This demonstrates how to use the boundary_crossing module

from boundary_crossing import SAEBoundaryTracer, visualize_boundary_analysis, visualize_concept_distances

# Initialize the boundary tracer
tracer = SAEBoundaryTracer(sae, tokenizer, model, device=device)

# Example texts from Li et al. study
education_texts = [
    "The university offers comprehensive programs in computer science and engineering.",
    "Students can pursue advanced degrees in mathematics and physics.",
    "The curriculum includes hands-on laboratory experiences and research opportunities."
]

technology_texts = [
    "The new smartphone features advanced AI capabilities and improved battery life.",
    "Cloud computing platforms enable scalable and efficient data processing.",
    "Machine learning algorithms are revolutionizing industries across the globe."
]

print("Boundary crossing analysis setup complete!")
print(f"Education texts: {len(education_texts)}")
print(f"Technology texts: {len(technology_texts)}")

Cosine similarity between h1_raw and h2_raw: 0.8860180974006653
Cosine similarity between h1_raw and h3_raw: 0.8103485703468323
Cosine similarity between h2_raw and h3_raw: 0.9335434436798096
------------------------------------------------------------
Cosine similarity between h1_raw and base_raw: 0.860113263130188
Cosine similarity between h2_raw and base_raw: 0.8216165900230408
Cosine similarity between h3_raw and base_raw: 0.788619875907898


In [45]:
def polytope(x):
    """
    If value is greater than 0, then 1 else 0
    """
    return (x > 0).float()

h1_plt = polytope(h1_raw)
h2_plt = polytope(h2_raw)
h3_plt = polytope(h3_raw)
base_plt = polytope(h_base_raw)

# Print the shapes of the polytopes
print(f"Shape of h1_plt: {h1_plt.shape}")

# hamming distance
hamming_distance = torch.sum(h1_plt != h2_plt).item()
print(f"Hamming distance between h1 and h2: {hamming_distance}")

hamming_distance = torch.sum(h1_plt != h3_plt).item()
print(f"Hamming distance between h1 and h3: {hamming_distance}")

hamming_distance = torch.sum(h2_plt != h3_plt).item()
print(f"Hamming distance between h2 and h3: {hamming_distance}")

hamming_distance = torch.sum(h1_plt != base_plt).item()
print(f"Hamming distance between h1 and base: {hamming_distance}")

hamming_distance = torch.sum(h2_plt != base_plt).item()
print(f"Hamming distance between h2 and base: {hamming_distance}")

hamming_distance = torch.sum(h3_plt != base_plt).item()
print(f"Hamming distance between h3 and base: {hamming_distance}")

Shape of h1_plt: torch.Size([2304])
Hamming distance between h1 and h2: 595
Hamming distance between h1 and h3: 743
Hamming distance between h2 and h3: 412
Hamming distance between h1 and base: 643
Hamming distance between h2 and base: 744
Hamming distance between h3 and base: 806


Extract sae features

In [None]:
k=170

z1,s1,s1_acts = extract_sae_features(h1_raw, sae, model_type, k)
z2,s2,s2_acts = extract_sae_features(h2_raw, sae, model_type, k)
z3,s3,s3_acts = extract_sae_features(h3_raw, sae, model_type, k)

In [46]:
# value in z1, z2, z3 are not 1 only
# appy polytope
z1_plt = polytope(z1)
z2_plt = polytope(z2)
z3_plt = polytope(z3)

# hamming distance
hamming_distance = torch.sum(z1_plt != z2_plt).item()
print(f"Hamming distance between z1 and z2: {hamming_distance}")
hamming_distance = torch.sum(z1_plt != z3_plt).item()  
print(f"Hamming distance between z1 and z3: {hamming_distance}")
hamming_distance = torch.sum(z2_plt != z3_plt).item()
print(f"Hamming distance between z2 and z3: {hamming_distance}")

Hamming distance between z1 and z2: 229
Hamming distance between z1 and z3: 234
Hamming distance between z2 and z3: 157


In [43]:
# how about get overlap
overlap_score = get_overlap(s1, s2)
print(f"Overlap score between s1 and s2: {overlap_score}")
overlap_score = get_overlap(s1, s3)
print(f"Overlap score between s1 and s3: {overlap_score}")
overlap_score = get_overlap(s2, s3)
print(f"Overlap score between s2 and s3: {overlap_score}")

Overlap score between s1 and s2: 0.4000000059604645
Overlap score between s1 and s3: 0.3529411852359772
Overlap score between s2 and s3: 0.6235294342041016
