In [1]:
try:
    import google.colab
    IN_COLAB = True
    %pip install sae-lens transformer-lens
except ImportError:
    IN_COLAB = False



In [2]:
# Standard imports
import os
import sys
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops

# import the LLM
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Device: cuda


# Config

In [3]:
# define the model to work with
GEMMA = True

if GEMMA == True:
    RELEASE = 'gemma-2b-res-jb'
    BASE_MODEL = "google/gemma-2b"
    FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'
    DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
    hook_part = 'post'
    layer_num = 6
else:
    RELEASE = 'gpt2-small-res-jb'
    BASE_MODEL = "gpt2-small"
    FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'
    DATASET_NAME = "Skylion007/openwebtext"
    hook_part = 'pre'
    layer_num = 6

In [4]:
import plotly.graph_objs as go
from functools import partial

def plot_log10_hist(y_data, y_value, num_bins=100, first_bin_name = 'First bin value',
                    y_scalar=1.5, y_scale_bin=-2, log_epsilon=1e-10):
    """
    Computes the histogram using PyTorch and plots the feature density diagram with log-10 scale using Plotly.
    Y-axis is clipped to the value of the second-largest bin to prevent suppression of smaller values.
    """
    # Flatten the tensor
    y_data_flat = torch.flatten(y_data)

    # Compute the logarithmic transformation using PyTorch
    log_y_data_flat = torch.log10(torch.abs(y_data_flat) + log_epsilon).detach().cpu()

    # Compute histogram using PyTorch
    hist_min = torch.min(log_y_data_flat).item()
    hist_max = torch.max(log_y_data_flat).item()
    hist_range = hist_max - hist_min
    bin_edges = torch.linspace(hist_min, hist_max, num_bins + 1)
    hist_counts, _ = torch.histogram(log_y_data_flat, bins=bin_edges)

    # Convert data to NumPy for Plotly
    bin_edges_np = bin_edges.detach().cpu().numpy()
    hist_counts_np = hist_counts.detach().cpu().numpy()

    # Find the largest and second-largest bin values
    first_bin_value = hist_counts_np[0]
    scale_bin_value = sorted(hist_counts_np)[y_scale_bin]  # Get the second largest bin value (by default)

    # Prepare the Plotly plot
    fig = go.Figure(
        data=[go.Bar(
            x=bin_edges_np[:-1],  # Exclude the last bin edge
            y=hist_counts_np,
            width=hist_range / num_bins,
        )]
    )

    # Update the layout for the plot, clipping the y-axis at the second largest bin value
    fig.update_layout(
        title=f"SAE Features {y_value} histogram ({first_bin_name}: {first_bin_value:.2e})",
        xaxis_title=f"Log10 of {y_value}",
        yaxis_title="Density",
        yaxis_range=[0, scale_bin_value * y_scalar],  # Clipping to the second-largest value by default
        bargap=0.2,
        bargroupgap=0.1,
    )

    # Add an annotation to display the value of the first bin
    fig.add_annotation(
        text=f"{first_bin_name}: {first_bin_value:.2e}",
        xref="paper", yref="paper",
        x=0.95, y=0.95,
        showarrow=False,
        font=dict(size=12, color="red"),
        bgcolor="white",
        bordercolor="black",
        borderwidth=1
    )

    # Show the plot
    fig.show()

