# SAE Training on Conditional Features (Standalone Colab Version)

This is a standalone version that works around Colab dependency issues.

**Make sure GPU is enabled**: Runtime → Change runtime type → T4 GPU

## 1. Clean Install of Dependencies

In [None]:
# Complete reset and install
!pip uninstall -y numpy pandas torch transformers tokenizers numexpr
!pip install numpy==1.23.5
!pip install pandas==1.5.3
!pip install torch==2.0.1
!pip install transformers==4.35.0
!pip install pyyaml tqdm matplotlib einops accelerate

In [None]:
# Restart runtime programmatically
import os
os.kill(os.getpid(), 9)

## 2. After Restart - Setup Environment

**RUN FROM HERE after the runtime restarts**

In [None]:
import os
import sys

# Check GPU
!nvidia-smi

# Clone repository
!git clone https://github.com/ychleee/SAE_Functional.git
%cd SAE_Functional

# Add to path
sys.path.append('/content/SAE_Functional/src')

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create results directory
DRIVE_DIR = '/content/drive/MyDrive/sae_conditionals_results'
!mkdir -p {DRIVE_DIR}
print(f"Results will be saved to: {DRIVE_DIR}")

## 3. Load Data and Configuration

In [None]:
import yaml
import torch
import numpy as np
from pathlib import Path
import csv

# Load configuration
with open('configs/training_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded:")
print(f"Model: {config['model']['name']}")
print(f"SAE features: {config['sae']['hidden_dim']}")
print(f"Device: cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load dataset without pandas
data_path = 'data/processed/conditionals_dataset.csv'

texts = []
types = []
has_conditional = []

with open(data_path, 'r') as f:
    reader = csv.DictReader(f)
    for row in reader:
        texts.append(row['text'])
        types.append(row['type'])
        has_conditional.append(row['has_conditional'] == 'True')

print(f"Loaded {len(texts)} sentences")
print(f"Conditionals: {sum(has_conditional)}")
print(f"Controls: {len(has_conditional) - sum(has_conditional)}")
print(f"\nSample texts:")
for i in range(3):
    print(f"{i+1}. {texts[i][:80]}...")

## 4. Extract Activations (Simplified)

In [None]:
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm

# Initialize model
model_name = config['model']['name']
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
model.eval()

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded on {device}")

In [None]:
# Extract activations in batches
def extract_activations(texts, batch_size=32, max_samples=500):
    """Extract activations for texts."""
    texts = texts[:max_samples]  # Limit for memory
    all_activations = []
    
    for i in tqdm(range(0, len(texts), batch_size), desc="Extracting"):
        batch_texts = texts[i:i+batch_size]
        
        # Tokenize
        inputs = tokenizer(
            batch_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128
        ).to(device)
        
        # Forward pass
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            # Get last layer activations
            hidden = outputs.hidden_states[-1]
            # Mean pool
            mask = inputs.attention_mask.unsqueeze(-1)
            pooled = (hidden * mask).sum(1) / mask.sum(1)
            all_activations.append(pooled.cpu())
    
    return torch.cat(all_activations, dim=0)

# Extract activations
print("Extracting activations...")
activations = extract_activations(texts, batch_size=16, max_samples=500)
print(f"Activations shape: {activations.shape}")

## 5. Train Sparse Autoencoder

In [None]:
import torch.nn as nn
import torch.optim as optim

class SimpleSAE(nn.Module):
    """Simple Sparse Autoencoder."""
    
    def __init__(self, input_dim, hidden_dim, sparsity_coeff=0.01):
        super().__init__()
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)
        self.sparsity_coeff = sparsity_coeff
        
        # Tie weights
        self.decoder.weight = nn.Parameter(self.encoder.weight.t())
        
        # Initialize
        nn.init.xavier_uniform_(self.encoder.weight)
        nn.init.zeros_(self.encoder.bias)
    
    def forward(self, x):
        code = self.relu(self.encoder(x))
        reconstruction = self.decoder(code)
        return reconstruction, code
    
    def loss(self, x):
        recon, code = self.forward(x)
        recon_loss = nn.functional.mse_loss(recon, x)
        sparsity_loss = self.sparsity_coeff * code.abs().mean()
        return recon_loss + sparsity_loss, recon_loss, sparsity_loss

# Create SAE
input_dim = activations.shape[1]
hidden_dim = config['sae']['hidden_dim']
sae = SimpleSAE(input_dim, hidden_dim, config['sae']['sparsity_coeff']).to(device)

print(f"SAE created: {input_dim} -> {hidden_dim} features")

In [None]:
# Training loop
from torch.utils.data import TensorDataset, DataLoader

# Prepare data
n_train = int(len(activations) * 0.8)
train_data = activations[:n_train].to(device)
val_data = activations[n_train:].to(device)

train_dataset = TensorDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Optimizer
optimizer = optim.Adam(sae.parameters(), lr=0.001)

# Train
print("Training SAE...")
for epoch in range(50):  # Reduced epochs for demo
    total_loss = 0
    for batch in train_loader:
        x = batch[0]
        
        loss, recon_loss, sparse_loss = sae.loss(x)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    if epoch % 10 == 0:
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch}: Loss = {avg_loss:.4f}")

print("Training complete!")

## 6. Analyze Features

In [None]:
# Encode all activations
with torch.no_grad():
    all_codes = sae.relu(sae.encoder(activations.to(device))).cpu().numpy()

# Find conditional vs non-conditional differences
has_cond_array = np.array(has_conditional[:len(all_codes)])
cond_codes = all_codes[has_cond_array]
non_cond_codes = all_codes[~has_cond_array]

# Average activations
cond_avg = cond_codes.mean(axis=0)
non_cond_avg = non_cond_codes.mean(axis=0)
diff = cond_avg - non_cond_avg

# Top differential features
top_features = np.argsort(diff)[-10:][::-1]

print("Top 10 Conditional Features:")
for i, feat_idx in enumerate(top_features):
    print(f"{i+1}. Feature {feat_idx}: diff = {diff[feat_idx]:.3f}")

# Statistics
print(f"\nAverage active features:")
print(f"  Conditionals: {(cond_codes > 0).sum(axis=1).mean():.1f}")
print(f"  Non-conditionals: {(non_cond_codes > 0).sum(axis=1).mean():.1f}")

## 7. Save Results

In [None]:
from datetime import datetime
import json

# Create timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_dir = f"{DRIVE_DIR}/run_{timestamp}"
!mkdir -p {run_dir}

# Save model
model_path = f"{run_dir}/sae_model.pt"
torch.save({
    'model_state_dict': sae.state_dict(),
    'input_dim': input_dim,
    'hidden_dim': hidden_dim,
    'config': config
}, model_path)
print(f"Model saved to: {model_path}")

# Save analysis
analysis = {
    'top_conditional_features': top_features.tolist(),
    'differential_scores': diff[top_features].tolist(),
    'n_samples': len(activations),
    'n_conditionals': int(has_cond_array.sum()),
    'timestamp': timestamp
}

with open(f"{run_dir}/analysis.json", 'w') as f:
    json.dump(analysis, f, indent=2)

print(f"\nAll results saved to: {run_dir}")
print("\nYou can download the results from your Google Drive!")