# SAE Features Visualization

![Collect Internal Activations](./images/decoder_only_model_internal_activations.jpg)

![Sae Features Activations](./images/internal_activations_to_sae.jpg)

## Dependencies

In [30]:
import torch
import os
from tqdm.auto import tqdm
import json

In [23]:
from pathlib import Path

project_dir = Path().resolve().parent
statistic_dir = project_dir / "statistics"
script_dir = project_dir / "scripts"

In [24]:
import sys

sys.path.append(str(script_dir))

In [25]:
%reload_ext autoreload
%autoreload 2

from visualization import (
    plot_all_layers,
    plot_all_lang_feature_overlap,
    plot_lang_feature_overlap_trend,
    plot_all_co_occurrence,
    plot_all_cross_co_occurrence,
    plot_all_count_box_plots,
    plot_lape_result,
    plot_umap,
    plot_ppl_change_matrix,
    generate_ppl_change_matrix,
    plot_metrics,
)

from feature_visualizer import (
    generate_feature_activations_visualization,
)

from loader import (
    load_layer_to_summary,
    load_lang_to_dataset_token_activations,
    load_lang_to_dataset_token_activations_aggregate,
    load_all_interpretations,
)

from const import lang_choices_to_qualified_name, layer_to_index

from delphi.log.result_analysis import get_metrics_per_latent, load_data

Triton not installed, using eager implementation of sparse decoder.


## Llama 3.2-1B

In [26]:
config_xnli = {
    "model": "meta-llama/Llama-3.2-1B",
    "sae": {
        "model": "EleutherAI/sae-Llama-3.2-1B-131k",
        "num_latents": 131072,
    },
    "dataset": "facebook/xnli",
    "split": "train",
    "languages": [
        "en",
        "de",
        "fr",
        "hi",
        "es",
        "th",
        "bg",
        "ru",
        "tr",
        "vi",
    ],
    "layers": [
        "model.layers.0.mlp",
        "model.layers.1.mlp",
        "model.layers.2.mlp",
        "model.layers.3.mlp",
        "model.layers.4.mlp",
        "model.layers.5.mlp",
        "model.layers.6.mlp",
        "model.layers.7.mlp",
        "model.layers.8.mlp",
        "model.layers.9.mlp",
        "model.layers.10.mlp",
        "model.layers.11.mlp",
        "model.layers.12.mlp",
        "model.layers.13.mlp",
        "model.layers.14.mlp",
        "model.layers.15.mlp",
    ],
}

config_pawsx = {
    "model": "meta-llama/Llama-3.2-1B",
    "sae": {
        "model": "EleutherAI/sae-Llama-3.2-1B-131k",
        "num_latents": 131072,
    },
    "dataset": "google-research-datasets/paws-x",
    "split": "train",
    "languages": [
        "en",
        "de",
        "fr",
        "es",
        "ja",
        "ko",
        "zh",
    ],
    "layers": [
        "model.layers.0.mlp",
        "model.layers.1.mlp",
        "model.layers.2.mlp",
        "model.layers.3.mlp",
        "model.layers.4.mlp",
        "model.layers.5.mlp",
        "model.layers.6.mlp",
        "model.layers.7.mlp",
        "model.layers.8.mlp",
        "model.layers.9.mlp",
        "model.layers.10.mlp",
        "model.layers.11.mlp",
        "model.layers.12.mlp",
        "model.layers.13.mlp",
        "model.layers.14.mlp",
        "model.layers.15.mlp",
    ],
}

config_flores = {
    "model": "meta-llama/Llama-3.2-1B",
    "sae": {
        "model": "EleutherAI/sae-Llama-3.2-1B-131k",
        "num_latents": 131072,
    },
    "dataset": "openlanguagedata/flores_plus",
    "split": "dev",
    "languages": [
        "eng_Latn",
        "deu_Latn",
        "fra_Latn",
        "ita_Latn",
        "por_Latn",
        "hin_Deva",
        "spa_Latn",
        "tha_Thai",
        "bul_Cyrl",
        "rus_Cyrl",
        "tur_Latn",
        "vie_Latn",
        "jpn_Jpan",
        "kor_Hang",
        "cmn_Hans",
    ],
    "layers": [
        "model.layers.0.mlp",
        "model.layers.1.mlp",
        "model.layers.2.mlp",
        "model.layers.3.mlp",
        "model.layers.4.mlp",
        "model.layers.5.mlp",
        "model.layers.6.mlp",
        "model.layers.7.mlp",
        "model.layers.8.mlp",
        "model.layers.9.mlp",
        "model.layers.10.mlp",
        "model.layers.11.mlp",
        "model.layers.12.mlp",
        "model.layers.13.mlp",
        "model.layers.14.mlp",
        "model.layers.15.mlp",
    ],
}

## Visualizations

### XNLI

In [None]:
data_path_summary_xnli = (
    statistic_dir
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / config_xnli["dataset"]
    / "summary"
)

df_layers_llama_xnli = load_layer_to_summary(
    data_path_summary_xnli, config_xnli["layers"], config_xnli["languages"]
)

In [None]:
plot_all_layers(df_layers_llama_xnli, config_xnli)

In [None]:
plot_all_lang_feature_overlap(df_layers_llama_xnli, config_xnli, range_y=[0, 40_000])

