# Setup

In [1]:
try:
    import google.colab
    IN_COLAB = True
    %pip install sae-lens transformer-lens
except ImportError:
    IN_COLAB = False

Collecting sae-lens
  Downloading sae_lens-3.22.2-py3-none-any.whl.metadata (5.1 kB)
Collecting transformer-lens
  Downloading transformer_lens-2.7.0-py3-none-any.whl.metadata (12 kB)
Collecting automated-interpretability<1.0.0,>=0.0.5 (from sae-lens)
  Downloading automated_interpretability-0.0.6-py3-none-any.whl.metadata (778 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae-lens)
  Downloading babe-0.0.7-py3-none-any.whl.metadata (10 kB)
Collecting datasets<3.0.0,>=2.17.1 (from sae-lens)
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting matplotlib<4.0.0,>=3.8.3 (from sae-lens)
  Downloading matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting plotly-express<0.5.0,>=0.4.1 (from sae-lens)
  Downloading plotly_express-0.4.1-py2.py3-none-any.whl.metadata (1.7 kB)
Collecting pytest-profiling<2.0.0,>=1.7.0 (from sae-lens)
  Downloading pytest_profiling-1.7.0-py2.py3-none-any.whl.metadata (12 kB)
Collecting python-dot

In [8]:
# Standard imports
import os
import sys
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops

# import the LLM
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from pathlib import Path

# GPU memory saver (this script doesn't need gradients computation)
torch.set_grad_enabled(False)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Device: cuda


# Config

In [9]:
# define the model to work with
MODEL = 'GPT2' # GEMMA, MISTRAL, GPT2

if MODEL == 'GEMMA':
    # Base model stuff
    BASE_MODEL = "google/gemma-2b"
    DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
    BASE_TOKENIZER_NAME = BASE_MODEL

    # Finetuned model stuff
    FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'
    FINETUNE_PATH = None

    # SAE stuff
    RELEASE = 'gemma-2b-res-jb'
    hook_part = 'post'
    layer_num = 6
elif MODEL == 'MISTRAL':
    # Base model stuff
    BASE_MODEL = "mistral-7b"
    DATASET_NAME = "monology/pile-uncopyrighted"
    BASE_TOKENIZER_NAME = 'mistralai/Mistral-7B-v0.1'

    # Finetuned model stuff
    FINETUNE_MODEL = 'meta-math/MetaMath-Mistral-7B'
    FINETUNE_PATH = f'/content/drive/My Drive/Finetunes/MetaMath-Mistral-7B'

    # SAE stuff
    RELEASE = 'mistral-7b-res-wg'
    hook_part = 'pre'
    layer_num = 8
elif MODEL == 'GPT2':
    # Base model stuff
    BASE_MODEL = "gpt2-small"
    DATASET_NAME = "Skylion007/openwebtext"
    BASE_TOKENIZER_NAME = 'openai-community/gpt2'

    # Finetuned model stuff
    FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'
    FINETUNE_PATH = None

    # SAE stuff
    RELEASE = 'gpt2-small-res-jb'
    hook_part = 'pre'
    layer_num = 6

SAE_HOOK = f'blocks.{layer_num}.hook_resid_{hook_part}'

In [10]:
COLAB_BASE_PATH = '/content/drive/My Drive/sae_data'

if IN_COLAB:
    # If in Colab, mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    # Define the path to your JSON file in Google Drive
    datapath = Path(COLAB_BASE_PATH)
else:
    # If not in Colab, use local folder
    # Assuming this is being run from the 'notebooks' folder
    sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

    from saetuning.utils import get_env_var
    _, datapath = get_env_var()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [11]:
import plotly.graph_objs as go
from functools import partial

def plot_log10_hist(y_data, y_value, num_bins=100, first_bin_name = 'First bin value',
                    y_scalar=1.5, y_scale_bin=-2, log_epsilon=1e-10, title=None):
    """
    Computes the histogram using PyTorch and plots the feature density diagram with log-10 scale using Plotly.
    Y-axis is clipped to the value of the second-largest bin to prevent suppression of smaller values.
    """

    if title is None:
        title = f"SAE Features {y_value} histogram"

    # Flatten the tensor
    y_data_flat = torch.flatten(y_data)

    # Compute the logarithmic transformation using PyTorch
    log_y_data_flat = torch.log10(torch.abs(y_data_flat) + log_epsilon).detach().cpu()

    # Compute histogram using PyTorch
    hist_min = torch.min(log_y_data_flat).item()
    hist_max = torch.max(log_y_data_flat).item()
    hist_range = hist_max - hist_min
    bin_edges = torch.linspace(hist_min, hist_max, num_bins + 1)
    hist_counts, _ = torch.histogram(log_y_data_flat, bins=bin_edges)

    # Convert data to NumPy for Plotly
    bin_edges_np = bin_edges.detach().cpu().numpy()
    hist_counts_np = hist_counts.detach().cpu().numpy()

    # Find the largest and second-largest bin values
    first_bin_value = hist_counts_np[0]
    scale_bin_value = sorted(hist_counts_np)[y_scale_bin]  # Get the second largest bin value (by default)

    # Prepare the Plotly plot
    fig = go.Figure(
        data=[go.Bar(
            x=bin_edges_np[:-1],  # Exclude the last bin edge
            y=hist_counts_np,
            width=hist_range / num_bins,
        )]
    )

    # Update the layout for the plot, clipping the y-axis at the second largest bin value
    fig.update_layout(
        title=title,
        xaxis_title=f"Log10 of {y_value}",
        yaxis_title="Count",
        yaxis_range=[0, scale_bin_value * y_scalar],  # Clipping to the second-largest value by default
        bargap=0.2,
        bargroupgap=0.1,
    )

    # Add an annotation to display the value of the first bin
    fig.add_annotation(
        text=f"{first_bin_name}: {first_bin_value:.2e}",
        xref="paper", yref="paper",
        x=0.95, y=0.95,
        showarrow=False,
        font=dict(size=12, color="red"),
        bgcolor="white",
        bordercolor="black",
        borderwidth=1
    )

    # Show the plot
    fig.show()

