In [None]:
def install_dependencies():
    import os
    
    # Set GitHub token as an environment variable
    os.environ['GITHUB_TOKEN'] = "gh_"
    
    # Remove existing directories if they exist
    ! rm -rf sae || True
    ! rm -rf TinySQL || True
    
    # Clone repositories with authentication
    ! pip install --upgrade pip
    ! git clone https://${GITHUB_TOKEN}@github.com/amirabdullah19852020/sae.git
    ! cd sae && git checkout old_saes && pip install .
    
    ! git clone https://${GITHUB_TOKEN}@github.com/withmartian/TinySQL.git
    ! cd TinySQL && pip install .
    
install_dependencies()

### You may need to restart notebook after the dependencies installation.

In [None]:
! pip install circuitsvis nnsight plotly matplotlib einops -q

In [None]:
import json
import os

os.environ["SAE_DISABLE_TRITON"] = "1"

import psutil
import re

from copy import deepcopy
from dataclasses import dataclass
from IPython.display import display, HTML
from typing import Callable
from math import ceil
from pathlib import Path

import numpy as np
import plotly.graph_objects as go
import sae
import torch
import torch.fx
import einops
from datasets import load_dataset
from huggingface_hub import snapshot_download
from nnsight import NNsight, LanguageModel
from sae import Sae
from sae.sae_interp import GroupedSaeOutput, SaeOutput, SaeCollector, LoadedSAES
from sae.sae_plotting import plot_layer_curves, plot_layer_features

from tqdm import tqdm
from TinySQL import sql_interp_model_location


In [None]:
# Get the current process
def process_info():
    process = psutil.Process(os.getpid())
    
    # Memory usage in MB
    memory_info = process.memory_info()
    print(f"RSS: {memory_info.rss / (1024 ** 2):.2f} MB")  # Resident Set Size
    print(f"VMS: {memory_info.vms / (1024 ** 2):.2f} MB") 

process_info()

In [None]:
seed=42
repo_name = "withmartian/sql_interp_saes"
cache_dir = "working_directory"

syn=True

full_model_name = sql_interp_model_location(model_num=2, cs_num=1, synonym=syn)
model_alias = f"saes_{full_model_name.split('/')[1]}_syn={syn}"
print(model_alias)


process_info()

In [None]:
! huggingface-cli login --token hf_

In [None]:
repo_path = Path(
    snapshot_download(repo_name, allow_patterns=f"{model_alias}/*", local_dir=cache_dir)
)

In [None]:
def format_example(example):
    alpaca_prompt = "### Instruction: {} ### Context: {} ### Response: {}"
    example['prompt'] = alpaca_prompt.format(example['english_prompt'], example['create_statement'], example['sql_statement'])
    example['response'] = example['sql_statement']
    return example

In [None]:
loaded_saes = LoadedSAES.load_from_path(model_alias=model_alias, k=32, cache_dir=cache_dir, dataset_mapper=format_example)
model = loaded_saes.language_model
tokenizer = loaded_saes.tokenizer

In [None]:
def calculate_similarity(text1, text2):
   def extract_sql_parts(text):
       select_part = text[text.find("SELECT") + 7:text.find("FROM")].strip()
       from_part = text[text.find("FROM") + 5:].strip()
       columns = [c.strip() for c in select_part.split(',')]
       return columns, from_part
   cols1, from1 = extract_sql_parts(text1)
   cols2, from2 = extract_sql_parts(text2)
    
   score = 0
   if "SELECT" in text1 and "SELECT" in text2: score += 0.2
   if "FROM" in text1 and "FROM" in text2: score += 0.2
   if from1 == from2: score += 0.2
   if cols1[0] == cols2[0]: score += 0.2
   if len(cols1) >= 2 and len(cols2) >= 2:
       if cols1[1] == cols2[1]: score += 0.2 
   return score

## Ablation experiment

In [None]:
layer_idx = 16

In [None]:
import TinySQL as qts