In [None]:
plot_lang_feature_overlap_trend(df_layers_llama_xnli, config_xnli)

In [None]:
plot_all_co_occurrence(df_layers_llama_xnli, config_xnli)

In [None]:
plot_all_count_box_plots(df_layers_llama_xnli, config_xnli)

In [None]:
data_path_dataset_token_activations_xnli = (
    statistic_dir
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / config_xnli["dataset"]
    / "dataset_token_activations"
)

df_dataset_token_activations_xnli = load_lang_to_dataset_token_activations_aggregate(
    data_path_dataset_token_activations_xnli,
    config_xnli["layers"],
    config_xnli["languages"],
)

In [None]:
df_dataset_token_activations_xnli.rename(
    columns={
        "index": "sae_feature_number",
        "count": "token_count",
    }
).to_csv("sae_features_facebook_xnli.csv", index=False)

### PAWS-X

In [None]:
data_path_pawsx = (
    statistic_dir
    / config_pawsx["model"]
    / config_pawsx["sae"]["model"]
    / config_pawsx["dataset"]
    / "summary"
)

In [None]:
df_layers_llama_pawsx = load_layer_to_summary(
    data_path_pawsx, config_pawsx["layers"], config_pawsx["languages"]
)

In [None]:
plot_all_layers(df_layers_llama_pawsx, config_pawsx)

In [None]:
plot_all_lang_feature_overlap(df_layers_llama_pawsx, config_pawsx, range_y=[0, 40_000])

In [None]:
plot_lang_feature_overlap_trend(
    df_layers_llama_pawsx,
    config_pawsx,
)

In [None]:
plot_all_co_occurrence(df_layers_llama_pawsx, config_pawsx)

In [None]:
plot_all_count_box_plots(df_layers_llama_pawsx, config_pawsx)

In [None]:
data_path_dataset_token_activations_pawsx = (
    statistic_dir
    / config_pawsx["model"]
    / config_pawsx["sae"]["model"]
    / config_pawsx["dataset"]
    / "dataset_token_activations"
)

df_dataset_token_activations_pawsx = load_lang_to_dataset_token_activations_aggregate(
    data_path_dataset_token_activations_pawsx,
    config_pawsx["layers"],
    config_pawsx["languages"],
)

In [None]:
df_dataset_token_activations_pawsx.rename(
    columns={
        "index": "sae_feature_number",
        "count": "token_count",
    }
).to_csv("sae_features_google-research-datasets_paws-x.csv", index=False)

#### XNLI and PAWS-X

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_xnli, config_xnli, df_layers_llama_pawsx, config_pawsx
)

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_xnli,
    config_xnli,
    df_layers_llama_pawsx,
    config_pawsx,
    specific_feature_lang_count=1,
)

### FLORES+

In [None]:
data_path_flores = (
    statistic_dir
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / config_flores["dataset"]
    / "summary"
)

In [None]:
df_layers_llama_flores = load_layer_to_summary(
    data_path_flores, config_flores["layers"], config_flores["languages"]
)

In [None]:
plot_all_layers(df_layers_llama_flores, config_flores)

In [None]:
plot_all_lang_feature_overlap(
    df_layers_llama_flores, config_flores, range_y=[0, 40_000]
)

In [None]:
plot_lang_feature_overlap_trend(
    df_layers_llama_flores,
    config_flores,
)

In [None]:
plot_all_co_occurrence(df_layers_llama_flores, config_flores)

In [None]:
plot_all_count_box_plots(df_layers_llama_flores, config_flores)

In [None]:
data_path_dataset_token_activations_flores = (
    statistic_dir
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / config_flores["dataset"]
    / "dataset_token_activations"
)

df_dataset_token_activations_flores = load_lang_to_dataset_token_activations_aggregate(
    data_path_dataset_token_activations_flores,
    config_flores["layers"],
    config_flores["languages"],
)

In [None]:
df_dataset_token_activations_flores.rename(
    columns={
        "index": "sae_feature_number",
        "count": "token_count",
    }
).to_csv("sae_features_gsarti_flores_101.csv", index=False)

#### Flores-101 with XNLI and PAWS-X

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_flores, config_flores, df_layers_llama_xnli, config_xnli
)

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_flores,
    config_flores,
    df_layers_llama_xnli,
    config_xnli,
    specific_feature_lang_count=1,
)

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_flores, config_flores, df_layers_llama_pawsx, config_pawsx
)

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_flores,
    config_flores,
    df_layers_llama_pawsx,
    config_pawsx,
    specific_feature_lang_count=1,
)

### Feature Index Visualization

In [None]:
data_path_dataset_token_activations_xnli = (
    statistic_dir
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / config_xnli["dataset"]
    / "dataset_token_activations"
)

data_path_dataset_token_activations_pawsx = (
    statistic_dir
    / config_pawsx["model"]
    / config_pawsx["sae"]["model"]
    / config_pawsx["dataset"]
    / "dataset_token_activations"
)

data_path_dataset_token_activations_flores = (
    statistic_dir
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / config_flores["dataset"]
    / "dataset_token_activations"
)

In [None]:
model = config_xnli["model"].split("/")[-1]
sae_model_name = config_xnli["sae"]["model"].split("/")[-1]

out_path = (
    project_dir / "visualization" / "feature_index" / model / sae_model_name
)

In [None]:
feature_index = 25
layer = "model.layers.0.mlp"