class FeatureDensityPlotter:
    def __init__(self, n_features=None, n_tokens=None, activation_threshold=1e-10, num_bins=100,
                 feature_densities=None):
        if feature_densities is not None:
            self.feature_densities = feature_densities
            return

        self.num_bins = num_bins
        self.activation_threshold = activation_threshold

        self.n_tokens = n_tokens
        self.n_features = n_features

        # Initialize a tensor of feature densities for all features,
        # where feature density is defined as the fraction of tokens on which the feature has a nonzero value.
        self.feature_densities = torch.zeros(n_features, dtype=torch.float32)

    def update(self, feature_acts):
        """
        Expects a tensor feature_acts of shape [N_TOKENS, N_FEATURES].

        Updates the feature_densities buffer:
        1. For each feature, count the number of tokens that the feature activated on (i.e. had an activation greater than the activation_threshold)
        2. Add this count at the feature's position in the feature_densities tensor, divided by the total number of tokens (to compute the fraction)
        """

        activating_tokens_count = (feature_acts > self.activation_threshold).float().sum(0)
        self.feature_densities += activating_tokens_count / self.n_tokens

    def plot(self, num_bins=100, y_scalar=1.5, y_scale_bin=-2, log_epsilon=1e-10):
        plot_log10_hist(self.feature_densities, 'Density', num_bins=num_bins, first_bin_name='Dead features density',
                        y_scalar=y_scalar, y_scale_bin=y_scale_bin, log_epsilon=log_epsilon)

# Feature densities loading & plotting

## Base model feature densities

In [5]:
# Choose saving names consistent with saetuning/get_scores.py
saving_name_base = BASE_MODEL if "/" not in BASE_MODEL else BASE_MODEL.split("/")[-1]
saving_name_ft = FINETUNE_MODEL if "/" not in FINETUNE_MODEL else FINETUNE_MODEL.split("/")[-1]
saving_name_ds = DATASET_NAME if "/" not in DATASET_NAME else DATASET_NAME.split("/")[-1]

base_feature_densities_fname = f'Feature_densities_{saving_name_base}_on_{saving_name_ds}.pt'

if IN_COLAB:
    # If in Colab, mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    # Define the path to your JSON file in Google Drive
    load_path = os.path.join('/content/drive/My Drive', base_feature_densities_fname)
else:
    # If not in Colab, use local folder
    # Assuming this is being run from the 'notebooks' folder
    sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

    from saetuning.utils import get_env_var
    _, datapath = get_env_var()

    load_path = datapath / base_feature_densities_fname

base_feature_densities = torch.load(load_path)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


  base_feature_densities = torch.load(load_path)


In [6]:
n_features = base_feature_densities.numel()
n_features

16384

In [7]:
density_plotter = FeatureDensityPlotter(feature_densities=base_feature_densities)
density_plotter.plot()

## Finetune model feature densities

In [8]:
finetune_feature_densities_fname = f'Feature_densities_{saving_name_ft}_on_{saving_name_ds}.pt'

if IN_COLAB:
    # If in Colab, mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    # Define the path to your JSON file in Google Drive
    load_path = os.path.join('/content/drive/My Drive', finetune_feature_densities_fname)
else:
    # If not in Colab, use local folder
    # Assuming this is being run from the 'notebooks' folder
    sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

    from saetuning.utils import get_env_var
    _, datapath = get_env_var()

    load_path = datapath / finetune_feature_densities_fname

finetune_feature_densities = torch.load(load_path)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [9]:
density_plotter = FeatureDensityPlotter(feature_densities=finetune_feature_densities)
density_plotter.plot()

In [10]:
log_epsilon=1e-10

base_feature_densities = torch.log10(base_feature_densities + log_epsilon)
finetune_feature_densities = torch.log10(finetune_feature_densities + log_epsilon)

In [11]:
import plotly.graph_objs as go
from scipy.stats import linregress

# Convert tensors to NumPy arrays for compatibility with other libraries
base_feature_densities_np = base_feature_densities.cpu().numpy()
finetune_feature_densities_np = finetune_feature_densities.cpu().numpy()

# Perform linear regression
slope, intercept, r_value, p_value, std_err = linregress(base_feature_densities_np, finetune_feature_densities_np)

# Define the regression line
regression_line = slope * base_feature_densities_np + intercept

# Create scatter plot
scatter_trace = go.Scatter(
    x=base_feature_densities_np,
    y=finetune_feature_densities_np,
    mode='markers',
    name='Data points',
    marker=dict(size=5, opacity=0.7)
)

# Create the regression line plot
line_trace = go.Scatter(
    x=base_feature_densities_np,
    y=regression_line,
    mode='lines',
    name=f'Regression line (R = {r_value:.2f})',
    line=dict(color='red')
)