model_num = 2                   # 0=GPT2, 1=TinyStories, 2=Qwen, 3=Llama, 4=Granite, 5=SmolLM
cs_num = 1
# 0=BaseModel, 1=CS1, 2=CS2 or 3=CS3
feature_name = qts.ENGTABLENAME   # ENGTABLENAME, ENGFIELDNAME, DEFTABLESTART, DEFTABLENAME, DEFFIELDNAME, DEFFIELDSEPARATOR
use_novel_names = False           # If True, we corrupt using words not found in the clean prompt or create sql e.g. "little" or "hammer"
use_synonyms_table = False
use_synonyms_field = False
batch_size = 100

model = qts.load_tinysql_model(model_num, cs_num, synonym=True)
model_hf = qts.sql_interp_model_location(model_num, cs_num)

In [None]:
generator = qts.CorruptFeatureTestGenerator(model_num, cs_num, model.tokenizer, use_novel_names=use_novel_names, use_synonyms_field=use_synonyms_field, use_synonyms_table=use_synonyms_table)
examples = generator.generate_feature_examples(feature_name, batch_size)

# Each examples is corrupted at prompt_token_index. A resulting impact is expected at answer_token_index
prompts = []
ref_answers = []
for i, example in enumerate(examples):
    clean_tokenizer_index = example.clean_tokenizer_index
    corrupt_tokenizer_index = example.corrupt_tokenizer_index
    answer_token_index = example.answer_token_index

    # Truncate the clean_prompt at answer_token_index
    clean_prompt = example.clean_BatchItem.get_alpaca_prompt() + example.clean_BatchItem.sql_statement
    clean_tokens = model.tokenizer(clean_prompt)["input_ids"]


    prompts.append(clean_prompt.split('Response: ')[0] + 'Response: ')
    ref_answers.append(clean_prompt.split('Response: ')[1])

In [None]:
def nnsight_eval_string_for_layer(layer: str, language_model_var_name="language_model"):
    """
    Converts layer paths to correct format for nnsight evaluation in standalone functions
    """
    # Convert dot notation with numbers to bracket notation
    subbed_layer = re.sub(r'\.([0-9]+)(?=\.|$)', r'[\1]', layer)
    return f"{language_model_var_name}.{subbed_layer}.output.save()"

def encode_text_to_activations_for_layer(text: str, layer: str, language_model=model):
    with torch.no_grad():
        with language_model.trace() as tracer:
            with tracer.invoke(text) as invoker:
                eval_string = nnsight_eval_string_for_layer(layer, language_model_var_name="language_model")
                my_output = eval(eval_string, {"language_model": language_model})
    
    if isinstance(my_output, tuple) and len(my_output) > 1:
        return my_output[0]
    else:
        return my_output
        
def compute_diff_vector(prompts, num_features, layer_name, language_model=model, sae=sae, tokenizer=tokenizer):
    
    diff_vectors = []
    for prompt in tqdm(prompts, desc="Processing prompts"):
        
        original_activations = encode_text_to_activations_for_layer(prompt, layer_name, language_model)
        sae_encoder_output = sae.encode(original_activations)
        top_indices = sae_encoder_output.top_indices[0].cpu().numpy()
        top_acts = sae_encoder_output.top_acts[0].detach().cpu().numpy()
        
        tokens = tokenizer.tokenize(prompt)
        sae_output = SaeOutput(
            sae_name=layer_name,
            sae=sae,
            text=prompt,
            tokens=tokens,
            raw_acts=original_activations.cpu().numpy().tolist(),
            top_acts=top_acts.tolist(),
            top_indices=top_indices.tolist()
        )
        
        # Find features at the last position
        last_position = len(sae_output.top_indices) - 1
        last_pos_indices = sae_output.top_indices[last_position]
        last_pos_acts = sae_output.top_acts[last_position]
        
        pairs = sorted(zip(last_pos_indices, last_pos_acts), key=lambda x: x[1], reverse=True)
        
        # Get features to ablate (all except top num_features)
        target_features_to_ablate = []
        if len(pairs) > num_features:
            target_features_to_ablate = [idx for idx, _ in pairs[num_features:]]
        
        # Process features
        num_tokens = len(sae_output.top_acts)
        num_feats = len(sae_output.top_acts[0])
        
        new_indices = deepcopy(sae_output.top_indices)
        new_acts = deepcopy(sae_output.top_acts)
        
        # Zero out activations for features to ablate
        for i, row in enumerate(new_indices):
            for target_feature in target_features_to_ablate:
                if target_feature in row:
                    index = row.index(target_feature)
                    new_acts[i][index] = 0
        
        # Decode using the original and modified activations
        old_vector = sae.decode(top_indices=torch.tensor(sae_output.top_indices).cuda(),
                                top_acts=torch.tensor(sae_output.top_acts).cuda())
        ablated_vector = sae.decode(top_indices=torch.tensor(new_indices).cuda(),
                                    top_acts=torch.tensor(new_acts).cuda())
        
        # Get the difference vector
        difference_vector = ablated_vector - old_vector
        diff_vectors.append(difference_vector)
    
    final_diff = torch.stack(diff_vectors).mean(0)
    return final_diff