model = config_flores["model"]
sae_model = config_flores["sae"]["model"]
layers = config_flores["layers"]

In [None]:
lang_to_dataset_token_activations_xnli = load_lang_to_dataset_token_activations(
    data_path_dataset_token_activations_xnli,
    layer,
    config_xnli["languages"],
    [feature_index],
)

lang_to_dataset_token_activations_pawsx = load_lang_to_dataset_token_activations(
    data_path_dataset_token_activations_pawsx,
    layer,
    config_pawsx["languages"],
    [feature_index],
)

lang_to_dataset_token_activations_flores = load_lang_to_dataset_token_activations(
    data_path_dataset_token_activations_flores,
    layer,
    config_flores["languages"],
    [feature_index],
)

dataset_lang_to_dataset_token_activations = {
    "xnli": {
        "dataset_token_activations": lang_to_dataset_token_activations_xnli,
        "config": {**config_xnli},
    },
    "paws-x": {
        "dataset_token_activations": lang_to_dataset_token_activations_pawsx,
        "config": {**config_pawsx},
    },
    "flores": {
        "dataset_token_activations": lang_to_dataset_token_activations_flores,
        "config": {**config_flores},
    },
}

In [None]:
feature_info = {
    "feature_index": feature_index,
    "layer": layer,
    "lang": "None",
    "selected_prob": "-",
    "entropy": "-",
    "interpretation": "-",
    "metrics": [
        {
            "score_type": "-",
            "true_positives": "-",
            "true_negatives": "-",
            "false_positives": "-",
            "false_negatives": "-",
            "total_examples": "-",
            "total_positives": "-",
            "total_negatives": "-",
            "failed_count": "-",
            "precision": "-",
            "recall": "-",
            "f1_score": "-",
            "accuracy": "-",
            "true_positive_rate": "-",
            "true_negative_rate": "-",
            "false_positive_rate": "-",
            "false_negative_rate": "-",
            "positive_class_ratio": "-",
            "negative_class_ratio": "-",
            "auc": None,
        }
    ],
}

generate_feature_activations_visualization(
    dataset_lang_to_dataset_token_activations,
    feature_index,
    feature_info,
    model,
    layer,
    sae_model,
    out_path,
    lang_choices_to_qualified_name,
)

### LAPE

In [30]:
lape_top_10_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_top_10_by_entropy.pt"
)

lape_top_10_result = torch.load(lape_top_10_result_path, weights_only=False)

plot_lape_result(
    lape_top_10_result,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/sae_features/lape_top_10_by_entropy"
    ),
    title="Distribution of Top-10 Language-Specific Features by Entropy",
)

In [31]:
lape_top_10_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_top_10_by_freq.pt"
)

lape_top_10_result = torch.load(lape_top_10_result_path, weights_only=False)

plot_lape_result(
    lape_top_10_result,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/sae_features/lape_top_10_by_freq"
    ),
    title="Distribution of Top-10 Language-Specific Features by Frequency",
)

In [32]:
lape_neuron_result_path = (
    project_dir
    / "mlp_acts_specific"
    / config_xnli["model"]
    / "lape_neuron.pt"
)

lape_neuron_result = torch.load(lape_neuron_result_path, weights_only=False)

plot_lape_result(
    lape_neuron_result,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/lape_neuron"
    ),
    title="Distribution of Language-specific Neurons",
)

In [33]:
lape_top_1_per_layer_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_top_1_per_layer_by_entropy.pt"
)

lape_result_top_1_per_layer = torch.load(lape_top_1_per_layer_result_path, weights_only=False)

plot_lape_result(
    lape_result_top_1_per_layer,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/sae_features/lape_top_1_per_layer_by_entropy"
    ),
    title="Distribution of Top-1 per Layer Language-Specific Features by Entropy",
)

In [34]:
lape_top_1_per_layer_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_top_1_per_layer_by_freq.pt"
)

lape_result_top_1_per_layer = torch.load(lape_top_1_per_layer_result_path, weights_only=False)

plot_lape_result(
    lape_result_top_1_per_layer,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/sae_features/lape_top_1_per_layer_by_freq"
    ),
    title="Distribution of Top-1 per Layer Language-Specific Features by Frequency",
)

In [16]:
lape_all_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_all.pt"
)

lape_all_result = torch.load(lape_all_result_path, weights_only=False)

plot_lape_result(
    lape_all_result,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/sae_features/lape_all"
    ),
    title="Distribution of LAPE for all Language-Specific Features",
)

### Language-Specific Features Visualization

In [None]:
lape_all_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_all.pt"
)

lape_all_result = torch.load(lape_all_result_path, weights_only=False)

In [None]:
def convert_df_metrics_to_nested_dict(df):
    result = {}
    
    for _, row in df.iterrows():
        layer = row['layer']
        latent_idx = row['latent_idx']
        values = row.drop(['layer', 'latent_idx'])
        values = values.apply(lambda x: round(x, 3) if isinstance(x, float) else x)

        layer_key = f"model.{layer}"

        if layer_key not in result:
            result[layer_key] = {}
        if latent_idx not in result[layer_key]:
            result[layer_key][latent_idx] = []

        result[layer_key][latent_idx].append(values.to_dict())
    
    return result

In [None]:
interpretation_folder = project_dir / "interpret_sae_features" / "explanations"