# Set up the layout
layout = go.Layout(
    title='Scatter Plot of SAE Features Densities',
    xaxis=dict(title='Base Model SAE Densities'),
    yaxis=dict(title='Finetuned Model SAE Densities'),
    showlegend=True
)

# Combine the traces into a figure
fig = go.Figure(data=[scatter_trace, line_trace], layout=layout)

# Show the plot
fig.show()

# Print correlation coefficient
print(f"Correlation coefficient (R): {r_value:.4f}")

Correlation coefficient (R): 0.4551


In [12]:
base_feature_densities.mean(), finetune_feature_densities.mean()

(tensor(-6.1313), tensor(-2.6128))

In [13]:
df = pd.DataFrame.from_dict({
    'base_feature_densities': base_feature_densities_np,
    'finetune_feature_densities': finetune_feature_densities_np
})
df

Unnamed: 0,base_feature_densities,finetune_feature_densities
0,-2.683964,-2.194224
1,-2.648100,-2.369074
2,-10.000000,-3.688081
3,-2.822779,-2.440926
4,-10.000000,-3.743128
...,...,...
16379,-10.000000,-4.570966
16380,-10.000000,-1.866895
16381,-10.000000,-4.913386
16382,-3.567037,-2.337971


In [14]:
import plotly.express as px

fig = px.parallel_coordinates(df,
                              dimensions=['base_feature_densities', 'finetune_feature_densities']
                             )
fig.show()

# Sampling features from density intervals

In [15]:
import torch

def subsample_indices(tensor, total_samples=10, low_threshold=-8, low_percentage=0.2, high_percentage=0.8,
                      log=False):
    """
    Subsample indices from the tensor based on the log10 scale density.

    Parameters:
    - tensor (torch.Tensor): The input tensor of feature densities.
    - total_samples (int): The total number of samples to return.
    - low_threshold (float): The log10 threshold for the lowest bar (default: -9.5, approximates -10).
    - low_percentage (float): The percentage of total_samples to take from the lowest bar (default: 0.2).
    - high_percentage (float): The percentage of total_samples to take from the high density interval (default: 0.8).

    Returns:
    - combined_indices (torch.Tensor): A tensor containing the subsampled indices.
    """
    if log:
        # Convert tensor to log10 scale
        tensor = torch.log10(tensor)

    # 1. Subsample low_percentage of samples from the lowest bar
    lowest_bar_indices = (tensor <= low_threshold)
    num_lowest_samples = int(total_samples * low_percentage)
    lowest_bar_sample_indices = torch.nonzero(lowest_bar_indices).squeeze(1)

    if len(lowest_bar_sample_indices) > num_lowest_samples:
        lowest_bar_sample_indices = lowest_bar_sample_indices[torch.randperm(len(lowest_bar_sample_indices))[:num_lowest_samples]]

    # 2. Subsample high_percentage of samples from the interval [-5, -1] in log scale
    high_density_indices = (tensor >= -5) & (tensor <= -1)
    num_high_density_samples = int(total_samples * high_percentage)
    high_density_sample_indices = torch.nonzero(high_density_indices).squeeze(1)

    if len(high_density_sample_indices) > num_high_density_samples:
        high_density_sample_indices = high_density_sample_indices[torch.randperm(len(high_density_sample_indices))[:num_high_density_samples]]

    return lowest_bar_sample_indices, high_density_sample_indices

In [16]:
dead_base_features, dense_base_features = subsample_indices(base_feature_densities)
dead_base_features, dense_base_features

(tensor([ 4979, 15109]),
 tensor([12143,  8699,  3689,  2020,  8217, 14095,  8952, 14668]))

In [17]:
base_feature_densities[dead_base_features], base_feature_densities[dense_base_features]

(tensor([-10., -10.]),
 tensor([-2.3100, -3.4055, -2.5217, -2.7361, -2.7805, -3.2144, -1.2508, -2.8200]))

# Interpreting features from different density intervals

In [18]:
# import the required libraries
from sae_lens import SAE

