In [1]:
import json
from collections.abc import Sequence
from functools import partial
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import wandb

from e2e_sae.log import logger
from e2e_sae.plotting import plot_facet, plot_per_layer_metric
from e2e_sae.scripts.analysis.plot_settings import (
    SIMILAR_CE_RUNS,
    SIMILAR_RUN_INFO,
    STYLE_MAP,
)
from e2e_sae.scripts.analysis.utils import create_run_df, get_df_gpt2
from e2e_sae.scripts.analysis.plot_performance import format_two_axes

  from .autonotebook import tqdm as notebook_tqdm


## Plot Performance of BAE with Different Beta Values

In [30]:
api = wandb.Api(api_key='b8fa6d3104a0f99ee8a99f7c7659b893559f1097')
project = "raymondl/tinystories-1m-local-bayesian-beta-sweep"
runs = api.runs(project)
df = create_run_df(runs, per_layer_metrics=False, use_run_name=True, grad_norm=False)
# df = df[df["name"].str.contains("beta_0.5") | df["name"].str.contains("beta_0.3")]

def assign_group(run_name: str) -> str:
    beta_value = f"{run_name.split("beta_")[-1].split("_")[0]}"
    return beta_value

df["grouping_type"] = df["name"].apply(assign_group)
plot_facet(
    df=df,
    xs=["CELossIncrease", "out_to_in"],
    y="L0",
    facet_by="layer",
    facet_vals=[4],
    line_by="grouping_type",
    xlabels=["CE Loss Increase", "Reconstruction MSE"],
    ylabel="L0",
    legend_title="Beta Value",
    axis_formatter=partial(format_two_axes, better_labels=True),
    out_file="plots/bayesian_local_beta_sweep_layer_4.png",
    xlims=[{4: (None, None)}, {4: (None, None)}],
    ylim={4: (None, None)},
    styles=STYLE_MAP,
    plot_type='line',
    annotate_col="sparsity_coeff",
)

Processing runs: 100%|██████████| 72/72 [00:00<00:00, 99.05it/s]
2025-05-02 15:12:53 - INFO - Saved plot to plots/bayesian_local_beta_sweep_layer_4.png
2025-05-02 15:12:53 - INFO - Saved SVG plot to plots/bayesian_local_beta_sweep_layer_4.svg


In [31]:

api = wandb.Api(api_key='b8fa6d3104a0f99ee8a99f7c7659b893559f1097')
project = "raymondl/tinystories-1m-e2e-bayesian-beta-sweep-no-relu"
runs = api.runs(project)
df = create_run_df(runs, per_layer_metrics=False, use_run_name=True, grad_norm=False)
# df = df[df["name"].str.contains("beta_0.5") | df["name"].str.contains("beta_0.2")]

def assign_group(run_name: str) -> str:
    beta_value = f"{run_name.split("beta_")[-1].split("_")[0]}"
    return beta_value

df["grouping_type"] = df["name"].apply(assign_group)
plot_facet(
    df=df,
    xs=["CELossIncrease", "out_to_in"],
    y="L0",
    facet_by="layer",
    facet_vals=[4],
    line_by="grouping_type",
    xlabels=["CE Loss Increase", "Reconstruction MSE"],
    ylabel="L0",
    legend_title="Beta Value",
    axis_formatter=partial(format_two_axes, better_labels=True),
    out_file="plots/bayesian_e2e_beta_sweep_layer_4.png",
    xlims=[{4: (None, None)}, {4: (None, None)}],
    ylim={4: (None, None)},
    styles=STYLE_MAP,
    plot_type='line',
    annotate_col="sparsity_coeff"
)

Processing runs: 100%|██████████| 130/130 [00:02<00:00, 48.79it/s]
2025-05-02 15:13:01 - INFO - Saved plot to plots/bayesian_e2e_beta_sweep_layer_4.png
2025-05-02 15:13:01 - INFO - Saved SVG plot to plots/bayesian_e2e_beta_sweep_layer_4.svg


## ReLU vs No ReLU