In [None]:

   
def find_similarity_ffeatures(prompts, ref_answers, difference_vector):

    simils = []
    for i, prompt in tqdm(enumerate(prompts)): 
        
        with model.generate(prompt,max_new_tokens=20, pad_token_id=model.tokenizer.eos_token_id, temperature=0.01, 
                                top_p=0.9, eos_token_id=model.tokenizer.eos_token_id, do_sample=True, early_stopping=True) as tracer:
            
            attention_output = model.model.layers[layer_idx].self_attn.output[0].save()
            
            modified_attention = attention_output[0] + difference_vector
            

            model.model.layers[layer_idx].self_attn.output = (modified_attention,) + model.model.layers[layer_idx].self_attn.output[1:]

            out = model.generator.output.save()
    
        gen_text = model.tokenizer.decode(out[0], skip_special_tokens=True).split('Response: ')[1]
        similarity = calculate_similarity(gen_text, ref_answers[i])
        simils.append(similarity)
    
    print('Total Similarity', sum(simils)/len(simils))

    return sum(simils)/len(simils)

    
    

In [None]:
layer_name =  f'model.layers.{layer_idx}.post_attention_layernorm'
sae = loaded_saes.layer_to_saes[layer_name]

In [None]:
def find_optimal_features(prompts, ref_answers, max_features=32, min_similarity=0.7, window_size=3, flatness_threshold=0.02):
    similarities = []
    
    # Test each feature count
    for num in range(1, max_features+1):
        diff = compute_diff_vector(prompts, num)
        sim = find_similarity_ffeatures(prompts, ref_answers, diff)
        similarities.append(sim)
        
        print(f"{num} features: {sim:.4f} similarity")
        
        # Check if we've exceeded minimum similarity
        if sim >= min_similarity:
            # Check if the curve has flattened (using sliding window)
            if len(similarities) >= window_size:
                # Calculate improvement over the last few steps
                recent_improvements = [similarities[i] - similarities[i-1] for i in range(len(similarities)-window_size+1, len(similarities))]
                avg_improvement = sum(recent_improvements) / len(recent_improvements)
                
                # If improvement is below threshold, we've flattened
                if avg_improvement < flatness_threshold:
                    print(f"Stopping at {num} features - similarity {sim:.4f} has flattened (avg improvement: {avg_improvement:.4f})")
                    return num, sim
    
    # If we reach max_features or don't find a flattening point
    print(f"Reached maximum {max_features} features with similarity {similarities[-1]:.4f}")
    return max_features, similarities[-1]


optimal_features, final_sim = find_optimal_features(
    prompts, 
    ref_answers, 
    max_features=32, 
    min_similarity=0.7, 
    window_size=3, 
    flatness_threshold=0.02
)
final_diff = compute_diff_vector(prompts, optimal_features)

In [None]:
# Define the block size based on the model architecture
head_dim = 64   # hidden_size / num_heads = 1024 / 16 = 64
num_heads = 14  
num_features = optimal_features