sae_id = f'blocks.{layer_num}.hook_resid_{hook_part}'

sae, cfg_dict, sparsity = SAE.from_pretrained(
                            release = RELEASE,
                            sae_id = sae_id,
                            device = device
)

## Base model

In [19]:
base_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, dtype=torch.float16)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2b into HookedTransformer


In [20]:
from sae_lens import ActivationsStore

batch_size_prompts = 8

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=base_model,
    sae=sae,
    streaming=True,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=batch_size_prompts,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

batch_size_tokens = activation_store.context_size * batch_size_prompts
batch_size_tokens

Resolving data files:   0%|          | 0/23781 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/23781 [00:00<?, ?it/s]



8192

### Feature activation vectors

In [21]:
def get_feature_activations(features_ids, model, sae=sae, total_batches=50, activation_store=None,
                            batch_size_prompts=batch_size_prompts, sae_id=sae_id, layer_num=layer_num,
                            model_name=BASE_MODEL, exclude_bos=True, tokens_sample=None, return_dtype=torch.float16):
    base_model_run = tokens_sample is None

    if base_model_run:
        assert activation_store is not None
        def get_tokens(k):
            """Returns the tokens sampled from the activation store"""
            # Get the corresponding batch of tokens from all_tokens
            tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
            if exclude_bos and activation_store.prepend_bos:
                tokens = tokens[:, 1:]

            return tokens
    else:
        def get_tokens(k):
            """Returns the tokens for the k-th outer batch, where 0 <= k < TOTAL_BATCHES"""

            # Get the corresponding batch of tokens from all_tokens
            tokens = tokens_sample[k]  # [N_BATCH, N_CONTEXT]
            return tokens

    if base_model_run:
        all_tokens = []
    all_feature_activations = []

    for k in tqdm(range(total_batches)):
        # Get a batch of tokens from the dataset
        tokens = get_tokens(k)  # [N_BATCH, N_CONTEXT]

        # Run the model and store the activations
        _, cache = model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                        names_filter=[sae_id]) # [N_BATCH, N_CONTEXT, D_MODEL]

        # Get the activations from the cache at the sae_id
        sae_in = cache[sae_id].flatten(0, 1) # [N_BATCH * N_CONTEXT, D_MODEL]

        del cache
        clear_cache()

        # Store tokens for later reuse
        if base_model_run:
            all_tokens.append(tokens)

        # Encode the activations with the SAE
        sae_hidden = sae.encode(sae_in) # [N_BATCH * N_CONTEXT, N_HIDDEN]

        # Select only given features
        feature_activations = sae_hidden[:, features_ids] # [N_BATCH * N_CONTEXT, len(feature_ids)]
        all_feature_activations.append(feature_activations)

        # Explicitly free up memory by deleting the cache and emptying the CUDA cache
        del sae_in, sae_hidden
        clear_cache()

    tokens_dataset = torch.stack(all_tokens, dim=0) if base_model_run else None
    all_feature_activations = torch.cat(all_feature_activations, dim=0).to(return_dtype)

    return all_feature_activations, tokens_dataset

In [22]:
base_model_dead_act, base_model_dead_act_tokens = get_feature_activations(dead_base_features, base_model,
                                                                          activation_store=activation_store)
base_model_dead_act.shape, base_model_dead_act_tokens.shape

100%|██████████| 50/50 [01:22<00:00,  1.64s/it]


(torch.Size([409200, 2]), torch.Size([50, 8, 1023]))

In [23]:
base_model_dense_act, base_model_dense_act_tokens = get_feature_activations(dense_base_features, base_model,
                                                                            activation_store=activation_store)
base_model_dense_act.shape, base_model_dense_act_tokens.shape

100%|██████████| 50/50 [01:20<00:00,  1.61s/it]


(torch.Size([409200, 8]), torch.Size([50, 8, 1023]))

In [24]:
torch.save(base_model_dead_act, f'base_model_dead_act.pt')
torch.save(base_model_dense_act, f'base_model_dense_act.pt')