class FeatureDensityPlotter:
    def __init__(self, n_features=None, n_tokens=None, activation_threshold=1e-10, num_bins=100,
                 feature_densities=None):
        if feature_densities is not None:
            self.feature_densities = feature_densities
            return

        self.num_bins = num_bins
        self.activation_threshold = activation_threshold

        self.n_tokens = n_tokens
        self.n_features = n_features

        # Initialize a tensor of feature densities for all features,
        # where feature density is defined as the fraction of tokens on which the feature has a nonzero value.
        self.feature_densities = torch.zeros(n_features, dtype=torch.float32)

    def update(self, feature_acts):
        """
        Expects a tensor feature_acts of shape [N_TOKENS, N_FEATURES].

        Updates the feature_densities buffer:
        1. For each feature, count the number of tokens that the feature activated on (i.e. had an activation greater than the activation_threshold)
        2. Add this count at the feature's position in the feature_densities tensor, divided by the total number of tokens (to compute the fraction)
        """

        activating_tokens_count = (feature_acts > self.activation_threshold).float().sum(0)
        self.feature_densities += activating_tokens_count / self.n_tokens

    def plot(self, num_bins=100, y_scalar=1.5, y_scale_bin=-2, log_epsilon=1e-10, title=None):
        plot_log10_hist(self.feature_densities, 'Density', num_bins=num_bins, first_bin_name='Dead features density',
                        y_scalar=y_scalar, y_scale_bin=y_scale_bin, log_epsilon=log_epsilon, title=title)

# Feature densities loading & plotting

## Base model feature densities

In [12]:
from pathlib import Path

# Choose saving names consistent with saetuning/get_scores.py
saving_name_base = BASE_MODEL if "/" not in BASE_MODEL else BASE_MODEL.split("/")[-1]
saving_name_ft = FINETUNE_MODEL if "/" not in FINETUNE_MODEL else FINETUNE_MODEL.split("/")[-1]
saving_name_ds = DATASET_NAME if "/" not in DATASET_NAME else DATASET_NAME.split("/")[-1]

base_feature_densities_fname = f'Feature_densities_{saving_name_base}_on_{saving_name_ds}.pt'

load_path = datapath / base_feature_densities_fname
base_feature_densities = torch.load(load_path)

  base_feature_densities = torch.load(load_path)


In [13]:
n_features = base_feature_densities.numel()
n_features

24576

In [14]:
density_plotter = FeatureDensityPlotter(feature_densities=base_feature_densities)
density_plotter.plot(title=f'{saving_name_base} SAE Features Density Histogram')

## Finetune model feature densities

In [15]:
finetune_feature_densities_fname = f'Feature_densities_{saving_name_ft}_on_{saving_name_ds}.pt'

load_path = datapath / finetune_feature_densities_fname
finetune_feature_densities = torch.load(load_path)


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [16]:
density_plotter = FeatureDensityPlotter(feature_densities=finetune_feature_densities)
density_plotter.plot(title=f'{saving_name_ft} SAE Features Density')

In [17]:
log_epsilon=1e-10

base_feature_densities = torch.log10(base_feature_densities + log_epsilon)
finetune_feature_densities = torch.log10(finetune_feature_densities + log_epsilon)

## Dead features count comparison

In [18]:
dead_threshold = -9

base_dead = torch.nonzero(base_feature_densities <= dead_threshold)
base_dead_frac = base_dead.numel() / base_feature_densities.numel()

finetune_dead = torch.nonzero(finetune_feature_densities <= dead_threshold)
finetune_dead_frac = finetune_dead.numel() / finetune_feature_densities.numel()

print(f'Base model dead features fraction: {base_dead_frac:.4f}')
print(f'Finetune model dead features fraction: {finetune_dead_frac:.4f}')

Base model dead features fraction: 0.0006
Finetune model dead features fraction: 0.0481


In [19]:
import plotly.graph_objs as go
from scipy.stats import linregress

# Convert tensors to NumPy arrays for compatibility with other libraries
base_feature_densities_np = base_feature_densities.cpu().numpy()
finetune_feature_densities_np = finetune_feature_densities.cpu().numpy()

# Perform linear regression
slope, intercept, r_value, p_value, std_err = linregress(base_feature_densities_np, finetune_feature_densities_np)

# Define the regression line
regression_line = slope * base_feature_densities_np + intercept

# Create scatter plot
scatter_trace = go.Scatter(
    x=base_feature_densities_np,
    y=finetune_feature_densities_np,
    mode='markers',
    name='Data points',
    marker=dict(size=5, opacity=0.7)
)

# Create the regression line plot
line_trace = go.Scatter(
    x=base_feature_densities_np,
    y=regression_line,
    mode='lines',
    name=f'Regression line (R = {r_value:.2f})',
    line=dict(color='red')
)

# Set up the layout
layout = go.Layout(
    title='Scatter Plot of SAE Features Densities',
    xaxis=dict(title='Base Model SAE Densities'),
    yaxis=dict(title='Finetuned Model SAE Densities'),
    showlegend=True
)