scores_path = (
    project_dir
    / "interpret_sae_features"
    / "scores"
)

visualize_path = (
    project_dir
    / "visualization"
    / "interpret_sae_features"
    / "scores"
)

hookpoints = [
    "layers.0.mlp",
    "layers.1.mlp",
    "layers.2.mlp",
    "layers.3.mlp",
    "layers.4.mlp",
    "layers.5.mlp",
    "layers.6.mlp",
    "layers.7.mlp",
    "layers.8.mlp",
    "layers.9.mlp",
    "layers.10.mlp",
    "layers.11.mlp",
    "layers.12.mlp",
    "layers.13.mlp",
    "layers.14.mlp",
    "layers.15.mlp",
]

In [None]:
interpretations = load_all_interpretations(interpretation_folder)
latent_df, counts = load_data(scores_path, hookpoints)
df_metrics = get_metrics_per_latent(latent_df)
metrics = convert_df_metrics_to_nested_dict(df_metrics)

In [None]:
model = config_flores["model"]
sae_model = config_flores["sae"]["model"]
model_name = config_flores["model"].split("/")[-1]
sae_model_name = config_flores["sae"]["model"].split("/")[-1]
layers = config_flores["layers"]

sorted_lang = lape_all_result["sorted_lang"]

for lang in tqdm(sorted_lang, desc="Processing languages"):
    lang_index = sorted_lang.index(lang)

    for layer in tqdm(layers, desc="Processing layers", leave=False):
        layer_index = layer_to_index[layer]
        lang_final_indices = lape_all_result["final_indice"][lang_index][
            layer_index
        ].tolist()

        if len(lang_final_indices) == 0:
            continue

        layer = layers[layer_index]

        lang_to_dataset_token_activations_xnli = load_lang_to_dataset_token_activations(
            data_path_dataset_token_activations_xnli,
            layer,
            config_xnli["languages"],
            lang_final_indices,
        )

        lang_to_dataset_token_activations_pawsx = (
            load_lang_to_dataset_token_activations(
                data_path_dataset_token_activations_pawsx,
                layer,
                config_pawsx["languages"],
                lang_final_indices,
            )
        )

        lang_to_dataset_token_activations_flores = (
            load_lang_to_dataset_token_activations(
                data_path_dataset_token_activations_flores,
                layer,
                config_flores["languages"],
                lang_final_indices,
            )
        )

        dataset_lang_to_dataset_token_activations = {
            "xnli": {
                "dataset_token_activations": lang_to_dataset_token_activations_xnli,
                "config": {**config_xnli},
            },
            "paws-x": {
                "dataset_token_activations": lang_to_dataset_token_activations_pawsx,
                "config": {**config_pawsx},
            },
            "flores": {
                "dataset_token_activations": lang_to_dataset_token_activations_flores,
                "config": {**config_flores},
            },
        }

        out_path = (
            project_dir
            / "visualization"
            / "feature_index"
            / model_name
            / sae_model_name
            / layer
            / lang
        )

        selected_probs = lape_all_result['features_info'][lang]["selected_probs"]
        entropies = lape_all_result['features_info'][lang]["entropies"]
        
        for feature_index in tqdm(lang_final_indices, desc="Processing indices", leave=False):
            try:
                file_path = out_path / f"feature_{feature_index}.html"

                if file_path.exists():
                    continue
                    
                arg_index = lape_all_result['features_info'][lang]["indicies"].index((layer_index, feature_index))

                feature_info = {
                    "feature_index": feature_index,
                    "layer": layer,
                    "lang": lang,
                    "selected_prob": round(selected_probs[arg_index].item(), ndigits=3),
                    "entropy": round(entropies[arg_index].item(), ndigits=3),
                    "interpretation": interpretations[layer][feature_index],
                    "metrics": metrics[layer][feature_index],
                }

                generate_feature_activations_visualization(
                    dataset_lang_to_dataset_token_activations,
                    feature_index,
                    feature_info,
                    model,
                    layer,
                    sae_model,
                    out_path,
                    lang_choices_to_qualified_name,
                    examples_per_section=40,
                )
            except Exception as e:
                print(f"Error processing {lang} - {layer} - {feature_index}")
                print(e)

### UMAP

In [None]:
lape_all_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_umap.pt"
)

In [None]:
lape_all_result = torch.load(lape_all_result_path, weights_only=False)

In [None]:
umap_output_dir = (
    project_dir
    / "visualization"
    / "umap"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
)

In [None]:
model = config_xnli["model"]
sae_model = config_xnli["sae"]["model"]
layers = config_xnli["layers"]

plot_umap(lape_all_result, layers, model, sae_model, umap_output_dir)

### Perplexity

In [62]:
normal_ppl_output_path = (
    project_dir
    / "ppl"
    / config_flores["model"]
    / config_flores["dataset"]
    / "normal"
    / "ppl.pt"
)

normal_ppl_result = torch.load(normal_ppl_output_path, weights_only=False)

#### Neuron Intervention

In [68]:
out_path = (
    project_dir
    / "ppl"
    / config_flores["model"]
    / config_flores["dataset"]
    / "neuron_intervention"
)

intervened_neuron_ppl_results = {
    lang_choices_to_qualified_name[intervened_lang]: torch.load(
        out_path / f"ppl_{intervened_lang}.pt", weights_only=False
    )
    for intervened_lang in config_flores["languages"]
}

