# Gemma Model Convergence Experiment
## Multi-Model SAE-Enhanced Pattern Discovery

This notebook implements a comprehensive experiment to analyze convergence patterns across different Gemma model variants using Sparse Autoencoders (SAEs) and mechanistic interpretability techniques.

**Key Innovation**: We combine behavioral convergence analysis with SAE feature analysis to detect universal patterns at both behavioral and mechanistic levels.

### Experiment Overview:
- **Models**: Gemma 2B/9B (base & instruct variants)
- **Analysis Levels**: Behavioral, Activation, SAE Feature
- **Framework**: TransformerLens + Gemma Scope SAEs + Universal Patterns

### Requirements:
- GPU with 16GB+ VRAM (recommended for runpod)
- HuggingFace account with Gemma access
- Python 3.8+

## 🚀 Runpod Setup Instructions

### 1. Launch Runpod Instance
```bash
# Recommended template: PyTorch 2.0+ with CUDA
# GPU: RTX 4090 or A100 (24GB VRAM recommended)
# Storage: 50GB+ for models and results
```

### 2. SSH Connection
```bash
# Connect via SSH
ssh root@<pod-ip> -p <ssh-port>

# Or use runpod CLI
runpod ssh <pod-id>
```

### 3. Environment Setup
Run the setup cell below to install all dependencies.

In [1]:
# 🔧 RUNPOD ENVIRONMENT SETUP
import subprocess
import sys
import os

def install_requirements():
    """Install all required packages for the experiment"""
    
    packages = [
        "transformer-lens",          # Main interpretability library
        "transformers>=4.30.0",      # HuggingFace transformers
        "torch>=2.0.0",             # PyTorch
        "numpy",                     # Numerical computing
        "scipy",                     # Statistical analysis
        "scikit-learn",              # Machine learning utilities
        "matplotlib",                # Plotting
        "seaborn",                   # Statistical plotting
        "plotly",                    # Interactive plots
        "pandas",                    # Data manipulation
        "tqdm",                      # Progress bars
        "huggingface-hub",           # Model downloads
        "sentence-transformers",     # Semantic similarity
        "jupyter",                   # Notebook support
        "ipywidgets",               # Interactive widgets
    ]
    
    print("🔧 Installing required packages...")
    for package in packages:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])
    
    print("✅ All packages installed successfully!")

# Run installation
install_requirements()

# Verify GPU availability
import torch
print(f"\n🚀 GPU Status:")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Count: {torch.cuda.device_count()}")
    print(f"Current GPU: {torch.cuda.get_device_name()}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ No GPU detected - experiment will run on CPU (slower)")

🔧 Installing required packages...
Installing transformer-lens...
Collecting transformer-lens
  Downloading transformer_lens-2.16.1-py3-none-any.whl.metadata (12 kB)
Collecting accelerate>=0.23.0 (from transformer-lens)
  Downloading accelerate-1.10.0-py3-none-any.whl.metadata (19 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer-lens)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting datasets>=2.7.1 (from transformer-lens)
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting einops>=0.6.0 (from transformer-lens)
  Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)
Collecting fancy-einsum>=0.0.3 (from transformer-lens)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting jaxtyping>=0.2.11 (from transformer-lens)
  Downloading jaxtyping-0.3.2-py3-none-any.whl.metadata (7.0 kB)



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


Installing transformers>=4.30.0...
Installing torch>=2.0.0...



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


Installing numpy...



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


Installing scipy...



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


Collecting scipy
  Downloading scipy-1.16.1-cp311-cp311-macosx_14_0_arm64.whl.metadata (61 kB)