# Combine the traces into a figure
fig = go.Figure(data=[scatter_trace, line_trace], layout=layout)

# Show the plot
fig.show()

# Print correlation coefficient
print(f"Correlation coefficient (R): {r_value:.4f}")

Correlation coefficient (R): 0.0923


In [20]:
base_feature_densities.mean(), finetune_feature_densities.mean()

(tensor(-3.1332), tensor(-3.6030))

In [21]:
df = pd.DataFrame.from_dict({
    'base_feature_densities': base_feature_densities_np,
    'finetune_feature_densities': finetune_feature_densities_np
})
df

Unnamed: 0,base_feature_densities,finetune_feature_densities
0,-3.374816,-3.408240
1,-3.494426,-3.550907
2,-3.876761,-4.065817
3,-3.539008,-3.204120
4,-3.726999,-2.436037
...,...,...
24571,-3.209583,-2.808357
24572,-2.170194,-3.024425
24573,-2.324021,-3.745482
24574,-3.046512,-2.494426


In [22]:
import plotly.express as px

fig = px.parallel_coordinates(df,
                              dimensions=['base_feature_densities', 'finetune_feature_densities']
                             )
fig.show()

# Sampling features from density intervals

In [23]:
DENSE_COUNT = 1000
DEAD_COUNT = 100

In [24]:
import torch

def subsample_features(base_tensor, finetune_tensor, dead_threshold=-8, dense_range=(-5, -1),
                       dead_samples=DEAD_COUNT, dense_samples=DENSE_COUNT, log=False):
    """
    Subsample features from two tensors for the base and finetuned models based on density values,
    and return counts of dead features in both models.

    Parameters:
    - base_tensor (torch.Tensor): The input tensor of feature densities for the base model.
    - finetune_tensor (torch.Tensor): The input tensor of feature densities for the finetuned model.
    - dead_threshold (float): The log10 threshold for determining dead features (default: -8).
    - dense_range (tuple): The range (inclusive) in log scale for determining dense features (default: (-5, -1)).
    - dead_samples (int): The total number of dead feature samples to return for both models.
    - dense_samples (int): The total number of dense feature samples to return for both models.
    - log (bool): Whether to apply log10 scaling to the tensors.

    Returns:
    - dead_base_indices (torch.Tensor): Subsampled indices for dead features in the base model (but not in finetuned).
    - dead_finetune_indices (torch.Tensor): Subsampled indices for dead features in the finetuned model (but not in base).
    - dense_both_indices (torch.Tensor): Subsampled indices for dense features in both base and finetuned models.
    - total_dead_base (int): Total number of dead features in the base model.
    - total_dead_finetune (int): Total number of dead features in the finetuned model.
    """
    if log:
        # Convert tensors to log10 scale
        base_tensor = torch.log10(base_tensor)
        finetune_tensor = torch.log10(finetune_tensor)

    # 1. Dead features: below dead_threshold in one model, but not in both
    dead_base_indices = torch.nonzero((base_tensor <= dead_threshold) & (finetune_tensor > dead_threshold)).squeeze(1)
    dead_finetune_indices = torch.nonzero((finetune_tensor <= dead_threshold) & (base_tensor > dead_threshold)).squeeze(1)

    # Total counts of dead features in base and finetuned models
    total_dead_base = dead_base_indices.numel()
    total_dead_finetune = dead_finetune_indices.numel()

    # Randomly subsample if needed
    if len(dead_base_indices) > dead_samples:
        dead_base_indices = dead_base_indices[torch.randperm(len(dead_base_indices))[:dead_samples]]
    if len(dead_finetune_indices) > dead_samples:
        dead_finetune_indices = dead_finetune_indices[torch.randperm(len(dead_finetune_indices))[:dead_samples]]

    # 2. Dense features: in the interval dense_range for both base and finetune
    dense_base_indices = (base_tensor >= dense_range[0]) & (base_tensor <= dense_range[1])
    dense_finetune_indices = (finetune_tensor >= dense_range[0]) & (finetune_tensor <= dense_range[1])
    dense_both_indices = torch.nonzero(dense_base_indices & dense_finetune_indices).squeeze(1)

    if len(dense_both_indices) > dense_samples:
        dense_both_indices = dense_both_indices[torch.randperm(len(dense_both_indices))[:dense_samples]]

    return dead_base_indices, dead_finetune_indices, dense_both_indices, total_dead_base, total_dead_finetune


In [25]:
dead_base_features, dead_finetune_features, dense_features, total_dead_base, total_dead_finetune = subsample_features(base_feature_densities, finetune_feature_densities)

print(f'Total dead base features, but not dead in the finetuned model: {total_dead_base}')
print(f'Total dead finetune features, but not dead in the base model: {total_dead_finetune}')

torch.save(dead_base_features, datapath / f'{saving_name_base}_dead_features.pt')
torch.save(dead_finetune_features, datapath / f'{saving_name_ft}_dead_features.pt')
torch.save(dense_features, datapath / f'{saving_name_base}_{saving_name_ft}_dense_features.pt')

Total dead base features, but not dead in the finetuned model: 2
Total dead finetune features, but not dead in the base model: 1169


In [26]:
base_feature_densities[dead_base_features], finetune_feature_densities[dead_finetune_features]

(tensor([-10., -10.]),
 tensor([-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
         -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
         -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
         -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
         -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
         -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
         -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
         -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
         -10., -10., -10., -10.]))