def analyze_feature_to_heads(feature_num):
    with torch.no_grad():
        decoder_weights = sae.W_dec  
        
        feature_vector = decoder_weights[feature_num].clone()
        
        total_size = feature_vector.numel()
        
        # For attention layers, we expect the size to be hidden_size (1024)
        if total_size != head_dim * num_heads:
            print(f"Warning: Vector size {total_size} doesn't match expected size {head_dim * num_heads}")
            if "attn" not in layer_name:
                print(f"This may not be an attention layer: {layer_name}")
                return np.zeros(num_heads)
        
        # Reshape to match attention heads structure
        try:
            feature_vector = feature_vector.reshape(num_heads, head_dim)
        except RuntimeError:
            print(f"Could not reshape vector of size {total_size} to {num_heads}x{head_dim}")
            # If reshaping fails, just compute overall magnitude
            return torch.norm(feature_vector.view(1, -1), dim=1).repeat(num_heads).cpu().numpy()
    
    # Compute magnitude per head
    head_magnitudes = torch.norm(feature_vector, dim=-1).cpu().numpy()
    
    return head_magnitudes

# Select heads based on area under the curve
def select_important_heads(heads, magnitudes, contribution_threshold=0.8):

    heads = np.array(heads)
    magnitudes = np.array([float(m) for m in magnitudes])
    
    # Sort by magnitude (descending)
    sorted_indices = np.argsort(magnitudes)[::-1]
    sorted_heads = heads[sorted_indices]
    sorted_magnitudes = magnitudes[sorted_indices]
    
    # Calculate difference from each magnitude to the minimum
    min_magnitude = np.min(sorted_magnitudes)
    differences = sorted_magnitudes - min_magnitude
    
    # Calculate total area under the difference curve
    total_area = np.sum(differences)
    
    # Calculate individual and cumulative contributions
    contributions = differences / total_area if total_area > 0 else np.zeros_like(differences)
    cumulative_contributions = np.cumsum(contributions)
    
    # Find cutoff index based on contribution threshold
    cutoff_indices = np.where(cumulative_contributions >= contribution_threshold)[0]
    cutoff_index = cutoff_indices[0] if len(cutoff_indices) > 0 else len(sorted_heads) - 1
    
    # Get selected heads
    selected_heads = sorted_heads[:cutoff_index + 1].tolist()
    
    analysis = {
        'all_heads': sorted_heads,
        'all_magnitudes': sorted_magnitudes,
        'differences': differences,
        'contributions': contributions,
        'cumulative_contributions': cumulative_contributions,
        'cutoff_index': cutoff_index,
        'cutoff_value': sorted_magnitudes[cutoff_index] if cutoff_index < len(sorted_magnitudes) else None,
        'total_selected_contribution': cumulative_contributions[cutoff_index] if cutoff_index < len(sorted_magnitudes) else 1.0
    }
    
    return selected_heads, analysis

In [None]:
from tqdm import tqdm
import numpy as np

accumulated_head_scores = {}

for prompt in tqdm(prompts, desc="Processing prompts"):
    original_activations = encode_text_to_activations_for_layer(prompt, layer_name)
    sae_encoder_output = sae.encode(original_activations)
    top_indices = sae_encoder_output.top_indices[0].cpu().numpy()
    top_acts = sae_encoder_output.top_acts[0].detach().cpu().numpy()
    tokens = tokenizer.tokenize(prompt)  
    sae_output = SaeOutput(
        sae_name=layer_name,
        sae=sae,
        text=prompt,  
        tokens=tokens,
        raw_acts=original_activations.cpu().numpy().tolist(),
        top_acts=top_acts.tolist(),
        top_indices=top_indices.tolist()
    ) 
    
    last_position = len(sae_output.top_indices) - 1
    last_position_indices = sae_output.top_indices[last_position]
    last_position_acts = sae_output.top_acts[last_position]
    
    pairs = sorted(zip(last_position_indices, last_position_acts), key=lambda x: x[1], reverse=True)
    
    features_kept = []
    if len(pairs) > num_features:
        features_kept = [idx for idx, _ in pairs[:num_features]]   
    else:
        features_kept = [idx for idx, _ in pairs]
    
    
    # Process each feature and accumulate head scores
    for feature in features_kept:
        head_magnitudes = analyze_feature_to_heads(feature)
        
        for head_idx, magnitude in enumerate(head_magnitudes):
            if head_idx not in accumulated_head_scores:
                accumulated_head_scores[head_idx] = 0
            accumulated_head_scores[head_idx] += magnitude

