In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
from dash import no_update
from dash import dash_table
import pandas as pd
import plotly.graph_objs as go

# Matplotlib stuff
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm

# Same as notebook 7
from transformer_lens import utils
from sprint.loading import load_all
from sprint.linearization import analyze_linearized_feature
from sprint.attention import get_attn_head_contribs, get_attn_head_contribs_ov
from sprint.sae_tutorial import make_token_df, process_tokens
from sprint.visualization import visualize_topk_plotly, plot_attn_contribs_for_example

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def value_to_color(val, min_val, max_val):
    # Normalize the value
    norm = mcolors.Normalize(vmin=min_val, vmax=max_val)
    # Choose a colormap
    cmap = cm.coolwarm
    # Get color from colormap
    return mcolors.to_hex(cmap(norm(val)))


def generate_style_data_conditional(df):
    style_data_conditional = []
    for column in df.columns:
        if pd.api.types.is_numeric_dtype(df[column]):
            min_val = df[column].min()
            max_val = df[column].max()
            column_styles = [
                {
                    "if": {"filter_query": f"{{{column}}} eq {val}", "column_id": column},
                    "backgroundColor": value_to_color(val, min_val, max_val),
                    "color": "black",  # or 'white' depending on contrast
                }
                for val in df[column].unique()
            ]
            style_data_conditional.extend(column_styles)
    return style_data_conditional

In [4]:
# Initialize Dash app


app = dash.Dash(__name__)

# Define your layour
app.layout = html.Div(
    [
        # Model and data loading section
        html.Div(
            [
                html.Div(
                    [
                        html.Label("Model + SAE:"),
                        dcc.Dropdown(
                            id="model_sae_name",
                            options=[
                                {"label": "gelu-1l, run 1", "value": "gelu-1l, run 1"},
                                {"label": "gelu-1l, run 2", "value": "gelu-1l, run 2"},
                                {"label": "gelu-2l, layer 0", "value": "gelu-2l, layer 0"},
                                {"label": "gelu-2l, layer 1", "value": "gelu-2l, layer 1"},
                            ],
                            style={"width": "50%"},  # Adjust width here
                            value="gelu-1l, run 1",
                        ),
                    ],
                    style={"margin-bottom": "10px"},  # Adjust spacing as needed
                ),
                html.Div([html.Button("Load model", id="load-button", n_clicks=0)], style={"margin-bottom": "10px"}),
            ],
            style={"width": "300px", "margin": "auto"},
        ),
        # Feature, sample, token, attention head callback happens here
        html.Div(
            [
                html.Div(
                    [html.Label("Feature ID:"), dcc.Input(id="feature-id", type="text", value="4542")],
                    style={"margin-bottom": "10px"},
                ),  # Add margin for spacing
                html.Div(
                    [html.Label("Sample #:"), dcc.Input(id="sample-idx", type="text", value="38")],
                    style={"margin-bottom": "10px"},
                ),
                html.Div(
                    [html.Label("Token #:"), dcc.Input(id="token-idx", type="text", value="73")],
                    style={"margin-bottom": "10px"},
                ),
                html.Div(
                    [
                        html.Label("Attention Head:"),
                        dcc.Dropdown(
                            id="attention-head",
                            options=[{"label": str(i), "value": i} for i in range(8)],
                            style={"width": "50%"},  # Adjust width here
                            value=0,
                        ),
                    ],
                    style={"margin-bottom": "10px"},
                ),
                html.Div(
                    [html.Label("Batch size"), dcc.Input(id="batch-size", type="text", value="64")],
                    style={"margin-bottom": "10px"},
                ),
                html.Button("Update", id="update-button", n_clicks=0),
            ],
            style={"width": "300px", "margin": "auto"},
        ),  # Adjust the overall width of the input container
        html.Div(id="model-name", children=[]),
        html.Div(id="output-container", children=[]),
    ]
)

# Good defaults
global model, data, sae, sae_layer
model = None
# model, data, sae = load_all(model_name="gelu-1l", run_id="run1")