torch.save(base_model_dead_act_tokens, f'base_model_dead_act_tokens.pt')
torch.save(base_model_dense_act_tokens, f'base_model_dense_act_tokens.pt')

### Feature logit vectors (TODO)

In [25]:
def get_feature_vector(feature_id, sae=sae):
    return sae.W_dec[feature_id]

def get_feature_logits(feature_id, model, sae=sae):
    feature_vector = get_feature_vector(feature_id, sae)
    logits = feature_vector @ model.W_U

    return logits

In [26]:
def get_features_logit_vectors(features_ids, tokens, model, sae=sae):
    """
    Assumes tokens of shape [total_batches, batch*seq]
    """
    logit_vectors = []

    for feature_id in features_ids:
        feature_logits = get_feature_logits(feature_id, model)

        # TODO: build feature vector here

        logit_vectors.append(logit_vector)
    return logit_vectors

In [27]:
del base_model, activation_store
clear_cache()

## Finetune model

In [28]:
# Load the finetune model and its tokenizer
finetune_tokenizer = AutoTokenizer.from_pretrained(FINETUNE_MODEL)
finetune_model_hf = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL)
finetune_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, hf_model=finetune_model_hf, dtype=torch.float16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2b into HookedTransformer


In [29]:
del finetune_model_hf
clear_cache()

In [21]:
# base_model_dense_token_path = os.path.join('/content/drive/My Drive', f'base_model_dense_act_tokens.pt')
# base_model_dead_token_path = os.path.join('/content/drive/My Drive', f'base_model_dead_act_tokens.pt')

# base_model_dense_act_tokens = torch.load(base_model_dense_token_path)
# base_model_dead_act_tokens = torch.load(base_model_dead_token_path)


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is poss

In [30]:
base_model_dense_act_tokens.shape, base_model_dead_act_tokens.shape

(torch.Size([50, 8, 1023]), torch.Size([50, 8, 1023]))

In [31]:
finetune_model_dense_act, _ = get_feature_activations(dense_base_features, finetune_model,
                                                      tokens_sample=base_model_dense_act_tokens)
finetune_model_dead_act, _ = get_feature_activations(dead_base_features, finetune_model,
                                                     tokens_sample=base_model_dead_act_tokens)

100%|██████████| 50/50 [01:17<00:00,  1.55s/it]
100%|██████████| 50/50 [01:16<00:00,  1.53s/it]


In [50]:
torch.save(finetune_model_dense_act, f'finetune_model_dense_act.pt')
torch.save(finetune_model_dead_act, f'finetune_model_dead_act.pt')

## Ploting & reporting the feature similarities

In [54]:
PRELOAD_FEATURES = True

if PRELOAD_FEATURES:
  # Define the path to your JSON file in Google Drive
  base_model_dense_path = os.path.join('/content/drive/My Drive', f'base_model_dense_act.pt')
  base_model_dead_path = os.path.join('/content/drive/My Drive', f'base_model_dead_act.pt')
  finetune_model_dead_path = os.path.join('/content/drive/My Drive', f'finetune_model_dead_act.pt')
  finetune_model_dense_path = os.path.join('/content/drive/My Drive', f'finetune_model_dense_act.pt')

  base_model_dense_act = torch.load(base_model_dense_path)
  base_model_dead_act = torch.load(base_model_dead_path)
  finetune_model_dead_act = torch.load(finetune_model_dead_path)
  finetune_model_dense_act = torch.load(finetune_model_dense_path)


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is poss

In [33]:
assert finetune_model_dense_act.shape == base_model_dense_act.shape
assert finetune_model_dead_act.shape == base_model_dead_act.shape

finetune_model_dense_act.shape, finetune_model_dead_act.shape

(torch.Size([409200, 8]), torch.Size([409200, 2]))