# Sort heads by their total scores
sorted_head_scores = sorted(accumulated_head_scores.items(), key=lambda x: x[1], reverse=True)
print("\nHeads sorted by total score across all prompts:")
for head_idx, total_score in sorted_head_scores:
    print(f"Head {head_idx}: Total score {total_score:.4f}")

# Apply area under the curve analysis to select important heads
def select_heads_by_auc(head_scores, contribution_threshold=0.8):
    # Extract head indices and scores
    heads = np.array(list(head_scores.keys()))
    scores = np.array(list(head_scores.values()))
    
    # Sort by score (descending)
    sorted_indices = np.argsort(scores)[::-1]
    sorted_heads = heads[sorted_indices]
    sorted_scores = scores[sorted_indices]
    
    # Calculate difference from each score to the minimum
    min_score = np.min(sorted_scores)
    differences = sorted_scores - min_score
    
    # Calculate total area under the difference curve
    total_area = np.sum(differences)
    
    # Calculate individual and cumulative contributions
    contributions = differences / total_area if total_area > 0 else np.zeros_like(differences)
    cumulative_contributions = np.cumsum(contributions)
    
    # Find cutoff index based on contribution threshold
    cutoff_indices = np.where(cumulative_contributions >= contribution_threshold)[0]
    cutoff_index = cutoff_indices[0] if len(cutoff_indices) > 0 else len(sorted_heads) - 1
    
    # Make sure cutoff_index is valid
    if cutoff_index >= len(sorted_heads):
        cutoff_index = len(sorted_heads) - 1
    elif cutoff_index < 0:
        cutoff_index = 0
    
    # Get selected heads
    selected_heads = sorted_heads[:cutoff_index + 1].tolist()
    
    return selected_heads, {
        'all_heads': sorted_heads,
        'all_scores': sorted_scores,
        'cutoff_index': cutoff_index,
        'total_selected_contribution': cumulative_contributions[cutoff_index] if cutoff_index < len(cumulative_contributions) else 1.0
    }

# Select important heads based on 80% area under the curve
selected_heads, analysis = select_heads_by_auc(accumulated_head_scores, contribution_threshold=0.85)
print(f"\nSelected heads (contributing to 80% of area under curve): {selected_heads}")
print(f"Selection covers {analysis['total_selected_contribution']*100:.2f}% of the area")

## Testing

In [None]:


simils = []
for i, prompt in tqdm(enumerate(prompts)): 
    
    with model.generate(prompt,max_new_tokens=20, pad_token_id=model.tokenizer.eos_token_id, temperature=0.01, 
                            top_p=0.9, eos_token_id=model.tokenizer.eos_token_id, do_sample=True, early_stopping=True) as tracer:
        
        attention_output = model.model.layers[layer_idx].self_attn.output[0].save()
        

        modified_attention = attention_output[0] + final_diff
        
        model.model.layers[layer_idx].self_attn.output = (modified_attention,) + model.model.layers[layer_idx].self_attn.output[1:]
        
        out = model.generator.output.save()

    gen_text = model.tokenizer.decode(out[0], skip_special_tokens=True).split('Response: ')[1]
    similarity = calculate_similarity(gen_text, ref_answers[i])
    simils.append(similarity)

print('Total Similarity', sum(simils)/len(simils))

In [None]:
simils = []
for i, prompt in tqdm(enumerate(prompts)): 
    
    with model.generate(prompt,max_new_tokens=20, pad_token_id=model.tokenizer.eos_token_id, temperature=0.01, 
                            top_p=0.9, eos_token_id=model.tokenizer.eos_token_id, do_sample=True, early_stopping=True) as tracer:
        
        out = model.generator.output.save()

    gen_text = model.tokenizer.decode(out[0], skip_special_tokens=True).split('Response: ')[1]
    similarity = calculate_similarity(gen_text, ref_answers[i])
    simils.append(similarity)

