# SAE Features Visualization

![Collect Internal Activations](./images/decoder_only_model_internal_activations.jpg)

![Sae Features Activations](./images/internal_activations_to_sae.jpg)

## Dependencies

In [None]:
import torch
from tqdm.auto import tqdm

In [None]:
from pathlib import Path

project_dir = Path().resolve().parent
statistic_dir = project_dir / "statistics"
script_dir = project_dir / "scripts"

In [None]:
import sys

sys.path.append(str(script_dir))

In [None]:
%reload_ext autoreload
%autoreload 2

from visualization import (
    plot_all_layers,
    plot_all_lang_feature_overlap,
    plot_lang_feature_overlap_trend,
    plot_all_co_occurrence,
    plot_all_cross_co_occurrence,
    plot_all_count_box_plots,
    plot_lape_result,
    plot_umap,
    plot_ppl_change_matrix,
    generate_ppl_change_matrix,
)

from feature_visualizer import (
    generate_feature_activations_visualization,
)

from loader import (
    load_layer_to_summary,
    load_lang_to_dataset_token_activations,
    load_lang_to_dataset_token_activations_aggregate,
    load_all_interpretations,
)

from const import lang_choices_to_qualified_name, layer_to_index

from delphi.log.result_analysis import get_metrics_per_latent, load_data

## Llama 3.2-1B

In [None]:
config_xnli = {
    "model": "meta-llama/Llama-3.2-1B",
    "sae": {
        "model": "EleutherAI/sae-Llama-3.2-1B-131k",
        "num_latents": 131072,
    },
    "dataset": "facebook/xnli",
    "split": "train",
    "languages": [
        "en",
        "de",
        "fr",
        "hi",
        "es",
        "th",
        "bg",
        "ru",
        "tr",
        "vi",
    ],
    "layers": [
        "model.layers.0.mlp",
        "model.layers.1.mlp",
        "model.layers.2.mlp",
        "model.layers.3.mlp",
        "model.layers.4.mlp",
        "model.layers.5.mlp",
        "model.layers.6.mlp",
        "model.layers.7.mlp",
        "model.layers.8.mlp",
        "model.layers.9.mlp",
        "model.layers.10.mlp",
        "model.layers.11.mlp",
        "model.layers.12.mlp",
        "model.layers.13.mlp",
        "model.layers.14.mlp",
        "model.layers.15.mlp",
    ],
}

config_pawsx = {
    "model": "meta-llama/Llama-3.2-1B",
    "sae": {
        "model": "EleutherAI/sae-Llama-3.2-1B-131k",
        "num_latents": 131072,
    },
    "dataset": "google-research-datasets/paws-x",
    "split": "train",
    "languages": [
        "en",
        "de",
        "fr",
        "es",
        "ja",
        "ko",
        "zh",
    ],
    "layers": [
        "model.layers.0.mlp",
        "model.layers.1.mlp",
        "model.layers.2.mlp",
        "model.layers.3.mlp",
        "model.layers.4.mlp",
        "model.layers.5.mlp",
        "model.layers.6.mlp",
        "model.layers.7.mlp",
        "model.layers.8.mlp",
        "model.layers.9.mlp",
        "model.layers.10.mlp",
        "model.layers.11.mlp",
        "model.layers.12.mlp",
        "model.layers.13.mlp",
        "model.layers.14.mlp",
        "model.layers.15.mlp",
    ],
}

config_flores = {
    "model": "meta-llama/Llama-3.2-1B",
    "sae": {
        "model": "EleutherAI/sae-Llama-3.2-1B-131k",
        "num_latents": 131072,
    },
    "dataset": "openlanguagedata/flores_plus",
    "split": "dev",
    "languages": [
        "eng_Latn",
        "deu_Latn",
        "fra_Latn",
        "ita_Latn",
        "por_Latn",
        "hin_Deva",
        "spa_Latn",
        "tha_Thai",
        "bul_Cyrl",
        "rus_Cyrl",
        "tur_Latn",
        "vie_Latn",
        "jpn_Jpan",
        "kor_Hang",
        "cmn_Hans",
    ],
    "layers": [
        "model.layers.0.mlp",
        "model.layers.1.mlp",
        "model.layers.2.mlp",
        "model.layers.3.mlp",
        "model.layers.4.mlp",
        "model.layers.5.mlp",
        "model.layers.6.mlp",
        "model.layers.7.mlp",
        "model.layers.8.mlp",
        "model.layers.9.mlp",
        "model.layers.10.mlp",
        "model.layers.11.mlp",
        "model.layers.12.mlp",
        "model.layers.13.mlp",
        "model.layers.14.mlp",
        "model.layers.15.mlp",
    ],
}

