# Install Packages

In [None]:
%pip install jaxtyping transformer_lens plotly-utils einops torch sae_lens
%pip install numpy==1.23.5
%pip install gensim
%pip install git+https://github.com/callummcdougall/eindex.git

# Import Packages

In [1]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import torch as t
from torch import nn, Tensor
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from dataclasses import dataclass
import numpy as np
import einops
from jaxtyping import Float, Int
from typing import Optional, Callable, Union, List, Tuple
from functools import partial
from tqdm.notebook import tqdm
from dataclasses import dataclass
from rich import print as rprint
from rich.table import Table
from IPython.display import display, HTML
from pathlib import Path
from sae_lens import SAE
import plotly.express as px

from transformer_lens import HookedTransformer, FactoredMatrix
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import (
    load_dataset,
    tokenize_and_concatenate,
    download_file_from_hf,
)
# from plotly_utils import imshow, line, hist

device = (
    "cuda"
    if t.cuda.is_available()
    else "mps"
    if t.backends.mps.is_available()
    else "cpu"
)

print("Using device:", device)

MAIN = __name__ == "__main__"

Using device: mps


# Coherence Score Analysis

## Load Models

First load the pretrained transformer and SAE:

In [2]:
model = HookedTransformer.from_pretrained("gpt2-small").to(device)

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  mps


In [3]:
sae, cfg_dict, sparsity = SAE.from_pretrained(release="gpt2-small-res-jb", sae_id="blocks.8.hook_resid_pre")

Next load the fine-tuned transformer and custom SAE:

In [6]:
path_to_custom_transformer_dict = '../models/fine-tuned/fine_tuned_gpt2/model.safetensors'

from safetensors import safe_open

# Initialize the HookedTransformer with the same architecture as your fine-tuned model
custom_model = HookedTransformer.from_pretrained("gpt2")  # base model

# Load the state dict from the .safetensors file
with safe_open(path_to_custom_transformer_dict, framework="pt", device=device) as f:
    state_dict = {key: f.get_tensor(key) for key in f.keys()}

# Load the state dict into the model
custom_model.load_state_dict(state_dict, strict=False)

custom_model = model.to(device)

Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  mps


In [5]:
path_to_custom_sae_dict = '../models/sae/gpt2-small-fine-tuned-layer-8'
custom_sae = SAE.load_from_pretrained(path_to_custom_sae_dict, device=device)

## Load Word Embeddings

In [12]:
import gensim.downloader as api

# Load word embeddings
print("Loading GloVe embeddings...")
word_embeddings = api.load("glove-wiki-gigaword-100")
print("GloVe embeddings loaded.")

Loading GloVe embeddings...
GloVe embeddings loaded.


## Define Functions

In [14]:
from sae_lens.analysis.feature_statistics import get_W_U_W_dec_stats_df
from tqdm import tqdm
from typing import Dict, List, Tuple
from torch import Tensor, topk
from scipy.spatial.distance import cosine


def calculate_coherence_stats(feature_summaries, feature_ids=None):
    if feature_ids is None:
        coherence_scores = [summary['coherence_score'] for summary in feature_summaries.values()]
    else:
        coherence_scores = [feature_summaries[i]['coherence_score'] for i in feature_ids if i in feature_summaries]
    
    coherence_scores = np.array(coherence_scores)
    non_zero_scores = coherence_scores[coherence_scores > 0]
    
    stats = {
        "mean_all": np.mean(coherence_scores),
        "median_all": np.median(coherence_scores),
        "mean_non_zero": np.mean(non_zero_scores) if len(non_zero_scores) > 0 else 0,
        "median_non_zero": np.median(non_zero_scores) if len(non_zero_scores) > 0 else 0,
        "fraction_non_zero": len(non_zero_scores) / len(coherence_scores),
    }
    
    return stats

def semantic_coherence_score(activated_tokens, activation_scores, word_embeddings):
    tokens_and_scores = [(t.lower(), score) for t, score in zip(activated_tokens, activation_scores) if t.lower().isalpha()]

    token_embeddings = []
    weights = []
    for token, score in tokens_and_scores:
        if token in word_embeddings:
            token_embeddings.append(word_embeddings[token])
            weights.append(score)

    similarities = []
    total_weight = 0
    for i in range(len(token_embeddings)):
        for j in range(i+1, len(token_embeddings)):
            sim = 1 - cosine(token_embeddings[i], token_embeddings[j])
            weight = weights[i] * weights[j]
            similarities.append(sim * weight)
            total_weight += weight

    return np.sum(similarities) / total_weight if total_weight > 0 else 0