In [34]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def make_similarity_plots(x, y, feature_ids, plot_rows, plot_cols, similarity_score='Activation'):
    n_plots = x.shape[1]
    assert n_plots == plot_rows * plot_cols

    x = x.cpu().numpy()
    y = y.cpu().numpy()

    activations_similarity = [np.corrcoef(x[:, i], y[:, i])[0, 1] for i in range(n_plots)]

    # Set up subplot grid
    fig = make_subplots(rows=plot_rows, cols=plot_cols,
                        x_title=f'Base Feature {similarity_score}',
                        y_title=f'Finetune Feature {similarity_score}',
                        subplot_titles=[f'Activation similarity = {similarity}' for similarity in activations_similarity])

    # Generate scatter plots for each column
    for i in range(n_plots):
        # Compute correlation coefficient
        corr_coef = np.corrcoef(x[:, i], y[:, i])[0, 1]

        # Add scatter plot
        scatter = go.Scatter(
            x=x[:, i],
            y=y[:, i],
            mode='markers',
            marker=dict(color='red', opacity=0.5),
            name=f'Feature {feature_ids[i]} (Corr={corr_coef:.2f})'
        )

        # Place scatter plot in correct subplot
        fig.add_trace(scatter, row=(i//plot_cols)+1, col=(i%plot_cols)+1)

    # Update layout for better aesthetics
    fig.update_layout(
        title="Feature Activations Scatter Plots and Correlations",
        height=600,
        width=1000,
        showlegend=False
    )

    # Show the figure
    fig.show()

def make_single_similarity_plot(x, y, feature_id, features_family, similarity_score='Activation'):
    # Convert tensors to numpy arrays
    x = x.cpu().numpy()
    y = y.cpu().numpy()

    # Compute correlation coefficient for the given feature (column index)
    corr_coef = np.corrcoef(x[:, feature_id], y[:, feature_id])[0, 1]

    # Set up scatter plot for the specific feature
    scatter = go.Scatter(
        x=x[:, feature_id],
        y=y[:, feature_id],
        mode='markers',
        marker=dict(color='red', opacity=0.5),
        name=f'Feature {feature_id} (Corr={corr_coef:.2f})'
    )

    # Set up layout with axis labels and title
    layout = go.Layout(
        title=f'Feature Activations Scatter Plot (Feature {features_family[feature_id]})<br>Activation similarity = {corr_coef:.2f}',
        xaxis=dict(title=f'Base Feature {similarity_score}'),
        yaxis=dict(title=f'Finetune Feature {similarity_score}'),
        height=600,
        width=600,
        showlegend=False
    )

    # Create figure and add trace
    fig = go.Figure(data=[scatter], layout=layout)

    # Show the figure
    fig.show()

def show_features_similarities(x, y, features_family, similarity_score='Activation'):
    # Ensure correct number of plots
    n_features = x.shape[1]

    # Convert tensors to numpy arrays
    x = x.cpu().numpy()
    y = y.cpu().numpy()

    # Compute activation similarities for each feature
    activations_similarity = [np.corrcoef(x[:, i], y[:, i])[0, 1] for i in range(n_features)]

    # Print activation similarities for each feature
    for i, similarity in enumerate(activations_similarity):
        print(f'Feature {features_family[i]} {similarity_score} similarity = {similarity:.2f}')

### Dead features

In [40]:
show_features_similarities(base_model_dead_act, finetune_model_dead_act, dead_base_features)

Feature 4979 Activation similarity = nan
Feature 15109 Activation similarity = nan


In [53]:
make_similarity_plots(base_model_dead_act, finetune_model_dead_act, dead_base_features,
                      plot_rows=1, plot_cols=2)

### Dense features

In [43]:
show_features_similarities(base_model_dense_act, finetune_model_dense_act, dense_base_features)

Feature 12143 Activation similarity = -0.00
Feature 8699 Activation similarity = -0.00
Feature 3689 Activation similarity = -0.00
Feature 2020 Activation similarity = 0.00
Feature 8217 Activation similarity = -0.00
Feature 14095 Activation similarity = -0.00
Feature 8952 Activation similarity = 0.00
Feature 14668 Activation similarity = 0.00


In [44]:
n_dense = dense_base_features.numel()
n_dense

8

In [52]:
# make_single_similarity_plot(base_model_dense_act, finetune_model_dense_act, 1, dense_base_features)