In [69]:
out_path = (
    project_dir
    / "visualization"
    / "ppl"
    / config_flores["model"]
    / config_flores["dataset"]
    / "neuron_intervention"
    / "ppl_change_matrix.html"
)

plot_ppl_change_matrix(
    config_flores["languages"],
    normal_ppl_result,
    intervened_neuron_ppl_results,
    out_path,
    title="PPL Change Matrix for Neuron Interventions",
    num_examples=1000,
)

#### SAE Feature Intervention

In [65]:
in_path = (
    project_dir
    / "ppl"
    / config_flores["model"]
    / config_flores["dataset"]
    / "sae_intervention"
)

#### All Layers

In [67]:
configs = [
    "top_10/entropy/max/mult_0.2",
    "top_10/entropy/max/mult_-0.2",
    "top_1_per_layer/entropy/avg/mult_1",
    "top_1_per_layer/entropy/avg/mult_-1",
    "top_1_per_layer/entropy/max/mult_0.2",
    "top_1_per_layer/entropy/max/mult_-0.2",
    "top_1_per_layer/freq/avg/mult_-1",
    "all/entropy/max/mult_0.2",
    "all/entropy/max/mult_-0.2",
    "all/entropy/max/mult_-0.3",
    "all/entropy/max/mult_-0.4",
]

generate_ppl_change_matrix(
    configs,
    config_flores["model"],
    config_flores["dataset"],
    config_flores["languages"],
    in_path,
    normal_ppl_result,
)

### Classification

In [None]:
metric_path = (
    project_dir
    / "classification"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "EleutherAI"
    / "sae-Llama-3.2-1B-131k"
    / "MartinThoma"
    / "wili_2018"
    / "min-max"
    / "metrics.json"
)

output_path = (
    project_dir
    / "visualization"
    / "classification"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "EleutherAI"
    / "sae-Llama-3.2-1B-131k"
    / "MartinThoma"
    / "wili_2018"
    / "min-max"
)

metric = json.load(open(metric_path, "r"))

plot_metrics(metric, output_path)

In [13]:
metric_path = (
    project_dir
    / "classification"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "EleutherAI"
    / "sae-Llama-3.2-1B-131k"
    / "MartinThoma"
    / "wili_2018"
    / "count"
    / "metrics.json"
)

output_path = (
    project_dir
    / "visualization"
    / "classification"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "EleutherAI"
    / "sae-Llama-3.2-1B-131k"
    / "MartinThoma"
    / "wili_2018"
    / "count"
)

metric = json.load(open(metric_path, "r"))

plot_metrics(metric, output_path)

## Text Generation Visualization

In [48]:
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Combine data from both tables
data_lower = {
    'Language': ['de', 'fr', 'it', 'pt', 'hi', 'es', 'th', 'bg', 'ru', 'tr', 'vi', 'ja', 'ko', 'zh'],
    'Alpha': [0.4, 0.3, 0.4, 0.2, 0.175, 0.5, 0.375, 0.4, 0.5, 0.25, 0.3, 0.3, 0.4, 0.3],
    'Change_Count': [17, 29, 10, 17, 29, 42, 18, 8, 27, 16, 11, 9, 16, 28],
    'Change_Incoherent': [0, 0, 3, 0, 19, 0, 6, 0, 2, 1, 0, 4, 4, 1],
    'Change_Partially_Coherent': [3, 15, 4, 7, 7, 10, 12, 5, 15, 13, 7, 3, 11, 12],
    'Change_Coherent': [14, 14, 3, 10, 3, 32, 0, 3, 10, 2, 4, 2, 1, 15],
    'Unchange_Count': [83, 71, 90, 83, 71, 58, 82, 92, 73, 84, 89, 91, 84, 72],
    'Unchange_Incoherent': [2, 0, 0, 5, 2, 0, 0, 1, 2, 0, 2, 6, 3, 1],
    'Unchange_Partially_Coherent': [7, 0, 8, 16, 3, 3, 11, 30, 4, 17, 7, 10, 18, 3],
    'Unchange_Coherent': [74, 71, 82, 62, 64, 55, 71, 61, 67, 67, 80, 75, 63, 68]
}

data_higher = {
    'Language': ['en', 'de', 'fr', 'it', 'pt', 'hi', 'es', 'th', 'bg', 'ru', 'tr', 'vi', 'ja', 'ko', 'zh'],
    'Alpha': [-1.2, 0.5, 0.4, 0.5, 0.25, 0.2, 0.8, 0.4, 0.5, 0.6, 0.3, 0.4, 0.4, 0.5, 0.4],
    'Change_Count': [5, 27, 38, 23, 12, 30, 45, 45, 22, 34, 27, 18, 19, 43, 66],
    'Change_Incoherent': [1, 2, 14, 20, 3, 23, 1, 37, 12, 9, 11, 4, 8, 24, 7],
    'Change_Partially_Coherent': [3, 13, 14, 3, 7, 5, 26, 8, 10, 12, 11, 9, 10, 18, 32],
    'Change_Coherent': [1, 12, 10, 0, 2, 2, 18, 0, 0, 13, 5, 5, 1, 1, 27],
    'Unchange_Count': [95, 73, 62, 77, 88, 70, 55, 55, 78, 66, 73, 82, 81, 57, 34],
    'Unchange_Incoherent': [8, 3, 1, 2, 28, 9, 0, 2, 3, 15, 6, 11, 9, 6, 1],
    'Unchange_Partially_Coherent': [10, 15, 0, 4, 17, 3, 1, 11, 32, 5, 8, 12, 10, 8, 2],
    'Unchange_Coherent': [77, 55, 61, 71, 43, 58, 54, 41, 43, 46, 59, 59, 62, 43, 31]
}