def get_top_k_words(
    feature_activations: Tensor, words: List[str], k: int = 10
) -> List[Tuple[str, float]]:
    """
    Get the top k activated words for a given feature.

    Args:
        feature_activations (torch.Tensor): Activation values for a feature.
        words (List[str]): List of words in the vocabulary.
        k (int): Number of top words to return.

    Returns:
        List[Tuple[str, float]]: List of tuples containing top words and their activation values.
    """
    if feature_activations.numel() == 0:
        return []

    k = min(k, feature_activations.numel())
    top_k_values, top_k_indices = topk(feature_activations, k)
    top_k_words = [words[i] for i in top_k_indices.tolist()]
    top_k_activations = top_k_values.tolist()

    return list(zip(top_k_words, top_k_activations))

def get_feature_summaries(sae, word_embeddings):
    # Assuming W_dec, model, and other necessary variables are already defined
    W_dec = sae.W_dec.detach().cpu()
    W_U_stats_df_dec, dec_projection_onto_W_U = get_W_U_W_dec_stats_df(
        W_dec, model, cosine_sim=False
    )

    number_of_features = dec_projection_onto_W_U.shape[0]

    # Get vocabulary
    vocab = model.tokenizer.get_vocab()
    words = sorted(vocab.keys(), key=lambda x: vocab[x])

    feature_summaries = {}
    for i in tqdm(range(number_of_features), desc="Processing features"):
        feature_activations = dec_projection_onto_W_U[i]
        top_activated_words = get_top_k_words(feature_activations, words)

        activated_tokens, activation_scores = zip(*top_activated_words)

        coherence_score = semantic_coherence_score(activated_tokens, activation_scores, word_embeddings)
        feature_summary = {
            "feature_idx": i,
            "top_activated_words": top_activated_words,
            "activation_scores": activation_scores,
            "coherence_score": coherence_score
        }
        feature_summaries[i] = feature_summary
    
    return feature_summaries

## Compute Feature Summaries

In [15]:
baseline_feature_summaries = get_feature_summaries(sae, word_embeddings)
finetuned_feature_summaries = get_feature_summaries(custom_sae, word_embeddings)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Processing features: 100%|██████████| 24576/24576 [00:23<00:00, 1060.10it/s]
Processing features: 100%|██████████| 24576/24576 [00:27<00:00, 906.09it/s] 


## Get the Feature IDs

In [16]:
import json
import pandas as pd

path_to_parsed_features = '../features/parsed_features.json'

# Load the JSON data
with open(path_to_parsed_features, 'r') as f:
    data = json.load(f)

# Function to convert the nested dictionary to a DataFrame
def dict_to_df(data_dict):
    rows = []
    for feature, tokens in data_dict.items():
        for token, activation in tokens.items():
            rows.append({
                'feature': int(feature),
                'token': token,
                'activation': activation
            })
    return pd.DataFrame(rows)

# Create DataFrames for baseline and finetuned data
baseline_df = dict_to_df(data['baseline'])
finetuned_df = dict_to_df(data['finetuned'])

baseline_feature_ids = baseline_df.feature.unique()
finetuned_feature_ids = finetuned_df.feature.unique()

## Compute Coherence Score Stats!

In [19]:
# Calculate stats for both models
baseline_stats = calculate_coherence_stats(baseline_feature_summaries)
baseline_medical_stats = calculate_coherence_stats(baseline_feature_summaries, baseline_feature_ids)
finetuned_stats = calculate_coherence_stats(finetuned_feature_summaries)
finetuned_medical_stats = calculate_coherence_stats(finetuned_feature_summaries, finetuned_feature_ids)

# Function to print stats
def print_stats(name, stats):
    print(f"{name}:")
    print(f"  Mean (all): {stats['mean_all']:.4f}")
    print(f"  Median (all): {stats['median_all']:.4f}")
    print(f"  Mean (non-zero): {stats['mean_non_zero']:.4f}")
    print(f"  Median (non-zero): {stats['median_non_zero']:.4f}")
    print(f"  Fraction of non-zero scores: {stats['fraction_non_zero']:.2%}")

# Print stats for both models
print("Baseline gpt2-small:")
print_stats("All features", baseline_stats)
print_stats("Medical features", baseline_medical_stats)

print("\nFine-tuned gpt2-small:")
print_stats("All features", finetuned_stats)
print_stats("Medical features", finetuned_medical_stats)

Baseline gpt2-small:
All features:
  Mean (all): 0.0719
  Median (all): 0.0000
  Mean (non-zero): 0.2057
  Median (non-zero): 0.1399
  Fraction of non-zero scores: 38.82%
Medical features:
  Mean (all): 0.0644
  Median (all): 0.0000
  Mean (non-zero): 0.1998
  Median (non-zero): 0.1342
  Fraction of non-zero scores: 37.62%

Fine-tuned gpt2-small:
All features:
  Mean (all): 0.0728
  Median (all): 0.0246
  Mean (non-zero): 0.1470
  Median (non-zero): 0.1026
  Fraction of non-zero scores: 57.06%