## Visualizations

### XNLI

In [None]:
data_path_summary_xnli = (
    statistic_dir
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / config_xnli["dataset"]
    / "summary"
)

df_layers_llama_xnli = load_layer_to_summary(
    data_path_summary_xnli, config_xnli["layers"], config_xnli["languages"]
)

In [None]:
plot_all_layers(df_layers_llama_xnli, config_xnli)

In [None]:
plot_all_lang_feature_overlap(df_layers_llama_xnli, config_xnli, range_y=[0, 40_000])

In [None]:
plot_lang_feature_overlap_trend(df_layers_llama_xnli, config_xnli)

In [None]:
plot_all_co_occurrence(df_layers_llama_xnli, config_xnli)

In [None]:
plot_all_count_box_plots(df_layers_llama_xnli, config_xnli)

In [None]:
data_path_dataset_token_activations_xnli = (
    statistic_dir
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / config_xnli["dataset"]
    / "dataset_token_activations"
)

df_dataset_token_activations_xnli = load_lang_to_dataset_token_activations_aggregate(
    data_path_dataset_token_activations_xnli,
    config_xnli["layers"],
    config_xnli["languages"],
)

In [None]:
df_dataset_token_activations_xnli.rename(
    columns={
        "index": "sae_feature_number",
        "count": "token_count",
    }
).to_csv("sae_features_facebook_xnli.csv", index=False)

### PAWS-X

In [None]:
data_path_pawsx = (
    statistic_dir
    / config_pawsx["model"]
    / config_pawsx["sae"]["model"]
    / config_pawsx["dataset"]
    / "summary"
)

In [None]:
df_layers_llama_pawsx = load_layer_to_summary(
    data_path_pawsx, config_pawsx["layers"], config_pawsx["languages"]
)

In [None]:
plot_all_layers(df_layers_llama_pawsx, config_pawsx)

In [None]:
plot_all_lang_feature_overlap(df_layers_llama_pawsx, config_pawsx, range_y=[0, 40_000])

In [None]:
plot_lang_feature_overlap_trend(
    df_layers_llama_pawsx,
    config_pawsx,
)

In [None]:
plot_all_co_occurrence(df_layers_llama_pawsx, config_pawsx)

In [None]:
plot_all_count_box_plots(df_layers_llama_pawsx, config_pawsx)

In [None]:
data_path_dataset_token_activations_pawsx = (
    statistic_dir
    / config_pawsx["model"]
    / config_pawsx["sae"]["model"]
    / config_pawsx["dataset"]
    / "dataset_token_activations"
)

df_dataset_token_activations_pawsx = load_lang_to_dataset_token_activations_aggregate(
    data_path_dataset_token_activations_pawsx,
    config_pawsx["layers"],
    config_pawsx["languages"],
)

In [None]:
df_dataset_token_activations_pawsx.rename(
    columns={
        "index": "sae_feature_number",
        "count": "token_count",
    }
).to_csv("sae_features_google-research-datasets_paws-x.csv", index=False)

#### XNLI and PAWS-X

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_xnli, config_xnli, df_layers_llama_pawsx, config_pawsx
)

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_xnli,
    config_xnli,
    df_layers_llama_pawsx,
    config_pawsx,
    specific_feature_lang_count=1,
)

### FLORES+

In [None]:
data_path_flores = (
    statistic_dir
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / config_flores["dataset"]
    / "summary"
)

In [None]:
df_layers_llama_flores = load_layer_to_summary(
    data_path_flores, config_flores["layers"], config_flores["languages"]
)

In [None]:
plot_all_layers(df_layers_llama_flores, config_flores)

In [None]:
plot_all_lang_feature_overlap(
    df_layers_llama_flores, config_flores, range_y=[0, 40_000]
)

In [None]:
plot_lang_feature_overlap_trend(
    df_layers_llama_flores,
    config_flores,
)

In [None]:
plot_all_co_occurrence(df_layers_llama_flores, config_flores)

In [None]:
plot_all_count_box_plots(df_layers_llama_flores, config_flores)

In [None]:
data_path_dataset_token_activations_flores = (
    statistic_dir
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / config_flores["dataset"]
    / "dataset_token_activations"
)

df_dataset_token_activations_flores = load_lang_to_dataset_token_activations_aggregate(
    data_path_dataset_token_activations_flores,
    config_flores["layers"],
    config_flores["languages"],
)