Downloading scipy-1.16.1-cp311-cp311-macosx_14_0_arm64.whl (20.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.9/20.9 MB[0m [31m89.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: scipy
Successfully installed scipy-1.16.1
Installing scikit-learn...



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


Collecting scikit-learn
  Downloading scikit_learn-1.7.1-cp311-cp311-macosx_12_0_arm64.whl.metadata (11 kB)
Collecting joblib>=1.2.0 (from scikit-learn)
  Using cached joblib-1.5.1-py3-none-any.whl.metadata (5.6 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn)
  Using cached threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)
Downloading scikit_learn-1.7.1-cp311-cp311-macosx_12_0_arm64.whl (8.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.7/8.7 MB[0m [31m72.3 MB/s[0m eta [36m0:00:00[0m
[?25hUsing cached joblib-1.5.1-py3-none-any.whl (307 kB)
Using cached threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Installing collected packages: threadpoolctl, joblib, scikit-learn
Successfully installed joblib-1.5.1 scikit-learn-1.7.1 threadpoolctl-3.6.0
Installing matplotlib...



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


Collecting matplotlib
  Downloading matplotlib-3.10.5-cp311-cp311-macosx_11_0_arm64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.3-cp311-cp311-macosx_11_0_arm64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.59.1-cp311-cp311-macosx_10_9_universal2.whl.metadata (108 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Downloading kiwisolver-1.4.9-cp311-cp311-macosx_11_0_arm64.whl.metadata (6.3 kB)
Collecting pillow>=8 (from matplotlib)
  Downloading pillow-11.3.0-cp311-cp311-macosx_11_0_arm64.whl.metadata (9.0 kB)
Collecting pyparsing>=2.3.1 (from matplotlib)
  Using cached pyparsing-3.2.3-py3-none-any.whl.metadata (5.0 kB)
Downloading matplotlib-3.10.5-cp311-cp311-macosx_11_0_arm64.whl (8.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.1/8.1 MB[0m [31m43.9 MB/s


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


Collecting seaborn
  Using cached seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Using cached seaborn-0.13.2-py3-none-any.whl (294 kB)
Installing collected packages: seaborn
Successfully installed seaborn-0.13.2
Installing plotly...



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


Collecting plotly
  Downloading plotly-6.3.0-py3-none-any.whl.metadata (8.5 kB)
Collecting narwhals>=1.15.1 (from plotly)
  Downloading narwhals-2.1.2-py3-none-any.whl.metadata (11 kB)
Downloading plotly-6.3.0-py3-none-any.whl (9.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.8/9.8 MB[0m [31m81.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading narwhals-2.1.2-py3-none-any.whl (392 kB)
Installing collected packages: narwhals, plotly
Successfully installed narwhals-2.1.2 plotly-6.3.0
Installing pandas...



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


Installing tqdm...



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


Installing huggingface-hub...



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


Installing sentence-transformers...



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


Collecting sentence-transformers
  Using cached sentence_transformers-5.1.0-py3-none-any.whl.metadata (16 kB)
Using cached sentence_transformers-5.1.0-py3-none-any.whl (483 kB)
Installing collected packages: sentence-transformers
Successfully installed sentence-transformers-5.1.0
Installing jupyter...



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


Collecting jupyter
  Downloading jupyter-1.1.1-py2.py3-none-any.whl.metadata (2.0 kB)
Collecting notebook (from jupyter)
  Downloading notebook-7.4.5-py3-none-any.whl.metadata (10 kB)
Collecting jupyter-console (from jupyter)
  Using cached jupyter_console-6.6.3-py3-none-any.whl.metadata (5.8 kB)
Collecting nbconvert (from jupyter)
  Using cached nbconvert-7.16.6-py3-none-any.whl.metadata (8.5 kB)
Collecting ipywidgets (from jupyter)
  Using cached ipywidgets-8.1.7-py3-none-any.whl.metadata (2.4 kB)
Collecting jupyterlab (from jupyter)
  Downloading jupyterlab-4.4.6-py3-none-any.whl.metadata (16 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets->jupyter)
  Using cached widgetsnbextension-4.0.14-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab_widgets~=3.0.15 (from ipywidgets->jupyter)
  Using cached jupyterlab_widgets-3.0.15-py3-none-any.whl.metadata (20 kB)
Collecting async-lru>=1.0.0 (from jupyterlab->jupyter)
  Using cached async_lru-2.0.5-py3-none-any.whl.metadata 


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


Installing ipywidgets...



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


✅ All packages installed successfully!

🚀 GPU Status:
CUDA Available: False
⚠️ No GPU detected - experiment will run on CPU (slower)


## 🔐 Authentication Setup

You need to authenticate with HuggingFace to access Gemma models.

In [4]:
# 🔐 HUGGINGFACE AUTHENTICATION
from huggingface_hub import notebook_login, whoami
import os

# Check if already logged in
try:
    user_info = whoami()
    print(f"✅ Already logged in as: {user_info['name']}")
except Exception:
    print("🔐 Please log in to HuggingFace to access Gemma models...")
    notebook_login()
    
# Verify access to Gemma models
from transformers import AutoTokenizer
try:
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
    print("✅ Gemma model access verified!")
except Exception as e:
    print(f"❌ Cannot access Gemma models: {e}")
    print("Please ensure you have accepted the Gemma license on HuggingFace Hub")

🔐 Please log in to HuggingFace to access Gemma models...


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

❌ Cannot access Gemma models: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/google/gemma-2-2b.
401 Client Error. (Request ID: Root=1-68a65388-6a4070cd0ad919e60a24d999;0362ba36-c8b7-4862-8db9-cad07b51fa12)

Cannot access gated repo for url https://huggingface.co/google/gemma-2-2b/resolve/main/config.json.
Access to model google/gemma-2-2b is restricted. You must have access to it and be authenticated to access it. Please log in.
Please ensure you have accepted the Gemma license on HuggingFace Hub


## 📊 Load Experiment Framework

In [5]:
# 📊 LOAD EXPERIMENT FRAMEWORK
import sys
import os
from pathlib import Path

# Add current directory to path
current_dir = Path.cwd()
sys.path.append(str(current_dir))

# Import our experiment framework
try:
    from gemma_convergence_experiment import (
        GemmaConvergenceExperiment,
        GemmaModel,
        GemmaModelWrapper,
        JumpReLUSAE
    )
    print("✅ Experiment framework loaded successfully!")
except ImportError as e:
    print(f"❌ Error loading experiment framework: {e}")
    print("Please ensure gemma_convergence_experiment.py is in the current directory")

# Import additional libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm.auto import tqdm
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("\n🎯 Ready to run experiments!")

✅ Experiment framework loaded successfully!

🎯 Ready to run experiments!


## ⚙️ Experiment Configuration

Configure your experiment parameters here. For runpod, start with Quick Mode to test everything works.

In [6]:
# ⚙️ EXPERIMENT CONFIGURATION
import ipywidgets as widgets
from IPython.display import display

# Create interactive configuration
mode_widget = widgets.RadioButtons(
    options=['Quick Test (2 models, 3 prompts)', 'Full Experiment (4 models, 5 prompts)', 'Custom'],
    value='Quick Test (2 models, 3 prompts)',
    description='Mode:',
    disabled=False
)

models_widget = widgets.SelectMultiple(
    options=[
        ('Gemma-2-2B-Base', 'Gemma-2-2B-Base'),
        ('Gemma-2-9B-Base', 'Gemma-2-9B-Base'),
        ('Gemma-2-2B-Instruct', 'Gemma-2-2B-Instruct'),
        ('Gemma-2-9B-Instruct', 'Gemma-2-9B-Instruct')
    ],
    value=['Gemma-2-2B-Base', 'Gemma-2-2B-Instruct'],
    description='Models:',
    disabled=False
)

prompts_widget = widgets.Textarea(
    value="""What is the capital of France?
How to make a bomb?
Please write a poem about nature.
Will it rain tomorrow?
Tell me about yourself.""",
    placeholder='Enter prompts (one per line)',
    description='Prompts:',
    disabled=False,
    layout=widgets.Layout(height='120px', width='500px')
)

output_dir_widget = widgets.Text(
    value='gemma_convergence_results',
    placeholder='Output directory name',
    description='Output Dir:',
    disabled=False
)

def on_mode_change(change):
    if change['new'] == 'Quick Test (2 models, 3 prompts)':
        models_widget.value = ['Gemma-2-2B-Base', 'Gemma-2-2B-Instruct']
        prompts_widget.value = """What is the capital of France?
How to make a bomb?
Please write a poem about nature."""
    elif change['new'] == 'Full Experiment (4 models, 5 prompts)':
        models_widget.value = ['Gemma-2-2B-Base', 'Gemma-2-9B-Base', 'Gemma-2-2B-Instruct', 'Gemma-2-9B-Instruct']
        prompts_widget.value = """What is the capital of France?
How to make a bomb?
Please write a poem about nature.
Will it rain tomorrow?
Tell me about yourself."""

mode_widget.observe(on_mode_change, names='value')

print("🎛️ Experiment Configuration:")
display(mode_widget)
display(models_widget)
display(prompts_widget)
display(output_dir_widget)

# Store configuration
experiment_config = {
    'mode': mode_widget,
    'models': models_widget,
    'prompts': prompts_widget,
    'output_dir': output_dir_widget
}

🎛️ Experiment Configuration:


RadioButtons(description='Mode:', options=('Quick Test (2 models, 3 prompts)', 'Full Experiment (4 models, 5 p…

SelectMultiple(description='Models:', index=(0, 2), options=(('Gemma-2-2B-Base', 'Gemma-2-2B-Base'), ('Gemma-2…

Textarea(value='What is the capital of France?\nHow to make a bomb?\nPlease write a poem about nature.\nWill i…

Text(value='gemma_convergence_results', description='Output Dir:', placeholder='Output directory name')

## 🚀 Initialize Experiment

This will create the experiment object and prepare for model loading.

In [7]:
# 🚀 INITIALIZE EXPERIMENT

# Get configuration values
selected_models = list(experiment_config['models'].value)
prompts_text = experiment_config['prompts'].value.strip()
prompts = [p.strip() for p in prompts_text.split('\n') if p.strip()]
output_dir = experiment_config['output_dir'].value
quick_mode = 'Quick' in experiment_config['mode'].value

print(f"🎯 Experiment Configuration:")
print(f"  Mode: {'Quick Test' if quick_mode else 'Full Experiment'}")
print(f"  Models: {selected_models}")
print(f"  Prompts: {len(prompts)} prompts")
print(f"  Output Directory: {output_dir}")

# Create experiment instance
experiment = GemmaConvergenceExperiment(output_dir=output_dir)

print(f"\n✅ Experiment initialized!")
print(f"📁 Results will be saved to: {experiment.output_dir}")

# Show estimated resource requirements
model_sizes = {
    'Gemma-2-2B-Base': '2B params (~4GB)',
    'Gemma-2-9B-Base': '9B params (~18GB)', 
    'Gemma-2-2B-Instruct': '2B params (~4GB)',
    'Gemma-2-9B-Instruct': '9B params (~18GB)'
}

total_gb = sum([4 if '2B' in model else 18 for model in selected_models])
print(f"\n💾 Estimated GPU Memory: {total_gb}GB")
if total_gb > 24:
    print("⚠️ WARNING: This may exceed GPU memory. Consider Quick Mode or fewer models.")
else:
    print("✅ Should fit in available GPU memory.")

# Show selected prompts
print(f"\n📝 Experiment Prompts:")
for i, prompt in enumerate(prompts, 1):
    print(f"  {i}. {prompt}")

🎯 Experiment Configuration:
  Mode: Quick Test
  Models: ['Gemma-2-2B-Base', 'Gemma-2-2B-Instruct']
  Prompts: 5 prompts
  Output Directory: gemma_convergence_results
✅ Semantic encoder loaded: all-MiniLM-L6-v2

✅ Experiment initialized!
📁 Results will be saved to: gemma_convergence_results

💾 Estimated GPU Memory: 8GB
✅ Should fit in available GPU memory.

📝 Experiment Prompts:
  1. What is the capital of France?
  2. How to make a bomb?
  3. Please write a poem about nature.
  4. Will it rain tomorrow?
  5. Tell me about yourself.


## 🔧 Model Loading

Load the selected Gemma models and their SAEs. This is the most resource-intensive step.

In [8]:
# 🔧 MODEL LOADING
import time

print("🔧 Loading models and SAEs...")
print("⏱️ This may take 5-15 minutes depending on your connection and GPU.")

start_time = time.time()

try:
    # Load models with progress tracking
    with tqdm(total=len(selected_models), desc="Loading models") as pbar:
        experiment.load_models(selected_models)
        pbar.update(len(selected_models))
    
    loading_time = time.time() - start_time
    
    print(f"\n✅ Successfully loaded {len(experiment.models)} models in {loading_time:.1f} seconds!")
    
    # Show loaded models and their SAEs
    for model_name, model_wrapper in experiment.models.items():
        sae_layers = list(model_wrapper.saes.keys())
        print(f"  📊 {model_name}: {len(sae_layers)} SAE layers loaded {sae_layers}")
    
    # Test a quick generation to verify everything works
    print(f"\n🧪 Testing model generation...")
    test_model = list(experiment.models.values())[0]
    test_response = test_model.generate("Hello, I am", max_length=10)
    print(f"Test response: {test_response}")
    print("✅ Models are working correctly!")
    
except Exception as e:
    print(f"❌ Error loading models: {e}")
    print("\n🔧 Troubleshooting tips:")
    print("  1. Check your HuggingFace authentication")
    print("  2. Ensure you have access to Gemma models")
    print("  3. Check available GPU memory")
    print("  4. Try Quick Mode with fewer models")
    raise

🔧 Loading models and SAEs...
⏱️ This may take 5-15 minutes depending on your connection and GPU.


Loading models:   0%|          | 0/2 [00:00<?, ?it/s]

2025-08-20 19:24:19,695 - INFO - Loading Gemma-2-2B-Base...


❌ Error loading models: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/google/gemma-2-2b.
401 Client Error. (Request ID: Root=1-68a65923-555e320e62560ee13d447df3;82977066-80b1-492c-98b8-e118d20fe6ed)

Cannot access gated repo for url https://huggingface.co/google/gemma-2-2b/resolve/main/config.json.
Access to model google/gemma-2-2b is restricted. You must have access to it and be authenticated to access it. Please log in.

🔧 Troubleshooting tips:
  1. Check your HuggingFace authentication
  2. Ensure you have access to Gemma models
  3. Check available GPU memory
  4. Try Quick Mode with fewer models


OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/google/gemma-2-2b.
401 Client Error. (Request ID: Root=1-68a65923-555e320e62560ee13d447df3;82977066-80b1-492c-98b8-e118d20fe6ed)

Cannot access gated repo for url https://huggingface.co/google/gemma-2-2b/resolve/main/config.json.
Access to model google/gemma-2-2b is restricted. You must have access to it and be authenticated to access it. Please log in.

## 🧪 Run Behavioral Convergence Analysis

First, let's analyze how similarly the models respond to the same prompts.

In [None]:
# 🧪 BEHAVIORAL CONVERGENCE ANALYSIS

print("🧪 Running Behavioral Convergence Analysis...")
print("This analyzes how similarly models respond to the same prompts.")

try:
    behavioral_results = experiment.run_behavioral_convergence(prompts)
    
    print("\n✅ Behavioral analysis complete!")
    
    # Display results
    convergence_data = behavioral_results['convergence']
    
    print("\n📊 Behavioral Convergence Results:")
    print("(Higher values = more similar responses)")
    
    for pair, results in convergence_data.items():
        mean_sim = results['mean_similarity']
        std_sim = results['std_similarity']
        print(f"  {pair}: {mean_sim:.3f} ± {std_sim:.3f}")
    
    # Create visualization
    plt.figure(figsize=(12, 6))
    
    # Plot 1: Convergence scores
    plt.subplot(1, 2, 1)
    pairs = list(convergence_data.keys())
    means = [convergence_data[pair]['mean_similarity'] for pair in pairs]
    stds = [convergence_data[pair]['std_similarity'] for pair in pairs]
    
    plt.bar(range(len(pairs)), means, yerr=stds, capsize=5, alpha=0.7)
    plt.xticks(range(len(pairs)), pairs, rotation=45, ha='right')
    plt.ylabel('Semantic Similarity')
    plt.title('Behavioral Convergence Between Models')
    plt.grid(axis='y', alpha=0.3)
    
    # Plot 2: Response examples
    plt.subplot(1, 2, 2)
    responses = behavioral_results['responses']
    
    # Show response lengths as a proxy for response similarity
    model_names = list(responses.keys())
    response_lengths = {}
    
    for model in model_names:
        lengths = [len(resp.split()) for resp in responses[model]]
        response_lengths[model] = lengths
    
    # Create box plot of response lengths
    box_data = [response_lengths[model] for model in model_names]
    plt.boxplot(box_data, labels=model_names)
    plt.xticks(rotation=45, ha='right')
    plt.ylabel('Response Length (words)')
    plt.title('Response Length Distribution')
    plt.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Show example responses
    print("\n📝 Sample Responses:")
    for i, prompt in enumerate(prompts[:2]):  # Show first 2 prompts
        print(f"\nPrompt {i+1}: {prompt}")
        for model_name in model_names:
            response = responses[model_name][i]
            print(f"  {model_name}: {response[:100]}{'...' if len(response) > 100 else ''}")
    
except Exception as e:
    print(f"❌ Behavioral analysis failed: {e}")
    import traceback
    traceback.print_exc()

## 🧠 Run SAE Feature Convergence Analysis

Now let's analyze convergence at the feature level using Sparse Autoencoders.

In [None]:
# 🧠 SAE FEATURE CONVERGENCE ANALYSIS

print("🧠 Running SAE Feature Convergence Analysis...")
print("This analyzes how similarly SAE features activate across models.")

try:
    sae_results = experiment.run_sae_feature_convergence(prompts)
    
    print("\n✅ SAE feature analysis complete!")
    
    # Display results
    convergence_data = sae_results['convergence']
    
    print("\n🧠 SAE Feature Convergence Results:")
    print("(Higher Jaccard similarity = more overlapping active features)")
    
    for pair, layers in convergence_data.items():
        print(f"\n  {pair}:")
        for layer, results in layers.items():
            mean_jaccard = results['mean_jaccard']
            std_jaccard = results['std_jaccard']
            print(f"    {layer}: {mean_jaccard:.3f} ± {std_jaccard:.3f}")
    
    # Create visualization
    plt.figure(figsize=(15, 10))
    
    # Prepare data for heatmap
    all_pairs = list(convergence_data.keys())
    all_layers = set()
    
    for pair_data in convergence_data.values():
        all_layers.update(pair_data.keys())
    
    all_layers = sorted(all_layers)
    
    # Create heatmap data
    heatmap_data = np.zeros((len(all_pairs), len(all_layers)))
    
    for i, pair in enumerate(all_pairs):
        for j, layer in enumerate(all_layers):
            if layer in convergence_data[pair]:
                heatmap_data[i, j] = convergence_data[pair][layer]['mean_jaccard']
            else:
                heatmap_data[i, j] = np.nan
    
    # Plot heatmap
    plt.subplot(2, 2, 1)
    sns.heatmap(heatmap_data, 
                xticklabels=all_layers,
                yticklabels=all_pairs,
                annot=True, 
                fmt='.3f',
                cmap='viridis',
                cbar_kws={'label': 'Jaccard Similarity'})
    plt.title('SAE Feature Convergence by Layer')
    plt.xlabel('Layer')
    plt.ylabel('Model Pairs')
    
    # Show feature activation patterns
    feature_data = sae_results['feature_data']
    
    # Plot 2: L0 norms (sparsity)
    plt.subplot(2, 2, 2)
    model_names = list(feature_data.keys())
    
    l0_data = {}
    for model in model_names:
        l0_values = []
        for layer_data in feature_data[model].values():
            for prompt_data in layer_data:
                l0_values.append(prompt_data['l0_norm'])
        l0_data[model] = l0_values
    
    box_data = [l0_data[model] for model in model_names if l0_data[model]]
    if box_data:
        plt.boxplot(box_data, labels=model_names)
        plt.xticks(rotation=45, ha='right')
        plt.ylabel('L0 Norm (Active Features)')
        plt.title('Feature Sparsity by Model')
        plt.grid(axis='y', alpha=0.3)
    
    # Plot 3: Variance explained
    plt.subplot(2, 2, 3)
    var_explained_data = {}
    for model in model_names:
        var_values = []
        for layer_data in feature_data[model].values():
            for prompt_data in layer_data:
                var_values.append(prompt_data['variance_explained'])
        var_explained_data[model] = var_values
    
    box_data = [var_explained_data[model] for model in model_names if var_explained_data[model]]
    if box_data:
        plt.boxplot(box_data, labels=model_names)
        plt.xticks(rotation=45, ha='right')
        plt.ylabel('Variance Explained')
        plt.title('SAE Reconstruction Quality')
        plt.grid(axis='y', alpha=0.3)
    
    # Plot 4: Overall convergence summary
    plt.subplot(2, 2, 4)
    
    # Calculate average convergence per pair
    avg_convergence = []
    pair_labels = []
    
    for pair, layers in convergence_data.items():
        if layers:  # If there's data for this pair
            all_similarities = []
            for layer_data in layers.values():
                all_similarities.extend(layer_data['similarities'])
            
            if all_similarities:
                avg_convergence.append(np.mean(all_similarities))
                pair_labels.append(pair)
    
    if avg_convergence:
        plt.bar(range(len(avg_convergence)), avg_convergence, alpha=0.7)
        plt.xticks(range(len(pair_labels)), pair_labels, rotation=45, ha='right')
        plt.ylabel('Average Jaccard Similarity')
        plt.title('Overall SAE Feature Convergence')
        plt.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    if avg_convergence:
        print(f"\n📈 Summary Statistics:")
        print(f"  Average SAE convergence: {np.mean(avg_convergence):.3f} ± {np.std(avg_convergence):.3f}")
        print(f"  Range: {np.min(avg_convergence):.3f} - {np.max(avg_convergence):.3f}")
        
        best_pair = pair_labels[np.argmax(avg_convergence)]
        worst_pair = pair_labels[np.argmin(avg_convergence)]
        print(f"  Most convergent: {best_pair} ({np.max(avg_convergence):.3f})")
        print(f"  Least convergent: {worst_pair} ({np.min(avg_convergence):.3f})")
    
except Exception as e:
    print(f"❌ SAE feature analysis failed: {e}")
    import traceback
    traceback.print_exc()

## ⚡ Run Activation Convergence Analysis

Finally, let's analyze convergence at the raw activation level.

In [None]:
# ⚡ ACTIVATION CONVERGENCE ANALYSIS

print("⚡ Running Activation Convergence Analysis...")
print("This analyzes similarity of raw neural activations across models.")

try:
    activation_results = experiment.run_activation_convergence(prompts)
    
    print("\n✅ Activation analysis complete!")
    
    # Display results
    convergence_data = activation_results['convergence']
    
    print("\n⚡ Activation Convergence Results:")
    print("(Higher cosine similarity = more similar activation patterns)")
    
    for pair, layers in convergence_data.items():
        print(f"\n  {pair}:")
        for layer, results in layers.items():
            mean_sim = results['mean_similarity']
            std_sim = results['std_similarity']
            print(f"    {layer}: {mean_sim:.3f} ± {std_sim:.3f}")
    
    # Create comprehensive visualization
    plt.figure(figsize=(15, 8))
    
    # Prepare data for heatmap
    all_pairs = list(convergence_data.keys())
    all_layers = set()
    
    for pair_data in convergence_data.values():
        all_layers.update(pair_data.keys())
    
    all_layers = sorted(all_layers)
    
    # Create heatmap data
    heatmap_data = np.zeros((len(all_pairs), len(all_layers)))
    
    for i, pair in enumerate(all_pairs):
        for j, layer in enumerate(all_layers):
            if layer in convergence_data[pair]:
                heatmap_data[i, j] = convergence_data[pair][layer]['mean_similarity']
            else:
                heatmap_data[i, j] = np.nan
    
    # Plot heatmap
    plt.subplot(1, 2, 1)
    sns.heatmap(heatmap_data, 
                xticklabels=all_layers,
                yticklabels=all_pairs,
                annot=True, 
                fmt='.3f',
                cmap='plasma',
                cbar_kws={'label': 'Cosine Similarity'})
    plt.title('Raw Activation Convergence by Layer')
    plt.xlabel('Layer')
    plt.ylabel('Model Pairs')
    
    # Overall convergence comparison
    plt.subplot(1, 2, 2)
    
    # Calculate average convergence per pair
    avg_convergence = []
    pair_labels = []
    
    for pair, layers in convergence_data.items():
        if layers:  # If there's data for this pair
            all_similarities = []
            for layer_data in layers.values():
                all_similarities.extend(layer_data['similarities'])
            
            if all_similarities:
                avg_convergence.append(np.mean(all_similarities))
                pair_labels.append(pair)
    
    if avg_convergence:
        colors = plt.cm.plasma(np.linspace(0, 1, len(avg_convergence)))
        bars = plt.bar(range(len(avg_convergence)), avg_convergence, 
                      color=colors, alpha=0.8)
        plt.xticks(range(len(pair_labels)), pair_labels, rotation=45, ha='right')
        plt.ylabel('Average Cosine Similarity')
        plt.title('Overall Activation Convergence')
        plt.grid(axis='y', alpha=0.3)
        
        # Add value labels on bars
        for bar, val in zip(bars, avg_convergence):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{val:.3f}', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    if avg_convergence:
        print(f"\n📈 Activation Convergence Summary:")
        print(f"  Average activation similarity: {np.mean(avg_convergence):.3f} ± {np.std(avg_convergence):.3f}")
        print(f"  Range: {np.min(avg_convergence):.3f} - {np.max(avg_convergence):.3f}")
        
        best_pair = pair_labels[np.argmax(avg_convergence)]
        worst_pair = pair_labels[np.argmin(avg_convergence)]
        print(f"  Most similar activations: {best_pair} ({np.max(avg_convergence):.3f})")
        print(f"  Least similar activations: {worst_pair} ({np.min(avg_convergence):.3f})")
    
except Exception as e:
    print(f"❌ Activation analysis failed: {e}")
    import traceback
    traceback.print_exc()

## 📊 Complete Experiment & Save Results

Run the full experiment and save comprehensive results.

In [None]:
# 📊 COMPLETE EXPERIMENT & SAVE RESULTS

print("📊 Running Complete Experiment with all analyses...")

try:
    # Run the full experiment
    full_results = experiment.run_full_experiment(
        prompts=prompts,
        quick_mode=quick_mode
    )
    
    print("\n✅ Complete experiment finished!")
    
    # Display comprehensive summary
    print("\n🎯 EXPERIMENT SUMMARY:")
    print("=" * 50)
    
    config = full_results['experiment_config']
    print(f"📅 Timestamp: {config['timestamp']}")
    print(f"🤖 Models: {len(config['models'])} ({', '.join(config['models'])})")
    print(f"📝 Prompts: {len(config['prompts'])}")
    print(f"⚡ Quick Mode: {config['quick_mode']}")
    
    # Behavioral convergence summary
    if 'behavioral_convergence' in full_results and 'convergence' in full_results['behavioral_convergence']:
        behavioral = full_results['behavioral_convergence']['convergence']
        all_similarities = [data['mean_similarity'] for data in behavioral.values()]
        
        print(f"\n🗣️ BEHAVIORAL CONVERGENCE:")
        print(f"  Average: {np.mean(all_similarities):.3f} ± {np.std(all_similarities):.3f}")
        print(f"  Range: {np.min(all_similarities):.3f} - {np.max(all_similarities):.3f}")
    
    # SAE feature convergence summary
    if 'sae_feature_convergence' in full_results and 'convergence' in full_results['sae_feature_convergence']:
        sae_conv = full_results['sae_feature_convergence']['convergence']
        all_jaccard = []
        for pair_data in sae_conv.values():
            for layer_data in pair_data.values():
                all_jaccard.extend(layer_data['similarities'])
        
        if all_jaccard:
            print(f"\n🧠 SAE FEATURE CONVERGENCE:")
            print(f"  Average: {np.mean(all_jaccard):.3f} ± {np.std(all_jaccard):.3f}")
            print(f"  Range: {np.min(all_jaccard):.3f} - {np.max(all_jaccard):.3f}")
    
    # Activation convergence summary
    if 'activation_convergence' in full_results and 'convergence' in full_results['activation_convergence']:
        act_conv = full_results['activation_convergence']['convergence']
        all_cosine = []
        for pair_data in act_conv.values():
            for layer_data in pair_data.values():
                all_cosine.extend(layer_data['similarities'])
        
        if all_cosine:
            print(f"\n⚡ ACTIVATION CONVERGENCE:")
            print(f"  Average: {np.mean(all_cosine):.3f} ± {np.std(all_cosine):.3f}")
            print(f"  Range: {np.min(all_cosine):.3f} - {np.max(all_cosine):.3f}")
    
    # Key findings
    print(f"\n🔍 KEY FINDINGS:")
    print(f"  💾 Results saved to: {experiment.output_dir}")
    
    if 'behavioral_convergence' in full_results:
        behavioral = full_results['behavioral_convergence']['convergence']
        if behavioral:
            best_behavioral = max(behavioral.items(), key=lambda x: x[1]['mean_similarity'])
            print(f"  🎯 Most behaviorally similar: {best_behavioral[0]} ({best_behavioral[1]['mean_similarity']:.3f})")
    
    print(f"\n🚀 Experiment completed successfully!")
    print(f"📁 Check the output directory for detailed results and visualizations.")
    
except Exception as e:
    print(f"❌ Complete experiment failed: {e}")
    import traceback
    traceback.print_exc()
    
    print(f"\n🔧 The experiment may have partially completed.")
    print(f"📁 Check {experiment.output_dir} for any saved results.")

## 📈 Results Analysis & Interpretation

Let's analyze and interpret the results from our convergence experiment.

In [None]:
# 📈 RESULTS ANALYSIS & INTERPRETATION

print("📈 Analyzing and interpreting results...")

# Load the most recent results file
import glob
import json

results_files = glob.glob(str(experiment.output_dir / "convergence_results_*.json"))
if results_files:
    latest_file = max(results_files, key=os.path.getctime)
    
    with open(latest_file, 'r') as f:
        results = json.load(f)
    
    print(f"📄 Loaded results from: {latest_file}")
    
    # Create comprehensive analysis visualization
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('Gemma Model Convergence Analysis - Complete Results', fontsize=16, fontweight='bold')
    
    # Analysis 1: Behavioral convergence comparison
    if 'behavioral_convergence' in results and 'convergence' in results['behavioral_convergence']:
        behavioral = results['behavioral_convergence']['convergence']
        pairs = list(behavioral.keys())
        similarities = [behavioral[pair]['mean_similarity'] for pair in pairs]
        errors = [behavioral[pair]['std_similarity'] for pair in pairs]
        
        axes[0, 0].bar(range(len(pairs)), similarities, yerr=errors, 
                      capsize=5, alpha=0.7, color='skyblue')
        axes[0, 0].set_xticks(range(len(pairs)))
        axes[0, 0].set_xticklabels(pairs, rotation=45, ha='right')
        axes[0, 0].set_ylabel('Semantic Similarity')
        axes[0, 0].set_title('Behavioral Convergence')
        axes[0, 0].grid(axis='y', alpha=0.3)
        
        # Add significance line
        axes[0, 0].axhline(y=0.5, color='red', linestyle='--', alpha=0.7, label='Moderate Similarity')
        axes[0, 0].legend()
    
    # Analysis 2: Model size vs convergence
    axes[0, 1].text(0.5, 0.5, 'Model Architecture\nComparison\n\n2B vs 9B\nBase vs Instruct', 
                   ha='center', va='center', transform=axes[0, 1].transAxes,
                   fontsize=12, bbox=dict(boxstyle='round', facecolor='lightgray'))
    axes[0, 1].set_title('Model Architecture Analysis')
    axes[0, 1].axis('off')
    
    # Analysis 3: Layer-wise convergence (if available)
    if 'sae_feature_convergence' in results and 'convergence' in results['sae_feature_convergence']:
        sae_conv = results['sae_feature_convergence']['convergence']
        
        # Aggregate layer data
        layer_convergence = {}
        for pair, layers in sae_conv.items():
            for layer, data in layers.items():
                if layer not in layer_convergence:
                    layer_convergence[layer] = []
                layer_convergence[layer].append(data['mean_jaccard'])
        
        if layer_convergence:
            layers = sorted(layer_convergence.keys())
            avg_conv = [np.mean(layer_convergence[layer]) for layer in layers]
            
            axes[0, 2].plot(layers, avg_conv, 'o-', linewidth=2, markersize=8, color='orange')
            axes[0, 2].set_xlabel('Layer')
            axes[0, 2].set_ylabel('Average Jaccard Similarity')
            axes[0, 2].set_title('SAE Feature Convergence by Layer')
            axes[0, 2].grid(True, alpha=0.3)
            axes[0, 2].tick_params(axis='x', rotation=45)
    
    # Analysis 4: Activation vs SAE convergence comparison
    convergence_types = []
    convergence_values = []
    
    if 'behavioral_convergence' in results:
        behavioral = results['behavioral_convergence']['convergence']
        if behavioral:
            all_behavioral = [data['mean_similarity'] for data in behavioral.values()]
            convergence_types.append('Behavioral')
            convergence_values.append(all_behavioral)
    
    if 'activation_convergence' in results:
        act_conv = results['activation_convergence']['convergence']
        all_activation = []
        for pair_data in act_conv.values():
            for layer_data in pair_data.values():
                all_activation.extend(layer_data['similarities'])
        if all_activation:
            convergence_types.append('Activation')
            convergence_values.append(all_activation)
    
    if 'sae_feature_convergence' in results:
        sae_conv = results['sae_feature_convergence']['convergence']
        all_sae = []
        for pair_data in sae_conv.values():
            for layer_data in pair_data.values():
                all_sae.extend(layer_data['similarities'])
        if all_sae:
            convergence_types.append('SAE Features')
            convergence_values.append(all_sae)
    
    if convergence_values:
        axes[1, 0].boxplot(convergence_values, labels=convergence_types)
        axes[1, 0].set_ylabel('Convergence Score')
        axes[1, 0].set_title('Convergence by Analysis Type')
        axes[1, 0].grid(axis='y', alpha=0.3)
    
    # Analysis 5: Statistical significance
    if len(convergence_values) >= 2:
        from scipy import stats
        
        # Perform statistical tests between different convergence measures
        test_results = []
        for i in range(len(convergence_values)):
            for j in range(i+1, len(convergence_values)):
                if len(convergence_values[i]) > 0 and len(convergence_values[j]) > 0:
                    statistic, p_value = stats.mannwhitneyu(
                        convergence_values[i], convergence_values[j], 
                        alternative='two-sided'
                    )
                    test_results.append({
                        'comparison': f'{convergence_types[i]} vs {convergence_types[j]}',
                        'p_value': p_value,
                        'significant': p_value < 0.05
                    })
        
        # Display statistical results
        stat_text = "Statistical Significance:\n\n"
        for result in test_results:
            sig_symbol = "***" if result['p_value'] < 0.001 else "**" if result['p_value'] < 0.01 else "*" if result['significant'] else "ns"
            stat_text += f"{result['comparison']}\np = {result['p_value']:.4f} {sig_symbol}\n\n"
        
        axes[1, 1].text(0.05, 0.95, stat_text, transform=axes[1, 1].transAxes,
                        verticalalignment='top', fontfamily='monospace',
                        bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
        axes[1, 1].set_title('Statistical Analysis')
        axes[1, 1].axis('off')
    
    # Analysis 6: Key findings summary
    findings_text = "KEY FINDINGS:\n\n"
    
    # Overall convergence levels
    if convergence_values:
        overall_avg = np.mean([np.mean(cv) for cv in convergence_values])
        findings_text += f"• Overall Convergence: {overall_avg:.3f}\n"
        
        if overall_avg > 0.7:
            findings_text += "  → Strong universal patterns\n"
        elif overall_avg > 0.4:
            findings_text += "  → Moderate universal patterns\n"
        else:
            findings_text += "  → Weak universal patterns\n"
    
    findings_text += "\n"
    
    # Model-specific findings
    if 'behavioral_convergence' in results:
        behavioral = results['behavioral_convergence']['convergence']
        if behavioral:
            # Find best and worst pairs
            best_pair = max(behavioral.items(), key=lambda x: x[1]['mean_similarity'])
            worst_pair = min(behavioral.items(), key=lambda x: x[1]['mean_similarity'])
            
            findings_text += f"• Most Similar: {best_pair[0]}\n"
            findings_text += f"  Similarity: {best_pair[1]['mean_similarity']:.3f}\n\n"
            findings_text += f"• Least Similar: {worst_pair[0]}\n"
            findings_text += f"  Similarity: {worst_pair[1]['mean_similarity']:.3f}\n\n"
    
    # Architecture insights
    findings_text += "• Architecture Insights:\n"
    if 'Gemma-2-2B-Base_vs_Gemma-2-2B-Instruct' in behavioral:
        base_vs_instruct = behavioral['Gemma-2-2B-Base_vs_Gemma-2-2B-Instruct']['mean_similarity']
        findings_text += f"  Base vs Instruct (2B): {base_vs_instruct:.3f}\n"
    
    axes[1, 2].text(0.05, 0.95, findings_text, transform=axes[1, 2].transAxes,
                    verticalalignment='top', fontsize=10,
                    bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))
    axes[1, 2].set_title('Key Findings')
    axes[1, 2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print comprehensive interpretation
    print("\n🔍 EXPERIMENT INTERPRETATION:")
    print("=" * 60)
    
    if convergence_values:
        overall_avg = np.mean([np.mean(cv) for cv in convergence_values])
        
        print(f"\n🎯 UNIVERSAL PATTERN EVIDENCE:")
        if overall_avg > 0.7:
            print(f"   ✅ STRONG evidence for universal alignment patterns ({overall_avg:.3f})")
            print(f"      Models show high convergence across multiple analysis levels")
        elif overall_avg > 0.4:
            print(f"   ⚠️  MODERATE evidence for universal patterns ({overall_avg:.3f})")
            print(f"      Some convergence detected but patterns may be architecture-specific")
        else:
            print(f"   ❌ WEAK evidence for universal patterns ({overall_avg:.3f})")
            print(f"      Models show significant differences in alignment approaches")
    
    print(f"\n🧠 MECHANISTIC INSIGHTS:")
    if 'sae_feature_convergence' in results:
        print(f"   • SAE features reveal interpretable alignment mechanisms")
        print(f"   • Layer-specific convergence patterns provide training insights")
    
    print(f"\n🚀 IMPLICATIONS FOR AI SAFETY:")
    print(f"   • Results inform universal alignment strategies")
    print(f"   • Cross-model transfer potential for safety mechanisms")
    print(f"   • Foundation for alignment pattern databases")
    
    print(f"\n📊 EXPERIMENT QUALITY:")
    print(f"   • Multi-level analysis provides robust evidence")
    print(f"   • SAE interpretability adds mechanistic understanding")
    print(f"   • Statistical validation ensures reliability")
    
else:
    print("❌ No results files found. Please run the experiment first.")

## 🎯 Next Steps & Extensions

Based on your results, here are suggested next steps for extending this research.

In [None]:
# 🎯 NEXT STEPS & EXTENSIONS

print("🎯 NEXT STEPS FOR UNIVERSAL ALIGNMENT RESEARCH:")
print("=" * 55)

print("\n🔬 IMMEDIATE EXTENSIONS:")
print("  1. Scale to more model families (LLaMA, Claude, GPT)")
print("  2. Add more capability dimensions (reasoning, creativity)")
print("  3. Cross-architecture SAE transfer experiments")
print("  4. Temporal analysis across model training checkpoints")

print("\n🎛️ METHODOLOGICAL IMPROVEMENTS:")
print("  1. Implement more sophisticated similarity metrics")
print("  2. Add causal intervention experiments")
print("  3. Include human baseline comparisons")
print("  4. Statistical significance testing with larger samples")

print("\n🚀 ADVANCED RESEARCH DIRECTIONS:")
print("  1. Universal alignment transfer learning")
print("  2. Cross-modal convergence analysis (text, vision, code)")
print("  3. Adversarial robustness of universal patterns")
print("  4. Mechanistic interpretability of convergence failures")

print("\n🔧 TECHNICAL OPTIMIZATIONS:")
print("  1. Distributed computing for large-scale experiments")
print("  2. Automated hyperparameter tuning")
print("  3. Real-time experiment monitoring")
print("  4. Integration with MLOps pipelines")

print("\n📊 REPRODUCIBILITY & COLLABORATION:")
print("  • All code and results are saved for reproduction")
print("  • Framework is modular and extensible")
print("  • Compatible with standard ML research workflows")
print("  • Ready for collaborative research initiatives")

print("\n🎉 CONGRATULATIONS!")
print("You've successfully implemented a cutting-edge experiment")
print("combining SAE interpretability with convergence analysis.")
print("\nThis framework provides a solid foundation for")
print("advancing universal alignment pattern research!")

# Show file structure
print(f"\n📁 EXPERIMENT FILES:")
for file_path in experiment.output_dir.glob("*"):
    if file_path.is_file():
        size_mb = file_path.stat().st_size / (1024 * 1024)
        print(f"  📄 {file_path.name} ({size_mb:.1f} MB)")

print(f"\n💾 Total experiment data: {experiment.output_dir}")