In [5]:
import os
import sys
from pathlib import Path
from safetensors.torch import load_model
import json

import torch

import einops
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import requests

from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
#from jaxtyping import Float, Int

from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy


#sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..",)))
#print("\n".join(sys.path))
current_dir = os.path.dirname(os.path.abspath("spar_sae_circuit_sandbox.ipynb"))
model_dir = os.path.join(current_dir, '..') # Assuming it's one level up
#toy_model_dir = os.path.join(current_dir, '..', 'llm_from_scratch/LLM_from_scratch/')
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
checkpoint_dir = Path('/Volumes/MacMini/gpt-circuits') / "checkpoints"
gpt_dir = checkpoint_dir / "shakespeare_64x4"
sae_dir = checkpoint_dir / "standard.shakespeare_64x4"

sys.path.append(model_dir)
#sys.path.append(toy_model_dir)

from config.gpt.training import options
from config.sae.models import sae_options
from models.gpt import GPT
from models.sparsified import SparsifiedGPT
from data.tokenizers import ASCIITokenizer, TikTokenTokenizer

# Loading SparsifiedGPT


# Imports from the project
from config.sae.models import SAEConfig
from models.sparsified import SparsifiedGPT
from models.gpt import GPT



In [6]:
# Load GPT model
print("Loading GPT model...")
gpt = GPT.load(gpt_dir, device=device)
    
# Load SAE config
print("Loading SAE configuration...")
sae_config_dir = sae_dir / "sae.json"
with open(sae_config_dir, "r") as f:
    meta = json.load(f)
config = SAEConfig(**meta)
config.gpt_config = gpt.config
    
# Create model using saved config
print("Creating SparsifiedGPT model...")
model = SparsifiedGPT(config)
model.gpt = gpt
    
# Load SAE weights
print("Loading SAE weights...")
for layer_name, module in model.saes.items():
    weights_path = os.path.join(sae_dir, f"sae.{layer_name}.safetensors")
    load_model(module, weights_path, device=device.type)

Loading GPT model...
Loading SAE configuration...
Creating SparsifiedGPT model...
Loading SAE weights...


In [7]:
#from utils import generate
#c_name = 'standardx8.shakespeare_64x4'
name = 'standard.shakespeare_64x4'
#name = 'shakespeare_64x4'
#config = sae_options[c_name]

#model = SparsifiedGPT(config)
#model_path = os.path.join("../checkpoints", name)
#model = model.load(model_path, device=config.device)

#load tokenizer
tokenizer = ASCIITokenizer() if "shake" in name else TikTokenTokenizer()

In [8]:
def generate(model, tokenizer, prompt, max_length=50, temperature=0.7) -> str:
    """
    Generate text from a prompt using the model
    """
    tokens = tokenizer.encode(prompt)
    tokens = torch.Tensor(tokens).long().unsqueeze(0)
    
    for _ in range(max_length):
        logits = model(tokens).logits[0][-1]
        probs = torch.softmax(logits / temperature, dim=-1)
        #next_token = torch.multinomial(probs, num_samples=1)
        next_token = torch.argmax(probs, keepdim=True)
        
        tokens = torch.cat([tokens.squeeze(0), next_token], dim=-1).unsqueeze(0)
        
    #return tokenizer.decode_sequence(tokens[0].tolist())
    return tokens

In [9]:
prompt = "His name is Licio, born in Mantua."
tokens = tokenizer.encode(prompt)
tokens = torch.Tensor(tokens).long().unsqueeze(0)
with torch.no_grad():
    model_output = model(tokens)

# Generate output text/tokens using "generate" function
#output = generate(model, tokenizer, prompt, max_length=1)
#print(output)
#print(tokenizer.decode_sequence(output[0].tolist()))

In [616]:
#with model.use_saes():
    #output_sae = model(tokens)

In [11]:
model_output