In [None]:
df_dataset_token_activations_flores.rename(
    columns={
        "index": "sae_feature_number",
        "count": "token_count",
    }
).to_csv("sae_features_gsarti_flores_101.csv", index=False)

#### Flores-101 with XNLI and PAWS-X

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_flores, config_flores, df_layers_llama_xnli, config_xnli
)

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_flores,
    config_flores,
    df_layers_llama_xnli,
    config_xnli,
    specific_feature_lang_count=1,
)

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_flores, config_flores, df_layers_llama_pawsx, config_pawsx
)

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_flores,
    config_flores,
    df_layers_llama_pawsx,
    config_pawsx,
    specific_feature_lang_count=1,
)

### Feature Index Visualization

In [None]:
data_path_dataset_token_activations_xnli = (
    statistic_dir
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / config_xnli["dataset"]
    / "dataset_token_activations"
)

data_path_dataset_token_activations_pawsx = (
    statistic_dir
    / config_pawsx["model"]
    / config_pawsx["sae"]["model"]
    / config_pawsx["dataset"]
    / "dataset_token_activations"
)

data_path_dataset_token_activations_flores = (
    statistic_dir
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / config_flores["dataset"]
    / "dataset_token_activations"
)

In [None]:
model = config_xnli["model"].split("/")[-1]
sae_model_name = config_xnli["sae"]["model"].split("/")[-1]

out_path = (
    project_dir / "visualization" / "feature_index" / model / sae_model_name
)

In [None]:
feature_index = 25
layer = "model.layers.0.mlp"

model = config_flores["model"]
sae_model = config_flores["sae"]["model"]
layers = config_flores["layers"]

In [None]:
lang_to_dataset_token_activations_xnli = load_lang_to_dataset_token_activations(
    data_path_dataset_token_activations_xnli,
    layer,
    config_xnli["languages"],
    [feature_index],
)

lang_to_dataset_token_activations_pawsx = load_lang_to_dataset_token_activations(
    data_path_dataset_token_activations_pawsx,
    layer,
    config_pawsx["languages"],
    [feature_index],
)

lang_to_dataset_token_activations_flores = load_lang_to_dataset_token_activations(
    data_path_dataset_token_activations_flores,
    layer,
    config_flores["languages"],
    [feature_index],
)

dataset_lang_to_dataset_token_activations = {
    "xnli": {
        "dataset_token_activations": lang_to_dataset_token_activations_xnli,
        "config": {**config_xnli},
    },
    "paws-x": {
        "dataset_token_activations": lang_to_dataset_token_activations_pawsx,
        "config": {**config_pawsx},
    },
    "flores": {
        "dataset_token_activations": lang_to_dataset_token_activations_flores,
        "config": {**config_flores},
    },
}

In [None]:
feature_info = {
    "feature_index": feature_index,
    "layer": layer,
    "lang": "None",
    "selected_prob": "-",
    "entropy": "-",
    "interpretation": "-",
    "metrics": [
        {
            "score_type": "-",
            "true_positives": "-",
            "true_negatives": "-",
            "false_positives": "-",
            "false_negatives": "-",
            "total_examples": "-",
            "total_positives": "-",
            "total_negatives": "-",
            "failed_count": "-",
            "precision": "-",
            "recall": "-",
            "f1_score": "-",
            "accuracy": "-",
            "true_positive_rate": "-",
            "true_negative_rate": "-",
            "false_positive_rate": "-",
            "false_negative_rate": "-",
            "positive_class_ratio": "-",
            "negative_class_ratio": "-",
            "auc": None,
        }
    ],
}

generate_feature_activations_visualization(
    dataset_lang_to_dataset_token_activations,
    feature_index,
    feature_info,
    model,
    layer,
    sae_model,
    out_path,
    lang_choices_to_qualified_name,
)

### LAPE

In [None]:
lape_top_10_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_top_10_by_entropy.pt"
)

lape_top_10_result = torch.load(lape_top_10_result_path, weights_only=False)

plot_lape_result(
    lape_top_10_result,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/sae_features/lape_top_10_by_entropy"
    ),
    title="Distribution of Top-10 Language-Specific Features by Entropy",
)

In [None]:
lape_top_10_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_top_10_by_freq.pt"
)

lape_top_10_result = torch.load(lape_top_10_result_path, weights_only=False)

plot_lape_result(
    lape_top_10_result,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/sae_features/lape_top_10_by_freq"
    ),
    title="Distribution of Top-10 Language-Specific Features by Frequency",
)