In [27]:
base_feature_densities[dense_features][:10], finetune_feature_densities[dense_features][:10]

(tensor([-3.0353, -3.7850, -3.3829, -3.6158, -2.1530, -2.7688, -2.2734, -3.9611,
         -3.3010, -2.7955]),
 tensor([-3.4540, -3.5274, -2.4758, -3.3590, -2.6539, -3.0580, -4.0658, -2.9737,
         -4.1072, -2.5030]))

# Computing feature representations

In [28]:
# import the required libraries
from sae_lens import SAE

sae_id = f'blocks.{layer_num}.hook_resid_{hook_part}'

sae, cfg_dict, sparsity = SAE.from_pretrained(
                            release = RELEASE,
                            sae_id = sae_id,
                            device = device
)

blocks.6.hook_resid_pre/cfg.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/98.4k [00:00<?, ?B/s]



This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)



## Base model

In [30]:
base_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, dtype=torch.float16)

model.safetensors:  96%|#########5| 524M/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]


`clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884



Loaded pretrained model gpt2-small into HookedTransformer


In [31]:
from sae_lens import ActivationsStore

if MODEL == 'MISTRAL' or MODEL == 'GPT2':
    batch_size_prompts = 20
else:
    batch_size_prompts = 5

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=base_model,
    sae=sae,
    streaming=True,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=batch_size_prompts,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

batch_size_tokens = activation_store.context_size * batch_size_prompts
batch_size_tokens

Downloading builder script:   0%|          | 0.00/2.73k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.35k [00:00<?, ?B/s]



2560

### Feature activation vectors

In [32]:
def get_feature_activations(features_ids, model, sae=sae, total_batches=50, activation_store=None,
                            batch_size_prompts=batch_size_prompts, sae_id=sae_id, layer_num=layer_num,
                            model_name=BASE_MODEL, exclude_bos=False, tokens_sample=None, return_dtype=torch.float16):
    base_model_run = tokens_sample is None

    if base_model_run:
        assert activation_store is not None
        def get_tokens(k):
            """Returns the tokens sampled from the activation store"""
            # Get the corresponding batch of tokens from all_tokens
            tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
            if exclude_bos and activation_store.prepend_bos:
                tokens = tokens[:, 1:]

            return tokens
    else:
        def get_tokens(k):
            """Returns the tokens for the k-th outer batch, where 0 <= k < TOTAL_BATCHES"""

            # Get the corresponding batch of tokens from all_tokens
            tokens = tokens_sample[k]  # [N_BATCH, N_CONTEXT]
            return tokens

    if base_model_run:
        all_tokens = []
    all_feature_activations = []

    for k in tqdm(range(total_batches)):
        # Get a batch of tokens from the dataset
        tokens = get_tokens(k)  # [N_BATCH, N_CONTEXT]

        # Run the model and store the activations
        _, cache = model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                        names_filter=[sae_id]) # [N_BATCH, N_CONTEXT, D_MODEL]

        # Get the activations from the cache at the sae_id
        sae_in = cache[sae_id].flatten(0, 1) # [N_BATCH * N_CONTEXT, D_MODEL]

        del cache
        clear_cache()

        # Store tokens for later reuse
        if base_model_run:
            all_tokens.append(tokens)

        # Encode the activations with the SAE
        sae_hidden = sae.encode(sae_in) # [N_BATCH * N_CONTEXT, N_HIDDEN]

        # Select only given features
        feature_activations = sae_hidden[:, features_ids] # [N_BATCH * N_CONTEXT, len(feature_ids)]
        all_feature_activations.append(feature_activations)

        # Explicitly free up memory by deleting the cache and emptying the CUDA cache
        del sae_in, sae_hidden
        clear_cache()

    tokens_dataset = torch.stack(all_tokens, dim=0) if base_model_run else None
    all_feature_activations = torch.cat(all_feature_activations, dim=0).to(return_dtype)

    return all_feature_activations, tokens_dataset

In [33]:
base_model_dead_base_act, act_tokens = get_feature_activations(dead_base_features, base_model,
                                                               activation_store=activation_store)
base_model_dead_finetune_act, _ = get_feature_activations(dead_finetune_features, base_model,
                                                          tokens_sample=act_tokens)

torch.save(base_model_dead_base_act, datapath / f'{saving_name_base}_dead_base_act.pt')
torch.save(base_model_dead_finetune_act, datapath / f'{saving_name_base}_dead_finetune_act.pt')

torch.save(act_tokens, datapath / f'{saving_name_base}_act_tokens.pt')

base_model_dead_base_act.shape, base_model_dead_finetune_act.shape

  0%|          | 0/50 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 50/50 [00:24<00:00,  2.02it/s]
100%|██████████| 50/50 [00:22<00:00,  2.17it/s]


(torch.Size([128000, 2]), torch.Size([128000, 100]))

In [34]:
base_model_dense_act, _ = get_feature_activations(dense_features, base_model,
                                                  tokens_sample=act_tokens)

torch.save(base_model_dense_act, datapath / f'{saving_name_base}_dense_act.pt')

base_model_dense_act.shape

100%|██████████| 50/50 [00:22<00:00,  2.20it/s]


torch.Size([128000, 1000])

In [35]:
del base_model_dead_base_act, base_model_dead_finetune_act, base_model_dense_act
clear_cache()

### Feature logit vectors

In [36]:
def get_feature_vector(feature_id, sae=sae):
    return sae.W_dec[feature_id]

def get_feature_logits(feature_id, model, sae=sae):
    W_U = model.W_U.to(torch.float32)
    feature_vector = get_feature_vector(feature_id, sae)

    logits = feature_vector @ W_U
    return logits

In [37]:
def get_features_logit_vectors(features_ids, model, sae=sae):
    """
    Assumes tokens of shape [total_batches, batch*seq]
    """
    logit_vectors = []

    for feature_id in features_ids:
        feature_logits = get_feature_logits(feature_id, model, sae=sae)
        logit_vectors.append(feature_logits)

    return torch.stack(logit_vectors).T

In [38]:
base_model_dense_logits = get_features_logit_vectors(dense_features, base_model)
base_model_dead_base_logits = get_features_logit_vectors(dead_base_features, base_model)
base_model_dead_finetune_logits = get_features_logit_vectors(dead_finetune_features, base_model)

torch.save(base_model_dense_logits, datapath / f'{saving_name_base}_dense_logit_vectors.pt')
torch.save(base_model_dead_base_logits, datapath / f'{saving_name_base}_dead_base_logit_vectors.pt')
torch.save(base_model_dead_finetune_logits, datapath / f'{saving_name_base}_dead_finetune_logit_vectors.pt')

base_model_dense_logits.shape, base_model_dead_base_logits.shape

(torch.Size([50257, 1000]), torch.Size([50257, 2]))

In [39]:
del base_model_dense_logits, base_model_dead_base_logits, base_model_dead_finetune_logits
clear_cache()

In [40]:
base_model_unembed = base_model.W_U.detach().clone()

del base_model, activation_store
clear_cache()

In [41]:
base_model_unembed.shape

torch.Size([768, 50257])

## Finetune model

In [42]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def adjust_state_dict(model, base_model_vocab_size):
    """Adjust the state_dict of the model to match the base model's vocab size."""
    state_dict = model.state_dict()

    # Adjust the embedding matrix
    if state_dict['model.embed_tokens.weight'].shape[0] > base_model_vocab_size:
        state_dict['model.embed_tokens.weight'] = state_dict['model.embed_tokens.weight'][:base_model_vocab_size, :]

    # Adjust the unembedding (lm_head) matrix
    if state_dict['lm_head.weight'].shape[0] > base_model_vocab_size:
        state_dict['lm_head.weight'] = state_dict['lm_head.weight'][:base_model_vocab_size, :]

    return state_dict

def load_hf_model(path, base_model=BASE_MODEL, device='cuda', dtype=None):
    tokenizer = AutoTokenizer.from_pretrained(path)
    model = AutoModelForCausalLM.from_pretrained(path)

    # Adjust the model's state dict to match the base model's vocab size
    if base_model == 'mistral-7b':
      base_model_vocab_size = 32000  # Mistral 7B base vocab size
      adjusted_state_dict = adjust_state_dict(model, base_model_vocab_size)

      # Adjust model architecture to match the new vocab size
      model.resize_token_embeddings(base_model_vocab_size)

      # Load the adjusted state dict back into the model
      model.load_state_dict(adjusted_state_dict, strict=False)

    # Now load the fine-tuned model into the HookedSAETransformer
    finetune_model = HookedSAETransformer.from_pretrained(
        base_model, device=device, hf_model=model, dtype=dtype
    )

    del model  # offload the HF model as it's already wrapped into HookedSAETransformer (finetune_model)
    clear_cache()

    return tokenizer, finetune_model

In [43]:
# Load the finetune model and its tokenizer
finetune_tokenizer, finetune_model = load_hf_model(FINETUNE_PATH if FINETUNE_PATH is not None else FINETUNE_MODEL,
                                                   device=device, dtype=torch.float16)

tokenizer_config.json:   0%|          | 0.00/92.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/850k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/508k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/120 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/510M [00:00<?, ?B/s]



Loaded pretrained model gpt2-small into HookedTransformer


In [44]:
act_tokens = torch.load(datapath / f'{saving_name_base}_act_tokens.pt')

act_tokens.shape


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



torch.Size([50, 20, 128])

### Feature activation vectors

In [45]:
finetune_model_dead_base_act, _ = get_feature_activations(dead_base_features, finetune_model,
                                                          tokens_sample=act_tokens)
finetune_model_dead_finetune_act, _ = get_feature_activations(dead_finetune_features, finetune_model,
                                                              tokens_sample=act_tokens)

torch.save(finetune_model_dead_base_act, datapath / f'{saving_name_ft}_dead_base_act.pt')
torch.save(finetune_model_dead_finetune_act, datapath / f'{saving_name_ft}_dead_finetune_act.pt')

100%|██████████| 50/50 [00:23<00:00,  2.17it/s]
100%|██████████| 50/50 [00:23<00:00,  2.16it/s]


In [46]:
finetune_model_dense_act, _ = get_feature_activations(dense_features, finetune_model,
                                                      tokens_sample=act_tokens)
torch.save(finetune_model_dense_act, datapath / f'{saving_name_ft}_dense_act.pt')

finetune_model_dense_act.shape

100%|██████████| 50/50 [00:24<00:00,  2.04it/s]


torch.Size([128000, 1000])

In [47]:
del finetune_model_dead_base_act, finetune_model_dead_finetune_act, finetune_model_dense_act
clear_cache()

### Feature logit vectors

In [48]:
finetune_model_dense_logits = get_features_logit_vectors(dense_features, finetune_model)
finetune_model_dead_base_logits = get_features_logit_vectors(dead_base_features, finetune_model)
finetune_model_dead_finetune_logits = get_features_logit_vectors(dead_finetune_features, finetune_model)

torch.save(finetune_model_dense_logits, datapath / f'{saving_name_ft}_dense_logit_vectors.pt')
torch.save(finetune_model_dead_base_logits, datapath / f'{saving_name_ft}_dead_base_logit_vectors.pt')
torch.save(finetune_model_dead_finetune_logits, datapath / f'{saving_name_ft}_dead_finetune_logit_vectors.pt')

In [49]:
finetune_model_unembed = finetune_model.W_U.detach().clone()

del finetune_model
clear_cache()

In [50]:
def compare_unembedding_matrices(matrix1: torch.Tensor, matrix2: torch.Tensor):
    # Compute the Frobenius norm of both matrices
    frobenius_norm1 = torch.norm(matrix1, p='fro')
    frobenius_norm2 = torch.norm(matrix2, p='fro')

    # Compute the ratio of the Frobenius norms
    frobenius_norm_ratio = frobenius_norm1 / frobenius_norm2

    frobenius_error = torch.norm(matrix1 - matrix2, p='fro').item()

    return frobenius_norm_ratio.item(),frobenius_error


frobenius_ratio, frobenius_error = compare_unembedding_matrices(base_model_unembed, finetune_model_unembed)
print(f"Frobenius norm ratio: {frobenius_ratio}")
print(f"Frobenius error (Frobenius norm of (Unembed_base - Unembed_finetune)): {frobenius_error}")

Frobenius norm ratio: 1.439453125
Frobenius error (Frobenius norm of (Unembed_base - Unembed_finetune)): 1270.0


# Ploting & reporting feature similarities

In [51]:
PRELOAD_FEATURES = True

if PRELOAD_FEATURES:
  base_model_dead_base_act_path = datapath / f'{saving_name_base}_dead_base_act.pt'
  base_model_dead_finetune_act_path = datapath / f'{saving_name_base}_dead_finetune_act.pt'
  base_model_dense_act_path = datapath / f'{saving_name_base}_dense_act.pt'

  finetune_model_dead_base_act_path = datapath / f'{saving_name_ft}_dead_base_act.pt'
  finetune_model_dead_finetune_act_path = datapath / f'{saving_name_ft}_dead_finetune_act.pt'
  finetune_model_dense_act_path = datapath / f'{saving_name_ft}_dense_act.pt'

  base_model_dense_logits_path = datapath / f'{saving_name_base}_dense_logit_vectors.pt'
  base_model_dead_base_logits_path = datapath / f'{saving_name_base}_dead_base_logit_vectors.pt'
  base_model_dead_finetune_logits_path = datapath / f'{saving_name_base}_dead_finetune_logit_vectors.pt'

  finetune_model_dense_logits_path = datapath / f'{saving_name_ft}_dense_logit_vectors.pt'
  finetune_model_dead_base_logits_path = datapath / f'{saving_name_ft}_dead_base_logit_vectors.pt'
  finetune_model_dead_finetune_logits_path = datapath / f'{saving_name_ft}_dead_finetune_logit_vectors.pt'

  ### RETRIEVE THE LOOOST TENSORS
  base_model_dead_base_act = torch.load(base_model_dead_base_act_path)
  base_model_dead_finetune_act = torch.load(base_model_dead_finetune_act_path)
  base_model_dense_act = torch.load(base_model_dense_act_path)

  finetune_model_dead_base_act = torch.load(finetune_model_dead_base_act_path)
  finetune_model_dead_finetune_act = torch.load(finetune_model_dead_finetune_act_path)
  finetune_model_dense_act = torch.load(finetune_model_dense_act_path)

  base_model_dense_logits = torch.load(base_model_dense_logits_path)
  base_model_dead_base_logits = torch.load(base_model_dead_base_logits_path)
  base_model_dead_finetune_logits = torch.load(base_model_dead_finetune_logits_path)

  finetune_model_dense_logits = torch.load(finetune_model_dense_logits_path)
  finetune_model_dead_base_logits = torch.load(finetune_model_dead_base_logits_path)
  finetune_model_dead_finetune_logits = torch.load(finetune_model_dead_finetune_logits_path)


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is poss

In [52]:
dead_base_features =  torch.load(datapath / f'{saving_name_base}_dead_features.pt')
dead_finetune_features =  torch.load(datapath / f'{saving_name_ft}_dead_features.pt')
dense_features =  torch.load(datapath / f'{saving_name_base}_{saving_name_ft}_dense_features.pt')

print('Dead base features:', dead_base_features.shape)
print('Dead finetune features:', dead_finetune_features.shape)
print('Dense features:', dense_features.shape)

Dead base features: torch.Size([2])
Dead finetune features: torch.Size([100])
Dense features: torch.Size([1000])



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is poss

In [53]:
# @title
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def make_similarity_plots(x, y, feature_ids, plot_rows, plot_cols, similarity_score='Activation'):
    n_plots = x.shape[1]
    assert n_plots == plot_rows * plot_cols

    x = x.cpu().numpy()
    y = y.cpu().numpy()

    activations_similarity = [np.corrcoef(x[:, i], y[:, i])[0, 1] for i in range(n_plots)]

    # Set up subplot grid
    fig = make_subplots(rows=plot_rows, cols=plot_cols,
                        x_title=f'Base Feature {similarity_score}',
                        y_title=f'Finetune Feature {similarity_score}',
                        subplot_titles=[f'Activation similarity = {similarity}' for similarity in activations_similarity])

    # Generate scatter plots for each column
    for i in range(n_plots):
        # Compute correlation coefficient
        corr_coef = np.corrcoef(x[:, i], y[:, i])[0, 1]

        # Add scatter plot
        scatter = go.Scatter(
            x=x[:, i],
            y=y[:, i],
            mode='markers',
            marker=dict(color='red', opacity=0.5),
            name=f'Feature {feature_ids[i]} (Corr={corr_coef:.2f})'
        )

        # Place scatter plot in correct subplot
        fig.add_trace(scatter, row=(i//plot_cols)+1, col=(i%plot_cols)+1)

    # Update layout for better aesthetics
    fig.update_layout(
        title="Feature Activations Scatter Plots and Correlations",
        height=600,
        width=1000,
        showlegend=False
    )

    # Show the figure
    fig.show()

def make_single_similarity_plot(x, y, feature_id, features_family, similarity_score='Activation'):
    # Convert tensors to numpy arrays
    x = x.cpu().numpy()
    y = y.cpu().numpy()

    # Compute correlation coefficient for the given feature (column index)
    corr_coef = np.corrcoef(x[:, feature_id], y[:, feature_id])[0, 1]

    # Set up scatter plot for the specific feature
    scatter = go.Scatter(
        x=x[:, feature_id],
        y=y[:, feature_id],
        mode='markers',
        marker=dict(color='red', opacity=0.5),
        name=f'Feature {feature_id} (Corr={corr_coef:.2f})'
    )

    # Set up layout with axis labels and title
    layout = go.Layout(
        title=f'Feature Activations Scatter Plot (Feature {features_family[feature_id]})<br>Activation similarity = {corr_coef:.2f}',
        xaxis=dict(title=f'Base Feature {similarity_score}'),
        yaxis=dict(title=f'Finetune Feature {similarity_score}'),
        height=600,
        width=600,
        showlegend=False
    )

    # Create figure and add trace
    fig = go.Figure(data=[scatter], layout=layout)

    # Show the figure
    fig.show()

In [54]:
import numpy as np
import plotly.graph_objects as go

def show_features_similarities(x, y, features_family, similarity_score='Activation',
                               similarity_metric='corr', log_epsilon=1e-10, show_each_feature=True):
    # Ensure correct number of plots
    n_features = x.shape[1]

    # Convert tensors to float32 numpy arrays
    x = x.float().cpu().numpy()
    y = y.float().cpu().numpy()

    # Initialize list for storing similarity results
    activations_similarity = []

    # Compute similarities based on the selected metric
    if similarity_metric == 'corr':
        # Compute correlation for each feature
        activations_similarity = [np.corrcoef(x[:, i], y[:, i])[0, 1] for i in range(n_features)]
    elif similarity_metric == 'mae':
        # Compute mean absolute error for each feature
        x_log = np.log10(x + log_epsilon)
        y_log = np.log10(y + log_epsilon)

        activations_similarity = [np.mean(np.abs(x_log[:, i] - y_log[:, i])) for i in range(n_features)]
    else:
        raise ValueError(f"Unsupported similarity metric: {similarity_metric}. Use 'corr' or 'mae'.")

    # Compute and print mean similarity
    mean_similarity = np.nanmean(activations_similarity)
    print(f'\nMean {similarity_score} {similarity_metric} = {mean_similarity}')

    # Sort similarities along with feature names in ascending order
    sorted_indices = np.argsort(activations_similarity)
    sorted_features = [features_family[i] for i in sorted_indices]
    sorted_similarities = [activations_similarity[i] for i in sorted_indices]

    # Print sorted similarities
    if show_each_feature:
        print(f"\n--- {similarity_score} {similarity_metric} for each feature (sorted) ---")
        for feature, similarity in zip(sorted_features, sorted_similarities):
            print(f'Feature {feature} {similarity_score} {similarity_metric} = {similarity}')

    # Create a histogram of similarity scores using Plotly
    fig = go.Figure(data=[go.Histogram(x=activations_similarity, histnorm='probability',
                                       marker=dict(color='skyblue', line=dict(color='black', width=1)))])

    # Set plot titles and labels
    fig.update_layout(title=f'Histogram of {similarity_score} {similarity_metric.capitalize()} Scores',
                      xaxis_title=f'{similarity_metric.capitalize()} Score',
                      yaxis_title='Frequency',
                      bargap=0.2)

    # Show the figure
    fig.show()

In [55]:
def compute_features_density(feature_acts, activation_threshold=1e-10, log_epsilon=1e-10):
  n_fired = (feature_acts > activation_threshold).float().sum(0)
  total_tokens = feature_acts.shape[0]

  density = n_fired / total_tokens

  return torch.log10(density + log_epsilon)

## Dead features in the base model

In [56]:
# compute_features_density(base_model_dead_base_act), compute_features_density(finetune_model_dead_base_act),

In [57]:
show_features_similarities(base_model_dead_base_act, finetune_model_dead_base_act,
                           dead_base_features, similarity_metric='mae')


Mean Activation mae = 0.0011657890863716602

--- Activation mae for each feature (sorted) ---
Feature 10305 Activation mae = 0.0003908365615643561
Feature 19228 Activation mae = 0.0019407415529713035


In [58]:
show_features_similarities(base_model_dead_base_logits, finetune_model_dead_base_logits,
                           dead_base_features, similarity_score='Logits', similarity_metric='corr')


Mean Logits corr = 0.00913674147295268

--- Logits corr for each feature (sorted) ---
Feature 19228 Logits corr = 0.006383587528649522
Feature 10305 Logits corr = 0.01188989541725584


## Dead features in the finetune model

In [59]:
# compute_features_density(base_model_dead_finetune_act), compute_features_density(finetune_model_dead_finetune_act),

In [60]:
show_features_similarities(base_model_dead_finetune_act, finetune_model_dead_finetune_act,
                           dead_finetune_features, similarity_metric='mae')


Mean Activation mae = 0.02259930595755577

--- Activation mae for each feature (sorted) ---
Feature 3073 Activation mae = 0.0008876610081642866
Feature 15092 Activation mae = 0.001083930255845189
Feature 6133 Activation mae = 0.0011763365473598242
Feature 1390 Activation mae = 0.0012912701349705458
Feature 6459 Activation mae = 0.0017719403840601444
Feature 3470 Activation mae = 0.001902880030684173
Feature 12732 Activation mae = 0.0019588409923017025
Feature 7474 Activation mae = 0.002081411425024271
Feature 9877 Activation mae = 0.002191046951338649
Feature 16480 Activation mae = 0.0024867388419806957
Feature 8129 Activation mae = 0.0025594737380743027
Feature 21987 Activation mae = 0.0025621491950005293
Feature 19622 Activation mae = 0.0026385008823126554
Feature 15190 Activation mae = 0.0026872879825532436
Feature 8169 Activation mae = 0.0027476793620735407
Feature 19707 Activation mae = 0.0028988022822886705
Feature 22515 Activation mae = 0.003003152087330818
Feature 22630 Activa

In [61]:
# base_model_dead_finetune_act.max(), finetune_model_dead_finetune_act.max()

In [62]:
show_features_similarities(base_model_dead_finetune_logits, finetune_model_dead_finetune_logits,
                           dead_finetune_features, similarity_score='Logits', similarity_metric='corr')


Mean Logits corr = 0.0031031567557988525

--- Logits corr for each feature (sorted) ---
Feature 6853 Logits corr = -0.014434480895721268
Feature 20794 Logits corr = -0.013746589450139528
Feature 21987 Logits corr = -0.0115157482922034
Feature 23788 Logits corr = -0.01071752323447072
Feature 21182 Logits corr = -0.00948531845694934
Feature 19707 Logits corr = -0.007808629206158789
Feature 5081 Logits corr = -0.007065722094264308
Feature 13119 Logits corr = -0.007010014448979359
Feature 12732 Logits corr = -0.006584250501906217
Feature 18675 Logits corr = -0.00652795719801002
Feature 3990 Logits corr = -0.00650513862627843
Feature 22630 Logits corr = -0.006024768491504199
Feature 1390 Logits corr = -0.0055643291235544685
Feature 3368 Logits corr = -0.005556637991107616
Feature 10812 Logits corr = -0.005441104076116915
Feature 15092 Logits corr = -0.005225530562758399
Feature 10360 Logits corr = -0.005194092541484985
Feature 21749 Logits corr = -0.005153808976135503
Feature 8169 Logits c

## Dense features

In [63]:
show_features_similarities(base_model_dense_act, finetune_model_dense_act, dense_features)


invalid value encountered in divide


invalid value encountered in divide




Mean Activation corr = 0.0038617609901369114

--- Activation corr for each feature (sorted) ---
Feature 12969 Activation corr = -0.01165897000192567
Feature 2341 Activation corr = -0.005332488956431337
Feature 19817 Activation corr = -0.003611674808909776
Feature 17457 Activation corr = -0.0033149200861657705
Feature 16058 Activation corr = -0.0031275942701578696
Feature 13134 Activation corr = -0.002974154760357654
Feature 7054 Activation corr = -0.0026400460323596385
Feature 17836 Activation corr = -0.0025635520089877456
Feature 3857 Activation corr = -0.0023676697177537754
Feature 12699 Activation corr = -0.002260842033686954
Feature 17835 Activation corr = -0.0022261060915806457
Feature 14990 Activation corr = -0.002087537928927146
Feature 9009 Activation corr = -0.002072929833039102
Feature 19480 Activation corr = -0.0020475496937973053
Feature 8686 Activation corr = -0.001992951926084863
Feature 4689 Activation corr = -0.0019923529176611708
Feature 23143 Activation corr = -0.001

In [64]:
show_features_similarities(base_model_dense_logits, finetune_model_dense_logits, dense_features, similarity_score='Logits')


Mean Logits corr = 0.0035014999257401893

--- Logits corr for each feature (sorted) ---
Feature 16045 Logits corr = -0.023721937964668864
Feature 19303 Logits corr = -0.01450932913979746
Feature 3112 Logits corr = -0.013750313865247912
Feature 10701 Logits corr = -0.013726703974903427
Feature 206 Logits corr = -0.01356622480829009
Feature 5887 Logits corr = -0.013273630809777228
Feature 6210 Logits corr = -0.012859157089678903
Feature 17938 Logits corr = -0.012729197859902484
Feature 3987 Logits corr = -0.011326313914646462
Feature 6195 Logits corr = -0.011324625813384913
Feature 476 Logits corr = -0.011322394953848427
Feature 22915 Logits corr = -0.011245661602050436
Feature 11570 Logits corr = -0.010141429269561233
Feature 6214 Logits corr = -0.010116121154373306
Feature 17835 Logits corr = -0.010008902795032725
Feature 6022 Logits corr = -0.01000592110479715
Feature 5774 Logits corr = -0.009877701886827812
Feature 3185 Logits corr = -0.009875507069578744
Feature 20610 Logits corr =

In [65]:
# make_single_similarity_plot(base_model_dense_act, finetune_model_dense_act, 0, dense_features)