SparsifiedGPTOutput(logits=tensor([[[-15.0021, -14.9872, -14.9872,  ..., -15.0066, -15.0015, -14.9803],
         [-20.9965, -20.9861, -21.0117,  ..., -20.9863, -20.9932, -20.9777],
         [-13.5827, -13.5573, -13.5960,  ..., -13.5951, -13.5750, -13.5636],
         ...,
         [-15.3942, -15.4329, -15.4454,  ..., -15.4078, -15.4222, -15.4405],
         [ -9.6964,  -9.7105,  -9.7095,  ...,  -9.7102,  -9.6881,  -9.7186],
         [-13.8675, -13.8698, -13.8927,  ..., -13.8691, -13.8846, -13.8751]]]), cross_entropy_loss=None, activations={0: tensor([[[-0.3428, -0.0691,  0.1505,  ..., -0.3408, -0.0127,  0.1800],
         [ 0.0057, -0.0554, -0.2171,  ...,  0.0312, -0.0692,  0.2164],
         [ 0.0641,  0.1122,  0.0556,  ..., -0.0218, -0.1222,  0.0047],
         ...,
         [ 0.0346, -0.2275,  0.0020,  ..., -0.1253, -0.0569,  0.1417],
         [ 0.0639, -0.1175,  0.1612,  ...,  0.0462,  0.1182, -0.0252],
         [ 0.0842, -0.0115,  0.0195,  ...,  0.0756, -0.0850,  0.0403]]]), 1: tensor(

In [12]:
#feature layers
feat_layer0 = model_output.feature_magnitudes[0].squeeze(0)
feat_layer1 = model_output.feature_magnitudes[1].squeeze(0)
feat_layer2 = model_output.feature_magnitudes[2].squeeze(0)
feat_layer3 = model_output.feature_magnitudes[3].squeeze(0)
feat_layer4 = model_output.feature_magnitudes[4].squeeze(0)

#minimum value a feature can be considered "active"
feat_threshold = 0

In [19]:
feat_product4 = einops.einsum(feat_layer3, feat_layer4, "token act1, token act2 -> token act1 act2")
feat_product4[-1].shape

torch.Size([512, 512])

In [21]:
mask = torch.zeros(feat_product4[-1].shape)
mask[torch.where(feat_product4[-1] > feat_threshold)] = 1


In [24]:
model

SparsifiedGPT(
  (gpt): GPT(
    (transformer): ModuleDict(
      (wte): Embedding(128, 64)
      (wpe): Embedding(128, 64)
      (h): ModuleList(
        (0-3): 4 x Block(
          (ln_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (attn): CausalSelfAttention(
            (c_attn): Linear(in_features=64, out_features=192, bias=True)
            (c_proj): Linear(in_features=64, out_features=64, bias=True)
          )
          (ln_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (c_fc): Linear(in_features=64, out_features=256, bias=True)
            (gelu): GELU(approximate='tanh')
            (c_proj): Linear(in_features=256, out_features=64, bias=True)
          )
        )
      )
      (ln_f): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=64, out_features=128, bias=False)
  )
  (saes): ModuleDict(
    (0): StandardSAE()
    (1): StandardSAE()
    (2): StandardSAE()
    (3):

In [26]:
model.saes['0']()

StandardSAE()

In [91]:
# Store the output of a specific layer
intermediate_outputs = {}

def hook_fn(module, input, output):
    intermediate_outputs['target_layer'] = output

# Register the hook on the layer you want
model.layer3.register_forward_hook(hook_fn)

# Run the full model
_ = model(input_data)

# Access the intermediate output
partial_output = intermediate_outputs['target_layer']

TypeError: 'SparsifiedGPT' object is not subscriptable

In [78]:
# load the prompts that return the correct tokens
double_newline_prompts = torch.load('double_newline_prompts.pt', weights_only=True)

In [28]:
from copy import deepcopy

def explore_sae_connections_sparsity(model, input_data, layer_i=0, layer_j=1, 
                                    pruning_fractions=np.linspace(0, 0.95, 20),
                                    pruning_methods=['magnitude', 'random'],
                                    num_trials=5):
    """
    Explore sparsity of connections between adjacent SAE layers by pruning 
    connections and measuring impact on output logits.
    
    Args:
        model: SparsifiedGPT model
        input_data: Input tensor to process
        layer_i, layer_j: Indices of adjacent SAE layers to analyze
        pruning_fractions: List of pruning fractions to test
        pruning_methods: List of pruning methods to use
        num_trials: Number of trials for random pruning
    """
    # Get baseline activations and logits
    with torch.no_grad():
        baseline_output = model(input_data)
    
    baseline_logits = baseline_output.logits
    feat_layer_i = baseline_output.feature_magnitudes[layer_i].squeeze(0)  # Shape: (H, W_i)
    feat_layer_j = baseline_output.feature_magnitudes[layer_j].squeeze(0)  # Shape: (H, W_j)
    
    # Compute connection strengths for each position in H
    # For each h, compute outer product between feat_layer_i[h] and feat_layer_j[h]
    # Shape: (H, W_i, W_j)
    H, W_i = feat_layer_i.shape
    _, W_j = feat_layer_j.shape
    
    # Use einsum for outer product while preserving H dimension
    connection_strengths = torch.einsum('hw,hv->hwv', feat_layer_i, feat_layer_j)
    
    # Results storage
    results = {method: {'kl_divs': [], 'logit_diffs': []} for method in pruning_methods}
    
    # For each pruning method and fraction, patch and measure impact
    for method in pruning_methods:
        all_kl_divs = []
        all_logit_diffs = []
        
        for fraction in pruning_fractions:
            kl_divs_trials = []
            logit_diffs_trials = []
            
            # Multiple trials for random pruning
            trials = num_trials if method == 'random' else 1
            
            for trial in range(trials):
                # Create pruning mask with same shape as connection_strengths
                mask = torch.ones_like(connection_strengths)
                
                if fraction > 0:
                    if method == 'magnitude':
                        # Prune weakest connections across all positions
                        flat_strengths = connection_strengths.abs().flatten()
                        num_to_prune = int(fraction * flat_strengths.numel())
                        threshold = torch.sort(flat_strengths)[0][num_to_prune]
                        mask = (connection_strengths.abs() > threshold).float()
                    elif method == 'random':
                        # Randomly prune connections
                        num_to_prune = int(fraction * mask.numel())
                        prune_indices = torch.randperm(mask.numel())[:num_to_prune]
                        mask.view(-1)[prune_indices] = 0
                
                # Step 4: Create patched features for layer_j
                def patch_hook(module, input, output):
                    # Get activations from layer_i
                    layer_i_activations = output.feature_magnitudes[layer_i]  # (1, H, W_i)
                    
                    # Original layer_j activations
                    orig_layer_j = output.feature_magnitudes[layer_j]  # (1, H, W_j)
                    
                    # Create patched layer_j activations based on masked connections
                    patched_layer_j = torch.zeros_like(orig_layer_j)
                    
                    # For each position h
                    for h in range(H):
                        # Get feature vectors at position h
                        feat_i_h = layer_i_activations[0, h]  # (W_i)
                        
                        # Apply masked connections to compute patched layer_j at position h
                        # For each feature in layer_j, sum contributions from all features in layer_i
                        for w_i in range(W_i):
                            for w_j in range(W_j):
                                if mask[h, w_i, w_j] > 0:
                                    # If connection exists, propagate activation
                                    patched_layer_j[0, h, w_j] += feat_i_h[w_i] * feat_layer_j[h, w_j] / feat_layer_i[h, w_i]
                    
                    # Replace layer_j activations
                    output.feature_magnitudes[layer_j] = patched_layer_j
                    return output
                
                # Register the hook
                hook_handle = model.register_forward_hook(patch_hook)
                
                # Run patched forward pass
                with torch.no_grad():
                    patched_output = model(input_data)
                
                # Remove the hook
                hook_handle.remove()
                
                # Measure impact on logits
                patched_logits = patched_output.logits
                
                # KL divergence between original and patched logits
                kl_div = torch.nn.functional.kl_div(
                    torch.log_softmax(patched_logits, dim=-1),
                    torch.softmax(baseline_logits, dim=-1),
                    reduction='batchmean'
                )
                
                # L2 distance between logits
                logit_diff = torch.norm(patched_logits - baseline_logits)
                
                kl_divs_trials.append(kl_div.item())
                logit_diffs_trials.append(logit_diff.item())
            
            # Average results across trials
            all_kl_divs.append(np.mean(kl_divs_trials))
            all_logit_diffs.append(np.mean(logit_diffs_trials))
        
        results[method]['kl_divs'] = all_kl_divs
        results[method]['logit_diffs'] = all_logit_diffs
    
    # Step 5: Analyze and visualize results
    plt.figure(figsize=(12, 10))
    
    plt.subplot(2, 2, 1)
    for method in pruning_methods:
        plt.plot(pruning_fractions, results[method]['kl_divs'], label=method)
    plt.xlabel('Pruning Fraction')
    plt.ylabel('KL Divergence')
    plt.title(f'Impact of Pruning SAE Connections (Layers {layer_i}->{layer_j})')
    plt.legend()
    
    plt.subplot(2, 2, 2)
    for method in pruning_methods:
        plt.plot(pruning_fractions, results[method]['logit_diffs'], label=method)
    plt.xlabel('Pruning Fraction')
    plt.ylabel('Logit L2 Distance')
    plt.title('Change in Output Logits')
    plt.legend()
    
    # Visualize connection strength distribution
    plt.subplot(2, 2, 3)
    plt.hist(connection_strengths.abs().flatten().cpu().numpy(), bins=50)
    plt.xlabel('Connection Strength (Absolute Value)')
    plt.ylabel('Frequency')
    plt.title('Distribution of Connection Strengths')
    
    # Visualize connection matrix
    plt.subplot(2, 2, 4)
    avg_conn = connection_strengths.abs().mean(dim=0).cpu()  # Average across H
    plt.imshow(avg_conn, cmap='viridis')
    plt.colorbar(label='Average Strength')
    plt.xlabel(f'Features in Layer {layer_j}')
    plt.ylabel(f'Features in Layer {layer_i}')
    plt.title('Average Connection Strength Matrix')
    
    plt.tight_layout()
    plt.show()
    
    return results, connection_strengths


In [31]:
input_data = tokens
results, connection_strengths = explore_sae_connections_sparsity(
    model, input_data, 
    layer_i=0, layer_j=1,
    pruning_fractions=[0, 0.9, 0.99],
    pruning_methods=['magnitude'],
    num_trials=1
)

KeyboardInterrupt: 

In [32]:
# Check if GPU is available
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Move model to GPU
model = model.to(device)

# Make sure input data is on GPU for each analysis
input_data = input_data.to(device)

Using device: mps


In [35]:
def quick_sae_connection_analysis(model, input_data, layer_i=0, layer_j=1):
    """Simplified analysis focusing on just key pruning levels"""
    
    # Get baseline
    with torch.no_grad():
        baseline_output = model(input_data)
    
    baseline_logits = baseline_output.logits
    feat_layer_i = baseline_output.feature_magnitudes[layer_i].squeeze(0)
    feat_layer_j = baseline_output.feature_magnitudes[layer_j].squeeze(0)
    
    # Compute connection matrix
    connection_strengths = torch.einsum('hw,hv->hwv', feat_layer_i, feat_layer_j)
    
    # Just test two pruning levels (90% and 99%)
    results = {}
    
    for pruning_frac in [0.9, 0.99]:
        # Create mask for magnitude pruning
        flat_strengths = connection_strengths.abs().flatten()
        num_to_prune = int(pruning_frac * flat_strengths.numel())
        threshold = torch.sort(flat_strengths)[0][num_to_prune]
        mask = (connection_strengths.abs() > threshold).float()
        
        # Simple intervention with masked connections
        def intervention_hook(module, input, output):
            # Simplified patching logic
            return patched_output
            
        # Run intervened forward pass
        hook_handle = model.register_forward_hook(intervention_hook)
        with torch.no_grad():
            patched_output = model(input_data)
        hook_handle.remove()
        
        # Calculate simple metric like L2 distance
        logit_diff = torch.norm(patched_output.logits - baseline_logits)
        results[pruning_frac] = logit_diff.item()
    
    # Just print results instead of visualizing
    print(f"Pruning 90%: Logit diff = {results[0.9]}")
    print(f"Pruning 99%: Logit diff = {results[0.99]}")
    
    return results, connection_strengths

In [36]:
results, connection_strengths = quick_sae_connection_analysis(model, input_data,
                                                              layer_i=0,
                                                              layer_j=1)

NameError: cannot access free variable 'patched_output' where it is not associated with a value in enclosing scope