# Initialize your data and model
@app.callback(
    Output("model-name", "children"),
    [Input("load-button", "n_clicks")],
    [State("model_sae_name", "value")],
)
def load_model_data_sae(n_clicks, model_sae_name):
    if n_clicks < 1:
        return []
    global model, data, sae, sae_layer

    if model_sae_name == "gelu-1l, run 1":
        model_name, sae_name, sae_layer = "gelu-1l", "run1", 0
    elif model_sae_name == "gelu-1l, run 2":
        model_name, sae_name, sae_layer = "gelu-1l", "run2", 0
    elif model_sae_name == "gelu-2l, layer 0":
        model_name, sae_name, sae_layer = "gelu-2l", "l0", 0
    elif model_sae_name == "gelu-2l, layer 1":
        model_name, sae_name, sae_layer = "gelu-2l", "l1", 1

    model, data, sae = load_all(model_name=model_name, run_id=sae_name)
    return [html.Div(f"Loaded model: {model_name}, SAE: {sae_name}, SAE Layer: {sae_layer}")]


# Define the callback for updating outputs
@app.callback(
    Output("output-container", "children"),
    [Input("update-button", "n_clicks")],
    [
        State("feature-id", "value"),
        State("sample-idx", "value"),
        State("token-idx", "value"),
        State("attention-head", "value"),
        State("batch-size", "value"),
    ],
)
def update_output(n_clicks, feature_id, sample_idx, token_idx, attention_head, batch_size):
    if n_clicks < 1:
        return []

    # Placeholder for PyTorch computation
    # result = run_pytorch_computation(feature_id, sample_num, token_num, attention_head)
    result = analyze_linearized_feature(
        feature_idx=int(feature_id),
        sample_idx=int(sample_idx),
        token_idx=int(token_idx),
        model=model,
        data=data,
        encoder=sae,
        head=int(attention_head),
        batch_size=int(batch_size),
    )
    batch = data[: int(batch_size)]

    # Dataframes
    df1 = make_token_df(batch, model=model)
    df1["feature"] = utils.to_numpy(result["sae activations"][:, int(feature_id)])
    df1 = df1.sort_values("feature", ascending=False).head(20)

    df2 = make_token_df(batch, model=model)
    df2["feature"] = utils.to_numpy(result["activation scores"])
    df2 = df2.sort_values("feature", ascending=False).head(20)

    df3 = pd.DataFrame(
        dict(str_tokens=result["token strings"], feature_scores=result["token scores"].detach().cpu().numpy())
    )

    df_list = [df1, df2, df3]
    df_titles = ["SAE Activations", "Activation Scores", "Token Scores"]
    # df_list = [df3]

    # Plots
    # plot1 = visualize_topk(feature_id=feature_id, n_examples=10, model=model, pad=True, clip=10)
    plot1 = visualize_topk_plotly(feature_id=int(feature_id), n_examples=10, model=model, pad=True, clip=20)
    plot2 = plot_attn_contribs_for_example(
        model=model,
        data=data,
        example_idx=int(sample_idx),
        token_idx=int(token_idx),
        feature_mid=result["mid"],
        ov_only=False,
        batch_size=int(batch_size),
    )
    # plot_list = [plot1, plot2]
    plot_list = [plot1]

    # Generate Data Tables
    data_tables = [
        (
            html.Div(title),
            dash_table.DataTable(
                data=df.to_dict("records"),
                columns=[{"name": i, "id": i} for i in df.columns],
                style_cell={
                    "minWidth": "50px",
                    "width": "auto",
                    "textAlign": "center",
                },
                style_table={"maxHeight": "300px", "overflowY": "scroll", "margin": "auto", "overflowX": "scroll"},
                style_data_conditional=generate_style_data_conditional(df),
            ),
        )
        for df, title in zip(df_list, df_titles)
    ]
    data_tables = [item for sublist in data_tables for item in sublist]  # Flatten list

    # Generate Plots
    plots = [dcc.Graph(figure=plot, style={"overflowX": "scroll"}) for plot in plot_list]

    # Combine all elements for display
    return data_tables + plots


app.run(debug=True)

In [5]:
from sprint.loading import load_sae

load_sae(run_id="l1")

{'act_name': 'blocks.1.hook_mlp_out',
 'act_size': 512,
 'batch_size': 4096,
 'beta1': 0.9,
 'beta2': 0.99,
 'buffer_batches': 12288,
 'buffer_mult': 384,
 'buffer_size': 1572864,
 'd_mlp': 512,
 'device': 'cuda:0',
 'dict_mult': 32,
 'dict_size': 16384,
 'enc_dtype': 'fp32',
 'l1_coeff': 0.0003,
 'layer': 1,
 'lr': 0.0001,
 'model_batch_size': 512,
 'model_name': 'gelu-2l',
 'num_tokens': 2000000000,
 'remove_rare_dir': False,
 'seed': 50,
 'seq_len': 128,
 'site': 'mlp_out'}
Encoder device: cuda:0


AutoEncoder()