Medical features:
  Mean (all): 0.0671
  Median (all): 0.0000
  Mean (non-zero): 0.1584
  Median (non-zero): 0.0938
  Fraction of non-zero scores: 47.12%


# Compute Coherence Scores From JSON

In [1]:
import json
import pandas as pd
import numpy as np
from scipy.spatial.distance import cosine
import gensim.downloader as api

In [7]:
# Load JSON data
with open('../features/medical_features_finetuned.json', 'r') as f:
    data_ft = json.load(f)

with open('../features/medical_features_baseline.json', 'r') as f:
    data_b = json.load(f)

In [None]:
# Load word embeddings
print("Loading GloVe embeddings...")
word_embeddings = api.load("glove-wiki-gigaword-100")
print("GloVe embeddings loaded.")

In [8]:
import re

def preprocess_token(token):
    # Remove 'Ġ' artifact and convert to lowercase
    token = token.lstrip('Ġ').lower()
    # Keep only alphabetic characters
    token = re.sub(r'[^a-z]', '', token)
    return token

def semantic_coherence_score(activated_tokens, activation_scores, word_embeddings, epsilon=1e-8):
    # Preprocess tokens and scores
    tokens_and_scores = [(preprocess_token(t), score) for t, score in zip(activated_tokens, activation_scores)]
    tokens_and_scores = [(t, score) for t, score in tokens_and_scores if t]  # Remove empty tokens

    # Get embeddings for each token
    token_embeddings = []
    weights = []
    for token, score in tokens_and_scores:
        if token in word_embeddings:
            token_embeddings.append(word_embeddings[token])
            weights.append(score)
        else:
            print(f"Warning: '{token}' not found in word embeddings")

    # If we have fewer than 2 valid tokens, return 0
    if len(token_embeddings) < 2:
        print(f"Warning: Fewer than 2 valid tokens for feature. Tokens: {activated_tokens}")
        return 0

    # Calculate weighted pairwise similarities
    similarities = []
    total_weight = 0
    for i in range(len(token_embeddings)):
        for j in range(i+1, len(token_embeddings)):
            sim = 1 - cosine(token_embeddings[i], token_embeddings[j])  # Cosine similarity
            weight = weights[i] * weights[j]  # Weight by product of activation scores
            similarities.append(sim * weight)
            total_weight += weight

    # Compute weighted average similarity
    score = np.sum(similarities) / (total_weight + epsilon)
    if score == 0:
        print(f"Warning: Zero score for feature. Tokens: {activated_tokens}")
    return score

In [9]:
def compute_coherence_scores(data, title):
    # Compute coherence scores for each feature
    results = []
    for feature_id, feature_data in data.items():
        top_words = feature_data['top_words']
        activated_tokens = [word['word'].strip() for word in top_words]
        activation_scores = [word['activation'] for word in top_words]
        
        coherence_score = semantic_coherence_score(activated_tokens, activation_scores, word_embeddings)
        
        results.append({
            'feature_id': feature_id,
            'coherence_score': coherence_score,
            'top_words': ', '.join(activated_tokens[:5])  # Show top 5 words for brevity
        })

    # Create DataFrame
    df = pd.DataFrame(results)

    # Sort by coherence score in descending order
    df = df.sort_values('coherence_score', ascending=False)

    # Display top 10 features by coherence score
    print(df.head(10))

    # Display summary statistics
    print("\nSummary Statistics:")
    print(df['coherence_score'].describe())

    # Save results to CSV
    df.to_csv(f'feature_coherence_scores_{title}.csv', index=False)
    print("\nResults saved to f'feature_coherence_scores_{title}.csv'")

In [11]:
compute_coherence_scores(data_b, "baseline")
compute_coherence_scores(data_ft, "finetuned")

    feature_id  coherence_score  \
2           98         0.764068   
68       15344         0.760279   
89       18508         0.753547   
84       17863         0.749772   
99       20390         0.743273   
95       19494         0.741164   
23        6487         0.730967   
41        9360         0.723107   
118      24377         0.719579   
106      21509         0.714646   

                                             top_words  
2    Ġtreatment, treatment, ĠTreatment, Ġtreatments...  
68     ĠMedical, Medical, medical, Ġmedical, ĠHospital  
89   Ġdisease, ĠDisease, Ġdiseases, ĠDiseases, Ġill...  
84     Ġdoctor, doctor, Ġphysician, Ġdoctors, Ġsurgeon  
99       Ġcancer, cancer, ĠCancer, Ġcancers, Ġleukemia  
95        ĠCancer, cancer, Ġcancer, ĠDiabetes, ĠAutism  
23   Ġpatients, ĠPatients, Ġpatient, Ġclinicians, Ġ...  
41   Ġmedical, medical, Medical, ĠMedical, Ġbiomedical  
118      ĠHealth, Health, health, Ġhealth, ĠHealthcare  
106       Ġtreat, Ġtreats, ĠTreat, Ġtreating