In [23]:
api = wandb.Api(api_key='b8fa6d3104a0f99ee8a99f7c7659b893559f1097')
sweep_project = "raymondl/tinystories-1m-e2e-bayesian-beta-sweep"
runs = api.runs(sweep_project)
sweep_project_no_relu = "raymondl/tinystories-1m-e2e-bayesian-beta-sweep-no-relu"
runs_no_relu = api.runs(sweep_project_no_relu)
beta_values = set()
df = create_run_df(runs, per_layer_metrics=False, use_run_name=True, grad_norm=False)

df["grouping_type"] = "ReLU"
df_no_relu = create_run_df(runs_no_relu, per_layer_metrics=False, use_run_name=True, grad_norm=False)
# Filter df_no_relu to only include runs that are in df
df_no_relu = df_no_relu[df_no_relu["name"].isin(df["name"])]
df_no_relu["grouping_type"] = "No ReLU"
df = pd.concat([df, df_no_relu], axis=0)
plot_facet(
    df=df,
    xs=["CELossIncrease"],
    y="L0",
    facet_by="layer",
    facet_vals=[4],
    line_by="grouping_type",
    xlabels=["CE Loss Increase"],
    ylabel="L0",
    legend_title="Beta Value",
    axis_formatter=None,
    out_file="plots/bayesian_e2e_relu_vs_no_relu_layer_4.png",
    xlims=[{4: (None, None)}, {4: (None, None)}],
    ylim={4: (None, None)},
    styles=STYLE_MAP,
    plot_type='line',
    annotate_col="sparsity_coeff"
)

Processing runs: 100%|██████████| 40/40 [00:00<00:00, 18471.01it/s]
Processing runs: 100%|██████████| 130/130 [00:02<00:00, 50.00it/s]
2025-05-02 14:49:15 - INFO - Saved plot to plots/bayesian_e2e_relu_vs_no_relu_layer_4.png
2025-05-02 14:49:15 - INFO - Saved SVG plot to plots/bayesian_e2e_relu_vs_no_relu_layer_4.svg


## Learning Rates

In [27]:
api = wandb.Api(api_key='b8fa6d3104a0f99ee8a99f7c7659b893559f1097')
project = "raymondl/tinystories-1m-e2e-bayesian-beta-sweep-no-relu"
runs = api.runs(project)
df = create_run_df(runs, per_layer_metrics=False, use_run_name=True, grad_norm=False)
df = df[df['name'].str.contains('beta_0.5_bayesian_seed-0')]
df["grouping_type"] = df["name"].apply(lambda x: f"LR: {x.split('lr-')[-1].split('_')[0]}")
plot_facet(
    df=df,
    xs=["CELossIncrease"],
    y="L0",
    facet_by="layer",
    facet_vals=[4],
    line_by="grouping_type",
    xlabels=["CE Loss Increase"],
    ylabel="L0",
    legend_title="Beta Value",
    axis_formatter=None,
    out_file="plots/bayesian_e2e_learning_rates_layer_4.png",
    xlims=[{4: (None, None)}, {4: (None, None)}],
    ylim={4: (None, None)},
    styles=STYLE_MAP,
    plot_type='line',
    annotate_col="sparsity_coeff"
)

Processing runs: 100%|██████████| 130/130 [00:02<00:00, 50.49it/s]
2025-05-02 15:10:32 - INFO - Saved plot to plots/bayesian_e2e_learning_rates_layer_4.png
2025-05-02 15:10:32 - INFO - Saved SVG plot to plots/bayesian_e2e_learning_rates_layer_4.svg


## TinyStories-1M Comparisons with Baselines

In [19]:
api = wandb.Api(api_key='b8fa6d3104a0f99ee8a99f7c7659b893559f1097')
project = "raymondl/tinystories-1m_play"
runs = api.runs(project)
df = create_run_df(runs, per_layer_metrics=False, use_run_name=True, grad_norm=False)
df = df[df['name'].str.contains('blocks.4.hook_resid_pre')]
df = df[df['name'].str.contains('local_seed') | df['name'].str.contains('e2e_seed') | df['name'].str.contains('ds_seed')]
def assign_group(run_name: str):
    # if 'e2e_bayesian' in run_name or 'beta_0.5' in run_name:
    #     return 'e2e (bayesian)'
    # elif 'local_bayesian' in run_name:
    #     return 'local (bayesian)'
    if 'ds' in run_name:
        return 'ds'
    elif 'e2e' in run_name:
        return 'e2e'
    else:
        return 'local'