# Combine both datasets
df_lower = pd.DataFrame(data_lower)
df_lower['Table'] = 'Lower α'

df_higher = pd.DataFrame(data_higher)
df_higher['Table'] = 'Higher α'

# Remove English from higher as it's a special case with negative alpha
df_higher_no_en = df_higher[df_higher['Language'] != 'en'].copy()

# Combine dataframes
df_combined = pd.concat([df_lower, df_higher_no_en])

# Find languages that appear in both tables
common_languages = set(df_lower['Language']).intersection(set(df_higher_no_en['Language']))

# Create a dataframe with paired data for languages that appear in both tables
paired_data = []
for lang in common_languages:
    lower_row = df_lower[df_lower['Language'] == lang].iloc[0]
    higher_row = df_higher_no_en[df_higher_no_en['Language'] == lang].iloc[0]
    
    paired_data.append({
        'Language': lang,
        'Alpha_Lower': lower_row['Alpha'],
        'Alpha_Higher': higher_row['Alpha'],
        'Change_Count_Lower': lower_row['Change_Count'],
        'Change_Count_Higher': higher_row['Change_Count'],
        'Change_Incoherent_Lower': lower_row['Change_Incoherent'],
        'Change_Incoherent_Higher': higher_row['Change_Incoherent'],
        'Change_Partially_Coherent_Lower': lower_row['Change_Partially_Coherent'],
        'Change_Partially_Coherent_Higher': higher_row['Change_Partially_Coherent'],
        'Change_Coherent_Lower': lower_row['Change_Coherent'],
        'Change_Coherent_Higher': higher_row['Change_Coherent']
    })

df_paired = pd.DataFrame(paired_data)

# Calculate percentages of coherence categories within the Changed texts
df_paired['Change_Incoherent_Pct_Lower'] = df_paired['Change_Incoherent_Lower'] / df_paired['Change_Count_Lower'] * 100
df_paired['Change_Partially_Coherent_Pct_Lower'] = df_paired['Change_Partially_Coherent_Lower'] / df_paired['Change_Count_Lower'] * 100
df_paired['Change_Coherent_Pct_Lower'] = df_paired['Change_Coherent_Lower'] / df_paired['Change_Count_Lower'] * 100

df_paired['Change_Incoherent_Pct_Higher'] = df_paired['Change_Incoherent_Higher'] / df_paired['Change_Count_Higher'] * 100
df_paired['Change_Partially_Coherent_Pct_Higher'] = df_paired['Change_Partially_Coherent_Higher'] / df_paired['Change_Count_Higher'] * 100
df_paired['Change_Coherent_Pct_Higher'] = df_paired['Change_Coherent_Higher'] / df_paired['Change_Count_Higher'] * 100

# Sort by alpha difference to see the effect of increasing alpha
df_paired['Alpha_Diff'] = df_paired['Alpha_Higher'] - df_paired['Alpha_Lower']
df_paired = df_paired.sort_values(by='Alpha_Diff', ascending=False)

In [51]:
# Define the ordered language list
ordered_languages = [
    "de",
    "fr",
    "it",
    "pt",
    "hi",
    "es",
    "th",
    "bg",
    "ru",
    "tr",
    "vi",
    "ja",
    "ko",
    "zh",
]

# Sort the paired dataframe according to the ordered language list
df_paired["Language_Order"] = df_paired["Language"].apply(
    lambda x: (
        ordered_languages.index(x) if x in ordered_languages else len(ordered_languages)
    )
)
df_paired = df_paired.sort_values("Language_Order").reset_index(drop=True)

# Figure 1: Comparing change in language generation at different alpha values
fig1 = go.Figure()

# Add lines for each coherence category
fig1.add_trace(
    go.Scatter(
        x=df_paired["Language"],
        y=df_paired["Change_Count_Lower"],
        mode="markers+lines",
        name="Changed Text Count (Lower α)",
        marker=dict(size=10, color="blue"),
        line=dict(width=2),
    )
)

fig1.add_trace(
    go.Scatter(
        x=df_paired["Language"],
        y=df_paired["Change_Count_Higher"],
        mode="markers+lines",
        name="Changed Text Count (Higher α)",
        marker=dict(size=10, color="red"),
        line=dict(width=2),
    )
)

# Add alpha values as annotations
for i, row in df_paired.iterrows():
    fig1.add_annotation(
        x=row["Language"],
        y=row["Change_Count_Lower"],
        text=f"α={row['Alpha_Lower']}",
        showarrow=False,
        yshift=-20,
        font=dict(size=10, color="blue"),
    )
    fig1.add_annotation(
        x=row["Language"],
        y=row["Change_Count_Higher"],
        text=f"α={row['Alpha_Higher']}",
        showarrow=False,
        yshift=10,
        font=dict(size=10, color="red"),
    )