In [None]:
lape_neuron_result_path = (
    project_dir
    / "mlp_acts_specific"
    / config_xnli["model"]
    / "lape_neuron.pt"
)

lape_neuron_result = torch.load(lape_neuron_result_path, weights_only=False)

plot_lape_result(
    lape_neuron_result,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/lape_neuron"
    ),
    title="Distribution of Language-specific Neurons",
)

In [None]:
lape_top_1_per_layer_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_top_1_per_layer_by_entropy.pt"
)

lape_result_top_1_per_layer = torch.load(lape_top_1_per_layer_result_path, weights_only=False)

plot_lape_result(
    lape_result_top_1_per_layer,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/sae_features/lape_top_1_per_layer_by_entropy"
    ),
    title="Distribution of Top-1 per Layer Language-Specific Features by Entropy",
)

In [None]:
lape_top_1_per_layer_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_top_1_per_layer_by_freq.pt"
)

lape_result_top_1_per_layer = torch.load(lape_top_1_per_layer_result_path, weights_only=False)

plot_lape_result(
    lape_result_top_1_per_layer,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/sae_features/lape_top_1_per_layer_by_freq"
    ),
    title="Distribution of Top-1 per Layer Language-Specific Features by Frequency",
)

In [None]:
lape_all_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_all.pt"
)

lape_all_result = torch.load(lape_all_result_path, weights_only=False)

plot_lape_result(
    lape_all_result,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/sae_features/lape_all"
    ),
    title="Distribution of LAPE for all Language-Specific Features",
)

### Language-Specific Features Visualization

In [None]:
lape_all_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_all.pt"
)

lape_all_result = torch.load(lape_all_result_path, weights_only=False)

In [None]:
def convert_df_metrics_to_nested_dict(df):
    result = {}
    
    for _, row in df.iterrows():
        layer = row['layer']
        latent_idx = row['latent_idx']
        values = row.drop(['layer', 'latent_idx'])
        values = values.apply(lambda x: round(x, 3) if isinstance(x, float) else x)

        layer_key = f"model.{layer}"

        if layer_key not in result:
            result[layer_key] = {}
        if latent_idx not in result[layer_key]:
            result[layer_key][latent_idx] = []

        result[layer_key][latent_idx].append(values.to_dict())
    
    return result

In [None]:
interpretation_folder = project_dir / "interpret_sae_features" / "explanations"

scores_path = (
    project_dir
    / "interpret_sae_features"
    / "scores"
)

visualize_path = (
    project_dir
    / "visualization"
    / "interpret_sae_features"
    / "scores"
)

hookpoints = [
    "layers.0.mlp",
    "layers.1.mlp",
    "layers.2.mlp",
    "layers.3.mlp",
    "layers.4.mlp",
    "layers.5.mlp",
    "layers.6.mlp",
    "layers.7.mlp",
    "layers.8.mlp",
    "layers.9.mlp",
    "layers.10.mlp",
    "layers.11.mlp",
    "layers.12.mlp",
    "layers.13.mlp",
    "layers.14.mlp",
    "layers.15.mlp",
]

In [None]:
interpretations = load_all_interpretations(interpretation_folder)
latent_df, counts = load_data(scores_path, hookpoints)
df_metrics = get_metrics_per_latent(latent_df)
metrics = convert_df_metrics_to_nested_dict(df_metrics)

In [None]:
model = config_flores["model"]
sae_model = config_flores["sae"]["model"]
model_name = config_flores["model"].split("/")[-1]
sae_model_name = config_flores["sae"]["model"].split("/")[-1]
layers = config_flores["layers"]

sorted_lang = lape_all_result["sorted_lang"]