df["grouping_type"] = df["name"].apply(assign_group)

project = "tinystories-1m-e2e-bayesian-beta-annealing"
runs = api.runs(project)
df2 = create_run_df(runs, per_layer_metrics=False, use_run_name=True, grad_norm=False)
df2 = df2[df2['name'].str.contains('linear_annealing')]
def assign_group2(run_name: str):
    beta_value = f"{run_name.split('beta_')[-1].split('_')[0]}"
    return f"e2e (bayesian) {beta_value}"
df2["grouping_type"] = df2["name"].apply(assign_group2)



df = pd.concat([df, df2], axis=0)

# df = df[df['CELossIncrease'] < 1]
# df = df[df['L0'] < 100]


plot_facet(
    df=df,
    xs=["L0", "alive_dict_elements"],
    y="CELossIncrease",
    facet_by="layer",
    facet_vals=[4],
    line_by="grouping_type",
    xlabels=["L0", "Alive Dict Elements"],
    ylabel="CE Loss Increase",
    legend_title="SAE Type",
    axis_formatter=partial(format_two_axes, better_labels=True),
    out_file="plots/l0_vs_ce_loss_layer_4.png",
    xlims=[{4: (None, None)}, {4: (None, None)}],
    ylim={4: (.5, 0)},
    styles=STYLE_MAP,
    plot_type='line',
    # annotate_col="sparsity_coeff"
)

Processing runs:   0%|          | 0/105 [00:00<?, ?it/s]

Run e2e_bayesian_bayesian_seed-0_lpcoeff-0.05_logits-kl-1.0_lr-0.001_ratio-50.0_blocks.4.hook_resid_pre is not finished, skipping
Run local_bayesian_bayesian_seed-0_lpcoeff-9e-05_in-to-out-1.0_lr-0.001_ratio-50.0_blocks.4.hook_resid_pre is not finished, skipping


Processing runs: 100%|██████████| 105/105 [00:01<00:00, 53.08it/s]
Processing runs: 100%|██████████| 51/51 [00:00<00:00, 643.18it/s]
  df = pd.concat([df, df2], axis=0)
2025-05-03 18:24:07 - INFO - Saved plot to plots/l0_vs_ce_loss_layer_4.png
2025-05-03 18:24:07 - INFO - Saved SVG plot to plots/l0_vs_ce_loss_layer_4.svg


In [None]:
api = wandb.Api(api_key='b8fa6d3104a0f99ee8a99f7c7659b893559f1097')
project = "raymondl/tinystories-1m_play"
runs = api.runs(project, filters={"tags": "local"})
df = create_run_df(runs, per_layer_metrics=False, use_run_name=True, grad_norm=False)

def assign_group(name):
    if "bayesian" in name:
        return "local (bayesian)"
    else:
        return "local"

df["grouping_type"] = df["name"].apply(assign_group)

plot_facet(
    df=df,
    xs=["CELossIncrease", "out_to_in"],
    y="L0",
    facet_by="layer",
    facet_vals=[4],
    line_by="grouping_type",
    xlabels=["CE Loss Increase", "Reconstruction MSE"],
    ylabel="L0",
    legend_title="SAE Type",
    axis_formatter=partial(format_two_axes, better_labels=True),
    out_file="plots/l0_vs_ce_loss_local_layer_4.png",
    xlims=[{4: (-0.5, 10)}, {4: (None, None)}],
    ylim={4: (None, None)},
    styles=STYLE_MAP,
    
    annotate_col="sparsity_coeff"
)


Processing runs:   0%|          | 0/61 [00:00<?, ?it/s]

Run local_bayesian_bayesian_seed-0_lpcoeff-9e-05_in-to-out-1.0_lr-0.001_ratio-50.0_blocks.4.hook_resid_pre is not finished, skipping


Processing runs: 100%|██████████| 61/61 [00:00<00:00, 182.75it/s]
2025-05-03 18:18:06 - INFO - Saved plot to plots/l0_vs_ce_loss_local_layer_4.png
2025-05-03 18:18:06 - INFO - Saved SVG plot to plots/l0_vs_ce_loss_local_layer_4.svg
