# SAE Training - Minimal Working Version for Colab

This version uses only Colab's pre-installed packages to avoid conflicts.

**Enable GPU**: Runtime → Change runtime type → T4 GPU

## 1. Setup (No Installation Needed)

In [None]:
# Use only pre-installed packages
import os
import sys
import json
import csv
from datetime import datetime

# Check GPU
!nvidia-smi | grep "Tesla" || echo "No GPU found - enable it in Runtime settings"

In [None]:
# Clone your repository
!rm -rf SAE_Functional
!git clone https://github.com/ychleee/SAE_Functional.git
%cd SAE_Functional

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create results directory
DRIVE_DIR = '/content/drive/MyDrive/sae_conditionals_results'
!mkdir -p {DRIVE_DIR}
print(f"Results will be saved to: {DRIVE_DIR}")

## 2. Install Only Essential Packages

In [None]:
# Install only what we absolutely need
!pip install -q transformers accelerate einops

## 3. Load Data (No Pandas)

In [None]:
# Load data using pure Python
texts = []
labels = []

with open('data/processed/conditionals_dataset.csv', 'r') as f:
    reader = csv.DictReader(f)
    for row in reader:
        texts.append(row['text'])
        labels.append(row['has_conditional'] == 'True')

print(f"Loaded {len(texts)} sentences")
print(f"Conditionals: {sum(labels)}")
print(f"Controls: {len(labels) - sum(labels)}")
print("\nSamples:")
for i in range(3):
    print(f"{i+1}. [{labels[i]}] {texts[i][:60]}...")

## 4. Load Model and Extract Activations

In [None]:
import torch
from transformers import AutoModel, AutoTokenizer

# Setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Load small model
model_name = "EleutherAI/pythia-70m"
print(f"Loading {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
model.eval()

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Model loaded!")

In [None]:
# Extract activations
def get_activations(texts, batch_size=16, max_samples=500):
    texts = texts[:max_samples]
    activations = []
    
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        
        # Tokenize
        inputs = tokenizer(
            batch,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128
        ).to(device)
        
        # Get hidden states
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            hidden = outputs.hidden_states[-1]  # Last layer
            
            # Mean pooling
            mask = inputs.attention_mask.unsqueeze(-1)
            pooled = (hidden * mask).sum(1) / mask.sum(1)
            
            activations.append(pooled.cpu())
        
        if (i // batch_size) % 10 == 0:
            print(f"Processed {i}/{len(texts)} texts")
    
    return torch.cat(activations, dim=0)

print("Extracting activations...")
acts = get_activations(texts, batch_size=16, max_samples=500)
print(f"Activations shape: {acts.shape}")

## 5. Train Simple SAE

In [None]:
import torch.nn as nn
import torch.optim as optim

class SAE(nn.Module):
    def __init__(self, input_dim, hidden_dim=1024, sparsity=0.01):
        super().__init__()
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)
        self.sparsity = sparsity
        
        # Tie weights
        self.decoder.weight = nn.Parameter(self.encoder.weight.t())
        
        # Initialize
        nn.init.xavier_uniform_(self.encoder.weight)
        nn.init.zeros_(self.encoder.bias)
    
    def forward(self, x):
        code = torch.relu(self.encoder(x))
        recon = self.decoder(code)
        return recon, code

# Create SAE
input_dim = acts.shape[1]
sae = SAE(input_dim, hidden_dim=1024, sparsity=0.01).to(device)
print(f"SAE: {input_dim} -> 1024 features")

In [None]:
# Training
optimizer = optim.Adam(sae.parameters(), lr=0.001)
acts_gpu = acts.to(device)

print("Training SAE...")
for epoch in range(50):
    # Forward
    recon, code = sae(acts_gpu)
    
    # Loss
    recon_loss = nn.functional.mse_loss(recon, acts_gpu)
    sparse_loss = sae.sparsity * code.abs().mean()
    loss = recon_loss + sparse_loss
    
    # Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Loss={loss:.4f} (R:{recon_loss:.4f} S:{sparse_loss:.4f})")

print("Training complete!")

## 6. Analyze Features

In [None]:
# Get sparse codes
with torch.no_grad():
    _, codes = sae(acts_gpu)
    codes = codes.cpu().numpy()

# Split by label
import numpy as np
labels_array = np.array(labels[:len(codes)])
cond_codes = codes[labels_array]
ctrl_codes = codes[~labels_array]

# Find differential features
cond_mean = cond_codes.mean(0)
ctrl_mean = ctrl_codes.mean(0)
diff = cond_mean - ctrl_mean

# Top features
top_idx = np.argsort(diff)[-10:][::-1]

print("Top 10 Conditional Features:")
for i, idx in enumerate(top_idx):
    print(f"{i+1}. Feature {idx}: diff={diff[idx]:.3f}")

# Stats
print(f"\nActive features per sample:")
print(f"  Conditionals: {(cond_codes > 0).sum(1).mean():.1f}")
print(f"  Controls: {(ctrl_codes > 0).sum(1).mean():.1f}")

In [None]:
# Find example sentences for top feature
top_feat = top_idx[0]
feat_acts = codes[:, top_feat]
top_examples = np.argsort(feat_acts)[-5:][::-1]

print(f"\nTop activating texts for Feature {top_feat}:")
for i, idx in enumerate(top_examples):
    print(f"{i+1}. [{feat_acts[idx]:.2f}] {texts[idx][:80]}")

## 7. Save Results

In [None]:
# Save to Drive
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
save_dir = f"{DRIVE_DIR}/run_{timestamp}"
!mkdir -p {save_dir}

# Save model
torch.save({
    'model': sae.state_dict(),
    'input_dim': input_dim,
    'hidden_dim': 1024
}, f"{save_dir}/sae_model.pt")

# Save analysis
results = {
    'top_features': top_idx.tolist(),
    'diff_scores': diff[top_idx].tolist(),
    'n_samples': len(codes),
    'n_conditional': int(labels_array.sum())
}

with open(f"{save_dir}/results.json", 'w') as f:
    json.dump(results, f, indent=2)

print(f"\nResults saved to: {save_dir}")
print("Download from Google Drive when done!")