print('Total Similarity', sum(simils)/len(simils))

In [None]:
N_HEADS = 16

def compute_head_means(model, prompt_texts):
    layer_means = []  
    
    with model.generate(prompt_texts,max_new_tokens=7, pad_token_id=model.tokenizer.eos_token_id, temperature=0.5, 
                        top_p=0.9, eos_token_id=model.tokenizer.eos_token_id, do_sample=True, early_stopping=True) as tracer:
        for layer_idx in range(24):
            attn_output = model.model.layers[layer_idx].self_attn.output[0]
            
            output_reshaped = einops.rearrange(
                attn_output, 
                'b s (nh dh) -> b s nh dh',
                nh=N_HEADS
            )
            
            # Calculate mean across samples
            head_means = output_reshaped.mean(dim=0)  # Shape: [s, nh, dh]
            layer_means.append(head_means.save())
            
    return layer_means



def zero_heads_with_means(model, prompt_text, target_layers, heads_per_layer, layer_means, mean=True):
    with model.generate(prompt_text,max_new_tokens=7, pad_token_id=model.tokenizer.eos_token_id, temperature=0.5, 
                        top_p=0.9, eos_token_id=model.tokenizer.eos_token_id, do_sample=True, early_stopping=True) as tracer:
        hidden_states = []
        for layer_id in range(24):
            # Get attention output
            attn_output = model.model.layers[layer_id].self_attn.output[0]
            target_heads = heads_per_layer[layer_id]
            
            
            output_reshaped = einops.rearrange(
                attn_output, 
                'b s (nh dh) -> b s nh dh',
                nh=N_HEADS
            )
            
            head_means = layer_means[layer_id]  # Shape: [s, nh, dh]
            # output_reshaped = head_means.expand_as(output_reshaped)
            if layer_id not in target_layers:
                if mean:
                    output_reshaped = head_means.expand_as(output_reshaped)
                else:
                    output_reshaped = torch.zeros_like(head_means.expand_as(output_reshaped))

            if layer_id in target_layers:
                for head_idx in range(N_HEADS):
                    if head_idx not in target_heads:
                        if mean:
                            output_reshaped[:, :, head_idx, :] = head_means[:, head_idx, :].unsqueeze(0)
                        else:
                            output_reshaped[:, :, head_idx, :] = torch.zeros_like(head_means[:, head_idx, :].unsqueeze(0))
                        

            
            # Reshape back to original format
            modified_attn = einops.rearrange(
                output_reshaped,
                'b s nh dh -> b s (nh dh)', 
                nh=N_HEADS
            )
            
            # Update only the attention output
            model.model.layers[layer_id].self_attn.output = (modified_attn,) + model.model.layers[layer_id].self_attn.output[1:]
            
            # Store the full layer output (which includes both attention and MLP)
            hidden_states.append(model.model.layers[layer_id].output[0].save())
            
        out = model.generator.output.save()
        
    return hidden_states, out

In [None]:
layer_means = compute_head_means(model, prompts)

In [None]:
target_layers = list(range(24))
heads_per_layer = {}
for i in range(24):
    heads_per_layer[i] = list(range(16))

heads_per_layer[layer_idx] = [12, 1, 9, 13, 5, 2, 6]


results = []
for i, prompt in enumerate(prompts):
   hidden_states, output = zero_heads_with_means(model, prompt, target_layers, heads_per_layer, layer_means, mean=False)
   gen_text = model.tokenizer.decode(output[0], skip_special_tokens=True).split('Response: ')[1]
   similarity = calculate_similarity(gen_text, ref_answers[i])
   results.append({
       'output': gen_text, 
       'similarity': similarity
   })

In [None]:
total_similarity = sum(r['similarity'] for r in results)
avg_similarity = total_similarity / len(results)
print(f"Average similarity: {avg_similarity}")