In [1]:
import plotly.subplots as sp
import plotly.graph_objects as go
import torch
from univ_utils import load_model_and_sae, get_running_activation_stats, load_data
device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


#### Universality across models for final checkpoints

In [2]:
n_batches = 40
batch_size = 32
train_data, val_data = load_data(dataset="openwebtext", device=device)

In [3]:
model_sae_pairs = [
    ("8-128", "f5khzlt0"),
    ("8-256", "fx5swct8"),
    ("8-512", "ga7p2hm5"),
    #("8-768", "zisfsfel"),
]

In [4]:
from itertools import product
all_stats = {}
for (model1_name, sae1_name), (model2_name, sae2_name) in product(model_sae_pairs, repeat=2):
    if model1_name == model2_name:
        continue
    model1, sae1 = load_model_and_sae(model1_name, sae1_name, None, device)
    model2, sae2 = load_model_and_sae(model2_name, sae2_name, None, device)
    print(model1_name, sae1_name, model2_name, sae2_name)
    
    stats = get_running_activation_stats(model1, model2, train_data, batch_size=batch_size, n_batches=n_batches, seed=34)
    
    all_stats[(model1_name, sae1_name), (model2_name, sae2_name)] = stats.to_cpu()
    del model1, model2, sae1, sae2, stats; torch.cuda.empty_cache()

  checkpoint = torch.load(ckpt_file, map_location=device)
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


8-128 f5khzlt0 8-256 fx5swct8


100%|██████████| 40/40 [00:06<00:00,  6.30it/s]


8-128 f5khzlt0 8-512 ga7p2hm5


100%|██████████| 40/40 [00:13<00:00,  2.91it/s]


8-256 fx5swct8 8-128 f5khzlt0


100%|██████████| 40/40 [00:06<00:00,  6.28it/s]


8-256 fx5swct8 8-512 ga7p2hm5


100%|██████████| 40/40 [00:19<00:00,  2.08it/s]


8-512 ga7p2hm5 8-128 f5khzlt0


100%|██████████| 40/40 [00:14<00:00,  2.82it/s]


8-512 ga7p2hm5 8-256 fx5swct8


100%|██████████| 40/40 [00:19<00:00,  2.05it/s]


In [5]:
num_layers = lambda model_name: int((model_name.split("-"))[1])

In [51]:
scatter_data = []
corr_coefs = []
subplot_titles = []
masked_scatter_data = []
masked_corr_coefs = []
masked_subplot_titles = []

for model1_name, sae1_name in model_sae_pairs:
    for model2_name, sae2_name in model_sae_pairs:
        if model1_name == model2_name:
            subplot_titles.append("____")
            masked_subplot_titles.append("____")
            continue

        stats = all_stats[(model1_name, sae1_name), (model2_name, sae2_name)]

        x = stats.max_x
        y = stats.corr_matrix.amax(dim=-1)
        scatter_data.append((x, y))

        # corr coef
        corr_coef = torch.corrcoef(torch.stack((x, y), dim=0))[0, 1].item()
        corr_coefs.append(corr_coef)

        # masked scatter plot data
        mask = (y > 0.9) & (x < 1) # TODO: get this mask based on UMAP
        x = x[~mask]
        y = y[~mask]
        masked_scatter_data.append((x, y))
        masked_corr_coef = torch.corrcoef(torch.stack((x, y), dim=0))[0, 1].item()
        masked_corr_coefs.append(masked_corr_coef)


        # subplot title
        nl1, nl2 = num_layers(model1_name), num_layers(model2_name)
        subplot_titles.append(f"Corr={corr_coef:.4f}")
        masked_subplot_titles.append(f"Corr={masked_corr_coef:.4f}")


In [None]:
# TODO: remove uninterpretable features from each SAE using UMAP

In [56]:
grid_length = len(model_sae_pairs)  # Ensure model_sae_pairs is defined and matches your data
grid_titles = [f"n_layers={model_name.split('-')[1]}" for model_name, _ in model_sae_pairs]
fig = sp.make_subplots(rows=grid_length, cols=grid_length, subplot_titles=masked_subplot_titles, row_titles=grid_titles, column_titles=grid_titles,
                       x_title="Maximum Feature Activation", y_title="Maximum Activation Similarity",
                       horizontal_spacing=0.05, vertical_spacing=0.05)

i = 0
for row in range(1, grid_length + 1):
    for col in range(1, grid_length + 1):
        if row == col:
            continue
        
        x, y = masked_scatter_data[i]
        corr_coef = masked_corr_coefs[i]
        
        fig.add_trace(
                go.Scatter(x=x.numpy(), y=y.numpy(), mode='markers', name=""),
                row=row, col=col
            )
        i += 1

fig.layout.annotations[9].update(y=1.025)
fig.layout.annotations[10].update(y=1.025)
fig.layout.annotations[11].update(y=1.025)

# Update layout for better display
fig.update_layout(
    title_text="Feature Importance (x-axis) vs Universality (y-axis)",
    showlegend=False,
    height=1200,
    width=1200,
)

fig.show()