for lang in tqdm(sorted_lang, desc="Processing languages"):
    lang_index = sorted_lang.index(lang)

    for layer in tqdm(layers, desc="Processing layers", leave=False):
        layer_index = layer_to_index[layer]
        lang_final_indices = lape_all_result["final_indice"][lang_index][
            layer_index
        ].tolist()

        if len(lang_final_indices) == 0:
            continue

        layer = layers[layer_index]

        lang_to_dataset_token_activations_xnli = load_lang_to_dataset_token_activations(
            data_path_dataset_token_activations_xnli,
            layer,
            config_xnli["languages"],
            lang_final_indices,
        )

        lang_to_dataset_token_activations_pawsx = (
            load_lang_to_dataset_token_activations(
                data_path_dataset_token_activations_pawsx,
                layer,
                config_pawsx["languages"],
                lang_final_indices,
            )
        )

        lang_to_dataset_token_activations_flores = (
            load_lang_to_dataset_token_activations(
                data_path_dataset_token_activations_flores,
                layer,
                config_flores["languages"],
                lang_final_indices,
            )
        )

        dataset_lang_to_dataset_token_activations = {
            "xnli": {
                "dataset_token_activations": lang_to_dataset_token_activations_xnli,
                "config": {**config_xnli},
            },
            "paws-x": {
                "dataset_token_activations": lang_to_dataset_token_activations_pawsx,
                "config": {**config_pawsx},
            },
            "flores": {
                "dataset_token_activations": lang_to_dataset_token_activations_flores,
                "config": {**config_flores},
            },
        }

        out_path = (
            project_dir
            / "visualization"
            / "feature_index"
            / model_name
            / sae_model_name
            / layer
            / lang
        )

        selected_probs = lape_all_result['features_info'][lang]["selected_probs"]
        entropies = lape_all_result['features_info'][lang]["entropies"]
        
        for feature_index in tqdm(lang_final_indices, desc="Processing indices", leave=False):
            try:
                file_path = out_path / f"feature_{feature_index}.html"

                if file_path.exists():
                    continue
                    
                arg_index = lape_all_result['features_info'][lang]["indicies"].index((layer_index, feature_index))

                feature_info = {
                    "feature_index": feature_index,
                    "layer": layer,
                    "lang": lang,
                    "selected_prob": round(selected_probs[arg_index].item(), ndigits=3),
                    "entropy": round(entropies[arg_index].item(), ndigits=3),
                    "interpretation": interpretations[layer][feature_index],
                    "metrics": metrics[layer][feature_index],
                }

                generate_feature_activations_visualization(
                    dataset_lang_to_dataset_token_activations,
                    feature_index,
                    feature_info,
                    model,
                    layer,
                    sae_model,
                    out_path,
                    lang_choices_to_qualified_name,
                    examples_per_section=40,
                )
            except Exception as e:
                print(f"Error processing {lang} - {layer} - {feature_index}")
                print(e)

### UMAP

In [None]:
lape_all_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_umap.pt"
)

In [None]:
lape_all_result = torch.load(lape_all_result_path, weights_only=False)

In [None]:
umap_output_dir = (
    project_dir
    / "visualization"
    / "umap"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
)

In [None]:
model = config_xnli["model"]
sae_model = config_xnli["sae"]["model"]
layers = config_xnli["layers"]

plot_umap(lape_all_result, layers, model, sae_model, umap_output_dir)

### Perplexity

In [None]:
normal_ppl_output_path = (
    project_dir
    / "ppl"
    / config_flores["model"]
    / config_flores["dataset"]
    / "normal"
    / "ppl.pt"
)

normal_ppl_result = torch.load(normal_ppl_output_path, weights_only=False)

#### Neuron Intervention

In [None]:
out_path = (
    project_dir
    / "ppl"
    / config_flores["model"]
    / config_flores["dataset"]
    / "neuron_intervention"
)

intervened_neuron_ppl_results = {
    lang_choices_to_qualified_name[intervened_lang]: torch.load(
        out_path / f"ppl_{intervened_lang}.pt", weights_only=False
    )
    for intervened_lang in config_flores["languages"]
}

In [None]:
out_path = (
    project_dir
    / "visualization"
    / "ppl"
    / config_flores["model"]
    / config_flores["dataset"]
    / "neuron_intervention"
    / "ppl_change_matrix.html"
)

plot_ppl_change_matrix(
    config_flores["languages"],
    normal_ppl_result,
    intervened_neuron_ppl_results,
    out_path,
    title="PPL Change Matrix for Neuron Interventions",
    num_examples=1000,
)

#### SAE Feature Intervention

In [None]:
in_path = (
    project_dir
    / "ppl"
    / config_flores["model"]
    / config_flores["dataset"]
    / "sae_intervention"
)

#### All Layers

In [None]:
configs = [
    "top_10/entropy/max/mult_0.2",
    "top_10/entropy/max/mult_-0.2",
    "top_1_per_layer/entropy/avg/mult_1",
    "top_1_per_layer/entropy/avg/mult_-1",
    "top_1_per_layer/entropy/max/mult_0.2",
    "top_1_per_layer/entropy/max/mult_-0.2",
    "top_1_per_layer/freq/avg/mult_-1",
    "all/entropy/max/mult_0.2"
]

generate_ppl_change_matrix(
    configs,
    config_flores["model"],
    config_flores["dataset"],
    config_flores["languages"],
    in_path,
    normal_ppl_result,
)