# Feature connectivity demo

This notebook gives a simple introduction to using the functionality of the [SAE network analysis package](https://github.com/owenparsons/sae_network_analysis). 


We'll start by importing the package as well as the torch library.

In [None]:
import sae_network_analysis
import torch
torch.set_grad_enabled(False)  # Disable gradient computation for inference

Check if MPS (Metal Performance Shaders) is available on macOS, otherwise fall back to CUDA or CPU.

In [None]:
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Next, we can load the pre-trained model using HookedSAETransformer. 
This example will use the "gpt2-small" model.

In [None]:
from sae_lens import SAE, HookedSAETransformer
model = HookedSAETransformer.from_pretrained("gpt2-small", device=device)

Then we load the SAE model.

In [None]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",  # The specific release of the model
    sae_id="blocks.7.hook_resid_pre",  # The specific layer to examine (SAE id)
    device=device
)

We'll load in out dataset, which is a list of athletes and the sport they play.
This will be converted to the task dataset that we'll be using

In [None]:
from sae_network_analysis.data_processing import create_prompt_dataframe, process_target_tokens
url = 'https://raw.githubusercontent.com/ali-ce/datasets/master/Most-paid-athletes/Athletes.csv'
context_col = "Name"
target_col = "Sport"
prompt = "Fact: {context_word} plays the sport of"

Create a DataFrame with the prompts based on the athletes' data.

In [None]:
prompt_df = create_prompt_dataframe(url, context_col, target_col, prompt)
prompt_df.head()

Tokenize the DataFrame using the model loaded earlier.

In [None]:
tokenised_df = process_target_tokens(prompt_df, model)
tokenised_df.head()

Run feature attribution using the tokenized data and the model. The results are merged into a comparison DataFrame.

In [None]:
from sae_network_analysis.sae_utils import run_feature_attribution, metric_fn
from sae_network_analysis.data_processing import convert_feature_attribution_dict_to_long, merge_dataframes

merge_method = 'intersection'
agg_method = 'last'
attribution_dict = run_feature_attribution(tokenised_df, "input", model, sae, metric_fn)
long_dict = convert_feature_attribution_dict_to_long(attribution_dict, sae)
comparison_df = merge_dataframes(long_dict, merge_method=merge_method, agg_method=agg_method)
comparison_df.head()

Next, let's cluster the features based on their correlation using network analysis
We create a correlation matrix, then extract the clusters of features.

In [None]:
from sae_network_analysis.network_analysis import cluster_features_with_network_analysis
full_network_filename = None # Save filename for the full network (set to None to avoid saving).

corr_matrix = comparison_df.corr()
clusters = cluster_features_with_network_analysis(corr_matrix, full_network_filename, threshold=0.7, plot_graph=True)

Let's sort the correlation matrix by clusters and then plot heatmaps to visualize the correlation matrix and cluster organization.

In [None]:
from sae_network_analysis.plot_utils import plot_heatmap_single

sorted_indices = []
for cluster, indices in clusters.items():
    sorted_indices.extend(indices)

sorted_corr_matrix = corr_matrix.loc[sorted_indices, sorted_indices]

corr_filename = None
sorted_corr_filename = None

plot_heatmap_single(corr_matrix, corr_filename, title='Feature Correlation Heatmap', cbar_range=(-1,1), figsize=(9, 7))
plot_heatmap_single(sorted_corr_matrix, sorted_corr_filename, title='Feature Correlation Heatmap (Sorted by Clusters)', cbar_range=(-1,1), figsize=(9, 7))

We can also filtering cluster based on size.

In [None]:
from sae_network_analysis.data_processing import filter_for_cluster_size_threshold, filter_for_specific_clusters
from sae_network_analysis.plot_utils import plot_cluster_heatmap_single

main_clusters = filter_for_cluster_size_threshold(clusters, threshold=10)
subset_corr_matrix = filter_for_specific_clusters(corr_matrix, main_clusters)

subset_corr_filename = None
subset_sorted_corr_filename = None

secondary_clusters = cluster_features_with_network_analysis(subset_corr_matrix, subset_corr_filename, threshold=0.7, plot_graph=True, show_edge_weights=True, show_legend=True) # This will recluster based on the subset features
plot_cluster_heatmap_single(subset_corr_matrix, secondary_clusters, subset_sorted_corr_filename, title='Feature Correlation Heatmap Grouped by Clusters', cbar_range=(-1,1))

Finally, let's calculate centrality measures for the features based on the network and also include metrics for the attribution scores.

In [None]:
from sae_network_analysis.network_analysis import compute_centrality_measures, calculate_activation_metrics

centrality_df = compute_centrality_measures(corr_matrix, threshold=0.7)
centrality_df = calculate_activation_metrics(centrality_df, comparison_df)

centrality_df.head()

In [None]:
# Plot correlations between metrics

import seaborn as sns
import matplotlib.pyplot as plt
sns.heatmap(
    centrality_df.corr(),
    cmap="YlGnBu",
    center=0,
    vmin=0,  # Set minimum value for colorbar
    vmax=1,  # Set maximum value for colorbar
    cbar_kws={'label': 'Correlation Coefficient', 'orientation': 'vertical'}  # Add colorbar label
)
plt.title("Correlation Coefficients Between Metric Scores", fontsize=14, pad=20)
plt.show()

We need to create an activation store to manage activations from the model.

In [None]:
from sae_lens import ActivationsStore

activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

We can also calculate the maximum activations for the selected features over multiple batches, this can be used to scale activations during feature steering if needed.

In [None]:
from sae_network_analysis.sae_utils import find_max_activations
import pandas as pd
import numpy as np

multi_features = torch.tensor(subset_corr_matrix.columns.values)
max_activations = find_max_activations(model, sae, activation_store, multi_features, num_batches=100)

max_act_df = pd.DataFrame(np.array([multi_features.detach().cpu().numpy(), max_activations.detach().cpu().numpy()]).T, columns=['Features', 'Max_activations'])