fig1.update_layout(
    title="Impact of Increasing Scaling Factor (α) on Language Generation",
    xaxis_title="Target Language",
    yaxis_title="Count of Texts Changed to Target Language",
    legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
    width=1000,
    height=600,
    hovermode="x unified",
    plot_bgcolor="white",
)

fig1.update_xaxes(
    categoryorder="array",
    categoryarray=ordered_languages,
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)

fig1.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)

output_path = (
    project_dir
    / "images"
    / "visualization"
    / "text_generation"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "EleutherAI"
    / "sae-Llama-3.2-1B-131k"
    / "all"
)

os.makedirs(output_path, exist_ok=True)

fig1.write_image(
    output_path / "impact_of_increasing_scaling_factor_on_language_generation.pdf",
)

In [58]:
# Figure 2: Coherence breakdown as stacked bars
fig2 = make_subplots(
    rows=1, cols=2,
    subplot_titles=('Lower α Values', 'Higher α Values'),
    specs=[[{'type': 'bar'}, {'type': 'bar'}]]
)

# Convert dataframe to long format for easier plotting
coherence_data = []
for i, row in df_paired.iterrows():
    # Lower alpha values
    coherence_data.append({
        'Language': row['Language'], 
        'Alpha Value': 'Lower',
        'Alpha': row['Alpha_Lower'],
        'Category': 'Coherent', 
        'Percentage': row['Change_Coherent_Pct_Lower']
    })
    coherence_data.append({
        'Language': row['Language'], 
        'Alpha Value': 'Lower',
        'Alpha': row['Alpha_Lower'],
        'Category': 'Partially Coherent', 
        'Percentage': row['Change_Partially_Coherent_Pct_Lower']
    })
    coherence_data.append({
        'Language': row['Language'], 
        'Alpha Value': 'Lower',
        'Alpha': row['Alpha_Lower'],
        'Category': 'Incoherent', 
        'Percentage': row['Change_Incoherent_Pct_Lower']
    })
    
    # Higher alpha values
    coherence_data.append({
        'Language': row['Language'], 
        'Alpha Value': 'Higher',
        'Alpha': row['Alpha_Higher'],
        'Category': 'Coherent', 
        'Percentage': row['Change_Coherent_Pct_Higher']
    })
    coherence_data.append({
        'Language': row['Language'], 
        'Alpha Value': 'Higher',
        'Alpha': row['Alpha_Higher'],
        'Category': 'Partially Coherent', 
        'Percentage': row['Change_Partially_Coherent_Pct_Higher']
    })
    coherence_data.append({
        'Language': row['Language'], 
        'Alpha Value': 'Higher',
        'Alpha': row['Alpha_Higher'],
        'Category': 'Incoherent', 
        'Percentage': row['Change_Incoherent_Pct_Higher']
    })

df_coherence = pd.DataFrame(coherence_data)


# Create color map for coherence categories
color_map = {
    'Coherent': 'rgb(53, 167, 107)',
    'Partially Coherent': 'rgb(253, 174, 97)',
    'Incoherent': 'rgb(215, 48, 39)'
}

# Plot lower alpha coherence breakdown
for category in ['Coherent', 'Partially Coherent', 'Incoherent']:
    df_cat = df_coherence[(df_coherence['Alpha Value'] == 'Lower') & 
                          (df_coherence['Category'] == category)]
    
    # Reorder data according to language_order
    df_cat = df_cat.set_index('Language').reindex(ordered_languages).reset_index()
    
    fig2.add_trace(
        go.Bar(
            x=df_cat['Language'],
            y=df_cat['Percentage'],
            name=category,
            marker_color=color_map[category],
            legendgroup=category,
            showlegend=True,
            text=[f"{val:.1f}%" for val in df_cat['Percentage']],
            textposition='inside',
            textfont=dict(color='white', size=10),
        ),
        row=1, col=1
    )

# Plot higher alpha coherence breakdown
for category in ['Coherent', 'Partially Coherent', 'Incoherent']:
    df_cat = df_coherence[(df_coherence['Alpha Value'] == 'Higher') & 
                          (df_coherence['Category'] == category)]
    
    # Reorder data according to language_order
    df_cat = df_cat.set_index('Language').reindex(ordered_languages).reset_index()
    
    fig2.add_trace(
        go.Bar(
            x=df_cat['Language'],
            y=df_cat['Percentage'],
            name=category,
            marker_color=color_map[category],
            legendgroup=category,
            showlegend=False,
            text=[f"{val:.1f}%" for val in df_cat['Percentage']],
            textposition='inside',
            textfont=dict(color='white', size=10),
        ),
        row=1, col=2
    )

# Add alpha values as annotations on x-axis
for col, alpha_val in enumerate(['Lower', 'Higher'], 1):
    for i, lang in enumerate(ordered_languages):
        alpha = df_paired[df_paired['Language'] == lang][f'Alpha_{alpha_val}'].values[0]
        fig2.add_annotation(
            x=lang,
            y=-10,
            text=f"α={alpha}",
            showarrow=False,
            xref=f'x{col}',
            yref=f'y{col}',
            font=dict(size=10)
        )

fig2.update_layout(
    title='Impact of Scaling Factor (α) on Text Coherence in Changed Languages',
    barmode='stack',
    legend=dict(orientation="h", yanchor="top", y=1.115, xanchor="right", x=1),
    width=1000,
    height=600,
    yaxis=dict(title='Percentage (%)', range=[0, 100]),
    yaxis2=dict(title='Percentage (%)', range=[0, 100]),
    xaxis=dict(title='Target Language'),
    xaxis2=dict(title='Target Language')
)

output_path = (
    project_dir
    / "images"
    / "visualization"
    / "text_generation"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "EleutherAI"
    / "sae-Llama-3.2-1B-131k"
    / "all"
)

os.makedirs(output_path, exist_ok=True)

fig2.write_image(
    output_path / "impact_of_scaling_factor_on_text_coherence_changed.pdf",
)

In [56]:
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# (Assumes df_lower, df_higher_no_en and df_paired already exist as in your setup)

# 1) Build Figure 2b
fig2b = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Lower α Values", "Higher α Values"),
    specs=[[{"type": "bar"}, {"type": "bar"}]],
)

# 2) Collect “unchanged” coherence data, but normalize over the sum of the three categories
unchanged_coherence_data = []
for _, row in df_paired.iterrows():
    lang = row["Language"]
    lower = df_lower.loc[df_lower["Language"] == lang].iloc[0]
    higher = df_higher_no_en.loc[df_higher_no_en["Language"] == lang].iloc[0]

    # Lower α
    total_lower = (
        lower["Unchange_Coherent"]
        + lower["Unchange_Partially_Coherent"]
        + lower["Unchange_Incoherent"]
    )
    if total_lower > 0:
        unchanged_coherence_data += [
            {
                "Language": lang,
                "Alpha Value": "Lower",
                "Category": "Coherent",
                "Percentage": lower["Unchange_Coherent"] / total_lower * 100,
            },
            {
                "Language": lang,
                "Alpha Value": "Lower",
                "Category": "Partially Coherent",
                "Percentage": lower["Unchange_Partially_Coherent"] / total_lower * 100,
            },
            {
                "Language": lang,
                "Alpha Value": "Lower",
                "Category": "Incoherent",
                "Percentage": lower["Unchange_Incoherent"] / total_lower * 100,
            },
        ]

    # Higher α
    total_higher = (
        higher["Unchange_Coherent"]
        + higher["Unchange_Partially_Coherent"]
        + higher["Unchange_Incoherent"]
    )
    if total_higher > 0:
        unchanged_coherence_data += [
            {
                "Language": lang,
                "Alpha Value": "Higher",
                "Category": "Coherent",
                "Percentage": higher["Unchange_Coherent"] / total_higher * 100,
            },
            {
                "Language": lang,
                "Alpha Value": "Higher",
                "Category": "Partially Coherent",
                "Percentage": higher["Unchange_Partially_Coherent"]
                / total_higher
                * 100,
            },
            {
                "Language": lang,
                "Alpha Value": "Higher",
                "Category": "Incoherent",
                "Percentage": higher["Unchange_Incoherent"] / total_higher * 100,
            },
        ]

df_unchanged_coherence = pd.DataFrame(unchanged_coherence_data)

# 3) Plot it
color_map = {
    "Coherent": "rgb(53, 167, 107)",
    "Partially Coherent": "rgb(253, 174, 97)",
    "Incoherent": "rgb(215, 48, 39)",
}
language_order = df_paired["Language"].tolist()

for col, alpha_val in enumerate(["Lower", "Higher"], start=1):
    for category in ["Coherent", "Partially Coherent", "Incoherent"]:
        df_cat = (
            df_unchanged_coherence.query(
                "`Alpha Value` == @alpha_val and Category == @category"
            )
            .set_index("Language")
            .reindex(language_order)
            .reset_index()
        )
        fig2b.add_trace(
            go.Bar(
                x=df_cat["Language"],
                y=df_cat["Percentage"],
                name=category,
                marker_color=color_map[category],
                legendgroup=category,
                showlegend=(col == 1),  # only show legend in first subplot
                text=[f"{v:.1f}%" for v in df_cat["Percentage"]],
                textposition="inside",
                textfont=dict(color="white", size=10),
            ),
            row=1,
            col=col,
        )

    # add α annotation under each language tick
    for lang in language_order:
        α = (
            row[f"Alpha_{alpha_val}"]
            if False
            else df_paired.loc[
                df_paired["Language"] == lang, f"Alpha_{alpha_val}"
            ].iloc[0]
        )
        fig2b.add_annotation(
            x=lang,
            y=-5,
            text=f"α={α}",
            showarrow=False,
            xref=f"x{col}",
            yref=f"y{col}",
            font=dict(size=10),
        )

fig2b.update_layout(
    title="Impact of Scaling Factor (α) on Text Coherence in Unchanged Languages",
    barmode="stack",
    legend=dict(orientation="h", yanchor="top", y=1.115, xanchor="right", x=1),
    width=1000,
    height=600,
    yaxis=dict(title="Percentage (%)", range=[0, 100]),
    yaxis2=dict(title="Percentage (%)", range=[0, 100]),
    xaxis=dict(title="Target Language"),
    xaxis2=dict(title="Target Language"),
    plot_bgcolor="white",
)


output_path = (
    project_dir
    / "images"
    / "visualization"
    / "text_generation"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "EleutherAI"
    / "sae-Llama-3.2-1B-131k"
    / "all"
)

os.makedirs(output_path, exist_ok=True)

fig2b.write_image(
    output_path / "impact_of_scaling_factor_on_text_coherence_unchanged.pdf",
)