# SAE Features Visualization

## Dependencies

In [None]:
import json
import torch
import os
from ast import literal_eval

import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
from pathlib import Path

project_dir = Path().resolve().parent
statistic_dir = project_dir / "statistics"
script_dir = project_dir / "scripts"

In [None]:
import sys

sys.path.append(str(script_dir))

In [None]:
%reload_ext autoreload
%autoreload 2

from visualization import (
    plot_all_layers,
    plot_all_lang_feature_overlap,
    plot_lang_feature_overlap_trend,
    plot_all_co_occurrence,
    plot_all_cross_co_occurrence,
    plot_all_count_box_plots,
    plot_lape_result,
    plot_umap,
    plot_ppl_change_matrix,
    generate_ppl_change_matrix,
    plot_metrics,
    plot_features_similarity,
    plot_sae_features_entropy_score_correlation,
    plot_intersection_heatmap,
    plot_iou_heatmap,
    plot_shared_count_bar_chart,
    plot_fastext_vs_sae_metrics,
    plot_entropy_distribution,
)

from feature_visualizer import (
    generate_feature_activations_visualization,
)

from loader import (
    load_layer_to_summary,
    load_lang_to_dataset_token_activations,
    load_lang_to_dataset_token_activations_aggregate,
    load_all_interpretations,
    load_sae_features_info_df,
    load_lang_to_sae_features_info,
)

from const import (
    lang_choices_to_qualified_name, 
    layer_to_index, 
    lang_choices_to_qualified_name,
    lang_choices_to_iso639_1,
    hookpoint_to_layer,
    lang_choices_to_flores,
    )

from delphi.log.result_analysis import get_metrics_per_latent, load_data

## Llama 3.2-1B

In [None]:
config_xnli = {
    "model": "meta-llama/Llama-3.2-1B",
    "sae": {
        "model": "EleutherAI/sae-Llama-3.2-1B-131k",
        "num_latents": 131072,
    },
    "dataset": "facebook/xnli",
    "split": "train",
    "languages": [
        "en",
        "de",
        "fr",
        "hi",
        "es",
        "th",
        "bg",
        "ru",
        "tr",
        "vi",
    ],
    "layers": [
        "model.layers.0.mlp",
        "model.layers.1.mlp",
        "model.layers.2.mlp",
        "model.layers.3.mlp",
        "model.layers.4.mlp",
        "model.layers.5.mlp",
        "model.layers.6.mlp",
        "model.layers.7.mlp",
        "model.layers.8.mlp",
        "model.layers.9.mlp",
        "model.layers.10.mlp",
        "model.layers.11.mlp",
        "model.layers.12.mlp",
        "model.layers.13.mlp",
        "model.layers.14.mlp",
        "model.layers.15.mlp",
    ],
}

config_pawsx = {
    "model": "meta-llama/Llama-3.2-1B",
    "sae": {
        "model": "EleutherAI/sae-Llama-3.2-1B-131k",
        "num_latents": 131072,
    },
    "dataset": "google-research-datasets/paws-x",
    "split": "train",
    "languages": [
        "en",
        "de",
        "fr",
        "es",
        "ja",
        "ko",
        "zh",
    ],
    "layers": [
        "model.layers.0.mlp",
        "model.layers.1.mlp",
        "model.layers.2.mlp",
        "model.layers.3.mlp",
        "model.layers.4.mlp",
        "model.layers.5.mlp",
        "model.layers.6.mlp",
        "model.layers.7.mlp",
        "model.layers.8.mlp",
        "model.layers.9.mlp",
        "model.layers.10.mlp",
        "model.layers.11.mlp",
        "model.layers.12.mlp",
        "model.layers.13.mlp",
        "model.layers.14.mlp",
        "model.layers.15.mlp",
    ],
}

config_flores = {
    "model": "meta-llama/Llama-3.2-1B",
    "sae": {
        "model": "EleutherAI/sae-Llama-3.2-1B-131k",
        "num_latents": 131072,
    },
    "dataset": "openlanguagedata/flores_plus",
    "split": "dev",
    "languages": [
        "eng_Latn",
        "deu_Latn",
        "fra_Latn",
        "ita_Latn",
        "por_Latn",
        "hin_Deva",
        "spa_Latn",
        "tha_Thai",
        "bul_Cyrl",
        "rus_Cyrl",
        "tur_Latn",
        "vie_Latn",
        "jpn_Jpan",
        "kor_Hang",
        "cmn_Hans",
    ],
    "layers": [
        "model.layers.0.mlp",
        "model.layers.1.mlp",
        "model.layers.2.mlp",
        "model.layers.3.mlp",
        "model.layers.4.mlp",
        "model.layers.5.mlp",
        "model.layers.6.mlp",
        "model.layers.7.mlp",
        "model.layers.8.mlp",
        "model.layers.9.mlp",
        "model.layers.10.mlp",
        "model.layers.11.mlp",
        "model.layers.12.mlp",
        "model.layers.13.mlp",
        "model.layers.14.mlp",
        "model.layers.15.mlp",
    ],
}

## Visualizations

### XNLI

In [None]:
data_path_summary_xnli = (
    statistic_dir
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / config_xnli["dataset"]
    / "summary"
)

df_layers_llama_xnli = load_layer_to_summary(
    data_path_summary_xnli, config_xnli["layers"], config_xnli["languages"]
)

In [None]:
plot_all_layers(df_layers_llama_xnli, config_xnli)

In [None]:
plot_all_lang_feature_overlap(df_layers_llama_xnli, config_xnli, range_y=[0, 40_000])

In [None]:
plot_lang_feature_overlap_trend(df_layers_llama_xnli, config_xnli)

In [None]:
plot_all_co_occurrence(df_layers_llama_xnli, config_xnli)

In [None]:
plot_all_count_box_plots(df_layers_llama_xnli, config_xnli)

In [None]:
data_path_dataset_token_activations_xnli = (
    statistic_dir
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / config_xnli["dataset"]
    / "dataset_token_activations"
)

df_dataset_token_activations_xnli = load_lang_to_dataset_token_activations_aggregate(
    data_path_dataset_token_activations_xnli,
    config_xnli["layers"],
    config_xnli["languages"],
)

In [None]:
df_dataset_token_activations_xnli.rename(
    columns={
        "index": "sae_feature_number",
        "count": "token_count",
    }
).to_csv("sae_features_facebook_xnli.csv", index=False)

### PAWS-X

In [None]:
data_path_pawsx = (
    statistic_dir
    / config_pawsx["model"]
    / config_pawsx["sae"]["model"]
    / config_pawsx["dataset"]
    / "summary"
)

In [None]:
df_layers_llama_pawsx = load_layer_to_summary(
    data_path_pawsx, config_pawsx["layers"], config_pawsx["languages"]
)

In [None]:
plot_all_layers(df_layers_llama_pawsx, config_pawsx)

In [None]:
plot_all_lang_feature_overlap(df_layers_llama_pawsx, config_pawsx, range_y=[0, 40_000])

In [None]:
plot_lang_feature_overlap_trend(
    df_layers_llama_pawsx,
    config_pawsx,
)

In [None]:
plot_all_co_occurrence(df_layers_llama_pawsx, config_pawsx)

In [None]:
plot_all_count_box_plots(df_layers_llama_pawsx, config_pawsx)

In [None]:
data_path_dataset_token_activations_pawsx = (
    statistic_dir
    / config_pawsx["model"]
    / config_pawsx["sae"]["model"]
    / config_pawsx["dataset"]
    / "dataset_token_activations"
)

df_dataset_token_activations_pawsx = load_lang_to_dataset_token_activations_aggregate(
    data_path_dataset_token_activations_pawsx,
    config_pawsx["layers"],
    config_pawsx["languages"],
)

In [None]:
df_dataset_token_activations_pawsx.rename(
    columns={
        "index": "sae_feature_number",
        "count": "token_count",
    }
).to_csv("sae_features_google-research-datasets_paws-x.csv", index=False)

#### XNLI and PAWS-X

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_xnli, config_xnli, df_layers_llama_pawsx, config_pawsx
)

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_xnli,
    config_xnli,
    df_layers_llama_pawsx,
    config_pawsx,
    specific_feature_lang_count=1,
)

### FLORES+

In [None]:
data_path_flores = (
    statistic_dir
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / config_flores["dataset"]
    / "summary"
)

In [None]:
df_layers_llama_flores = load_layer_to_summary(
    data_path_flores, config_flores["layers"], config_flores["languages"]
)

In [None]:
plot_all_layers(df_layers_llama_flores, config_flores)

In [None]:
plot_all_lang_feature_overlap(
    df_layers_llama_flores, config_flores, range_y=[0, 40_000]
)

In [None]:
plot_lang_feature_overlap_trend(
    df_layers_llama_flores,
    config_flores,
)

In [None]:
plot_all_co_occurrence(df_layers_llama_flores, config_flores)

In [None]:
plot_all_count_box_plots(df_layers_llama_flores, config_flores)

In [None]:
data_path_dataset_token_activations_flores = (
    statistic_dir
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / config_flores["dataset"]
    / "dataset_token_activations"
)

df_dataset_token_activations_flores = load_lang_to_dataset_token_activations_aggregate(
    data_path_dataset_token_activations_flores,
    config_flores["layers"],
    config_flores["languages"],
)

In [None]:
df_dataset_token_activations_flores.rename(
    columns={
        "index": "sae_feature_number",
        "count": "token_count",
    }
).to_csv("sae_features_gsarti_flores_101.csv", index=False, float_format="%.2f")

#### Flores-101 with XNLI and PAWS-X

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_flores, config_flores, df_layers_llama_xnli, config_xnli
)

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_flores,
    config_flores,
    df_layers_llama_xnli,
    config_xnli,
    specific_feature_lang_count=1,
)

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_flores, config_flores, df_layers_llama_pawsx, config_pawsx
)

In [None]:
plot_all_cross_co_occurrence(
    df_layers_llama_flores,
    config_flores,
    df_layers_llama_pawsx,
    config_pawsx,
    specific_feature_lang_count=1,
)

### Feature Index Visualization

In [None]:
data_path_dataset_token_activations_xnli = (
    statistic_dir
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / config_xnli["dataset"]
    / "dataset_token_activations"
)

data_path_dataset_token_activations_pawsx = (
    statistic_dir
    / config_pawsx["model"]
    / config_pawsx["sae"]["model"]
    / config_pawsx["dataset"]
    / "dataset_token_activations"
)

data_path_dataset_token_activations_flores = (
    statistic_dir
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / config_flores["dataset"]
    / "dataset_token_activations"
)

In [None]:
model = config_xnli["model"].split("/")[-1]
sae_model_name = config_xnli["sae"]["model"].split("/")[-1]

out_path = project_dir / "visualization" / "feature_index" / model / sae_model_name

In [None]:
feature_index = 25
layer = "model.layers.0.mlp"

model = config_flores["model"]
sae_model = config_flores["sae"]["model"]
layers = config_flores["layers"]

In [None]:
lang_to_dataset_token_activations_xnli = load_lang_to_dataset_token_activations(
    data_path_dataset_token_activations_xnli,
    layer,
    config_xnli["languages"],
    [25, 100],
)

In [None]:
lang_to_dataset_token_activations_xnli = load_lang_to_dataset_token_activations(
    data_path_dataset_token_activations_xnli,
    layer,
    config_xnli["languages"],
    [feature_index],
)

lang_to_dataset_token_activations_pawsx = load_lang_to_dataset_token_activations(
    data_path_dataset_token_activations_pawsx,
    layer,
    config_pawsx["languages"],
    [feature_index],
)

lang_to_dataset_token_activations_flores = load_lang_to_dataset_token_activations(
    data_path_dataset_token_activations_flores,
    layer,
    config_flores["languages"],
    [feature_index],
)

dataset_lang_to_dataset_token_activations = {
    "xnli": {
        "dataset_token_activations": lang_to_dataset_token_activations_xnli,
        "config": {**config_xnli},
    },
    "paws-x": {
        "dataset_token_activations": lang_to_dataset_token_activations_pawsx,
        "config": {**config_pawsx},
    },
    "flores": {
        "dataset_token_activations": lang_to_dataset_token_activations_flores,
        "config": {**config_flores},
    },
}

In [None]:
feature_info = {
    "feature_index": feature_index,
    "layer": layer,
    "lang": "None",
    "selected_prob": "-",
    "entropy": "-",
    "interpretation": "-",
    "metrics": [
        {
            "score_type": "-",
            "true_positives": "-",
            "true_negatives": "-",
            "false_positives": "-",
            "false_negatives": "-",
            "total_examples": "-",
            "total_positives": "-",
            "total_negatives": "-",
            "failed_count": "-",
            "precision": "-",
            "recall": "-",
            "f1_score": "-",
            "accuracy": "-",
            "true_positive_rate": "-",
            "true_negative_rate": "-",
            "false_positive_rate": "-",
            "false_negative_rate": "-",
            "positive_class_ratio": "-",
            "negative_class_ratio": "-",
            "auc": None,
        }
    ],
}

generate_feature_activations_visualization(
    dataset_lang_to_dataset_token_activations,
    feature_index,
    feature_info,
    model,
    layer,
    sae_model,
    out_path,
    lang_choices_to_qualified_name,
)

### LAPE

In [None]:
lape_top_10_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_top_10_by_entropy.pt"
)

lape_top_10_result = torch.load(lape_top_10_result_path, weights_only=False)

plot_lape_result(
    lape_top_10_result,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/sae_features/lape_top_10_by_entropy"
    ),
    title="Distribution of Top-10 Language-Specific Features by Entropy",
)

In [None]:
lape_top_10_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_top_10_by_freq.pt"
)

lape_top_10_result = torch.load(lape_top_10_result_path, weights_only=False)

plot_lape_result(
    lape_top_10_result,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/sae_features/lape_top_10_by_freq"
    ),
    title="Distribution of Top-10 Language-Specific Features by Frequency",
)

In [None]:
lape_neuron_result_path = (
    project_dir / "mlp_acts_specific" / config_xnli["model"] / "lape_neuron.pt"
)

lape_neuron_result = torch.load(lape_neuron_result_path, weights_only=False)

plot_lape_result(
    lape_neuron_result,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/lape_neuron"
    ),
    title="Distribution of Language-specific Neurons",
)

In [None]:
lape_top_1_per_layer_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_top_1_per_layer_by_entropy.pt"
)

lape_result_top_1_per_layer = torch.load(
    lape_top_1_per_layer_result_path, weights_only=False
)

plot_lape_result(
    lape_result_top_1_per_layer,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/sae_features/lape_top_1_per_layer_by_entropy"
    ),
    title="Distribution of Top-1 per Layer Language-Specific Features by Entropy",
)

In [None]:
lape_top_1_per_layer_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_top_1_per_layer_by_freq.pt"
)

lape_result_top_1_per_layer = torch.load(
    lape_top_1_per_layer_result_path, weights_only=False
)

plot_lape_result(
    lape_result_top_1_per_layer,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/sae_features/lape_top_1_per_layer_by_freq"
    ),
    title="Distribution of Top-1 per Layer Language-Specific Features by Frequency",
)

In [None]:
lape_all_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_all.pt"
)

lape_all_result = torch.load(lape_all_result_path, weights_only=False)

plot_lape_result(
    lape_all_result,
    out_dir=Path(
        r"visualization/lape/meta-llama/Llama-3.2-1B/EleutherAI/sae-Llama-3.2-1B-131k/sae_features/lape_all"
    ),
    title="Distribution of LAPE for all Language-Specific Features",
)

### Language-Specific Features Visualization

In [None]:
lape_all_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_all.pt"
)

lape_all_result = torch.load(lape_all_result_path, weights_only=False)

In [None]:
def convert_df_metrics_to_nested_dict(df):
    result = {}

    for _, row in df.iterrows():
        layer = row["layer"]
        latent_idx = row["latent_idx"]
        values = row.drop(["layer", "latent_idx"])
        values = values.apply(lambda x: round(x, 3) if isinstance(x, float) else x)

        layer_key = f"model.{layer}"

        if layer_key not in result:
            result[layer_key] = {}
        if latent_idx not in result[layer_key]:
            result[layer_key][latent_idx] = []

        result[layer_key][latent_idx].append(values.to_dict())

    return result

In [None]:
interpretation_folder = project_dir / "interpret_sae_features" / "explanations"

scores_path = project_dir / "interpret_sae_features" / "scores"

visualize_path = project_dir / "visualization" / "interpret_sae_features" / "scores"

hookpoints = [
    "layers.0.mlp",
    "layers.1.mlp",
    "layers.2.mlp",
    "layers.3.mlp",
    "layers.4.mlp",
    "layers.5.mlp",
    "layers.6.mlp",
    "layers.7.mlp",
    "layers.8.mlp",
    "layers.9.mlp",
    "layers.10.mlp",
    "layers.11.mlp",
    "layers.12.mlp",
    "layers.13.mlp",
    "layers.14.mlp",
    "layers.15.mlp",
]

In [None]:
interpretations = load_all_interpretations(interpretation_folder)
latent_df, counts = load_data(scores_path, hookpoints)
df_metrics = get_metrics_per_latent(latent_df)
metrics = convert_df_metrics_to_nested_dict(df_metrics)

In [None]:
model = config_flores["model"]
sae_model = config_flores["sae"]["model"]
model_name = config_flores["model"].split("/")[-1]
sae_model_name = config_flores["sae"]["model"].split("/")[-1]
layers = config_flores["layers"]

sorted_lang = lape_all_result["sorted_lang"]

for lang in tqdm(sorted_lang, desc="Processing languages"):
    lang_index = sorted_lang.index(lang)

    for layer in tqdm(layers, desc="Processing layers", leave=False):
        layer_index = layer_to_index[layer]
        lang_final_indices = lape_all_result["final_indice"][lang_index][
            layer_index
        ].tolist()

        if len(lang_final_indices) == 0:
            continue

        layer = layers[layer_index]

        lang_to_dataset_token_activations_xnli = load_lang_to_dataset_token_activations(
            data_path_dataset_token_activations_xnli,
            layer,
            config_xnli["languages"],
            lang_final_indices,
        )

        lang_to_dataset_token_activations_pawsx = (
            load_lang_to_dataset_token_activations(
                data_path_dataset_token_activations_pawsx,
                layer,
                config_pawsx["languages"],
                lang_final_indices,
            )
        )

        lang_to_dataset_token_activations_flores = (
            load_lang_to_dataset_token_activations(
                data_path_dataset_token_activations_flores,
                layer,
                config_flores["languages"],
                lang_final_indices,
            )
        )

        dataset_lang_to_dataset_token_activations = {
            "xnli": {
                "dataset_token_activations": lang_to_dataset_token_activations_xnli,
                "config": {**config_xnli},
            },
            "paws-x": {
                "dataset_token_activations": lang_to_dataset_token_activations_pawsx,
                "config": {**config_pawsx},
            },
            "flores": {
                "dataset_token_activations": lang_to_dataset_token_activations_flores,
                "config": {**config_flores},
            },
        }

        out_path = (
            project_dir
            / "visualization"
            / "feature_index"
            / model_name
            / sae_model_name
            / layer
            / lang
        )

        selected_probs = lape_all_result["features_info"][lang]["selected_probs"]
        entropies = lape_all_result["features_info"][lang]["entropies"]

        for feature_index in tqdm(
            lang_final_indices, desc="Processing indices", leave=False
        ):
            try:
                file_path = out_path / f"feature_{feature_index}.html"

                if file_path.exists():
                    continue

                arg_index = lape_all_result["features_info"][lang]["indicies"].index(
                    (layer_index, feature_index)
                )

                feature_info = {
                    "feature_index": feature_index,
                    "layer": layer,
                    "lang": lang,
                    "selected_prob": round(selected_probs[arg_index].item(), ndigits=3),
                    "entropy": round(entropies[arg_index].item(), ndigits=3),
                    "interpretation": interpretations[layer][feature_index],
                    "metrics": metrics[layer][feature_index],
                }

                generate_feature_activations_visualization(
                    dataset_lang_to_dataset_token_activations,
                    feature_index,
                    feature_info,
                    model,
                    layer,
                    sae_model,
                    out_path,
                    lang_choices_to_qualified_name,
                    examples_per_section=40,
                )
            except Exception as e:
                print(f"Error processing {lang} - {layer} - {feature_index}")
                print(e)

### UMAP

In [None]:
lape_all_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_all.pt"
)

In [None]:
lape_all_result = torch.load(lape_all_result_path, weights_only=False)

In [None]:
umap_output_dir = (
    project_dir
    / "visualization"
    / "umap"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
)

In [None]:
model = config_xnli["model"]
sae_model = config_xnli["sae"]["model"]
layers = config_xnli["layers"]

plot_umap(lape_all_result, layers, model, sae_model, umap_output_dir)

### Perplexity

In [None]:
normal_ppl_output_path = (
    project_dir
    / "ppl"
    / config_flores["model"]
    / config_flores["dataset"]
    / "normal"
    / "ppl.pt"
)

normal_ppl_result = torch.load(normal_ppl_output_path, weights_only=False)

#### Neuron Intervention

In [None]:
out_path = (
    project_dir
    / "ppl"
    / config_flores["model"]
    / config_flores["dataset"]
    / "neuron_intervention"
    / "fixed_0"
)

intervened_neuron_ppl_results = {
    lang_choices_to_qualified_name[intervened_lang]: torch.load(
        out_path / f"ppl_{intervened_lang}.pt", weights_only=False
    )
    for intervened_lang in config_flores["languages"]
}

In [None]:
out_path = (
    project_dir
    / "visualization"
    / "ppl"
    / config_flores["model"]
    / config_flores["dataset"]
    / "neuron_intervention"
    / "fixed_0"
    / "ppl_change_matrix.html"
)

plot_ppl_change_matrix(
    config_flores["languages"],
    normal_ppl_result,
    intervened_neuron_ppl_results,
    out_path,
    title="PPL Change Matrix for Neuron Interventions",
    num_examples=1000,
)

In [None]:
configs = [
    "plus_0.1",
    "plus_0.2",
    "plus_0.3",
    "plus_0.4",
    "min_0.1",
    "min_0.2",
    "min_0.3",
    "min_0.4",
]

for config in configs:
    out_path = (
        project_dir
        / "ppl"
        / config_flores["model"]
        / config_flores["dataset"]
        / "neuron_intervention"
        / "baseline"
        / config
    )

    intervened_neuron_ppl_results = {
        lang_choices_to_qualified_name[intervened_lang]: torch.load(
            out_path / f"ppl_{intervened_lang}.pt", weights_only=False
        )
        for intervened_lang in config_flores["languages"]
    }

    out_path = (
        project_dir
        / "visualization"
        / "ppl"
        / config_flores["model"]
        / config_flores["dataset"]
        / "neuron_intervention"
        / "baseline"
        / config
        / "ppl_change_matrix.html"
    )

    plot_ppl_change_matrix(
        config_flores["languages"],
        normal_ppl_result,
        intervened_neuron_ppl_results,
        out_path,
        title="PPL Change Matrix for Neuron Interventions",
        num_examples=1000,
    )

In [None]:
out_path = (
    project_dir
    / "ppl"
    / config_flores["model"]
    / config_flores["dataset"]
    / "neuron_intervention"
    / "baseline"
    / "min_0_2"
)

intervened_neuron_ppl_results = {
    lang_choices_to_qualified_name[intervened_lang]: torch.load(
        out_path / f"ppl_{intervened_lang}.pt", weights_only=False
    )
    for intervened_lang in config_flores["languages"]
}

out_path = (
    project_dir
    / "visualization"
    / "ppl"
    / config_flores["model"]
    / config_flores["dataset"]
    / "neuron_intervention"
    / "baseline"
    / "min_0_2"
    / "ppl_change_matrix.html"
)

plot_ppl_change_matrix(
    config_flores["languages"],
    normal_ppl_result,
    intervened_neuron_ppl_results,
    out_path,
    title="PPL Change Matrix for Neuron Interventions",
    num_examples=1000,
)

#### SAE Feature Intervention

In [None]:
in_path = (
    project_dir
    / "ppl"
    / config_flores["model"]
    / config_flores["dataset"]
    / "sae_intervention"
)

#### All Layers

In [None]:
configs = [
    "top_10/entropy/max/mult_0.2",
    "top_10/entropy/max/mult_-0.2",
    "top_1_per_layer/entropy/avg/mult_1",
    "top_1_per_layer/entropy/avg/mult_-1",
    "top_1_per_layer/entropy/max/mult_0.2",
    "top_1_per_layer/entropy/max/mult_-0.2",
    "top_1_per_layer/freq/avg/mult_-1",
    "all/entropy/max/mult_0.1",
    "all/entropy/max/mult_0.2",
    "all/entropy/max/mult_0.3",
    "all/entropy/max/mult_0.4",
    "all/entropy/max/mult_-0.1",
    "all/entropy/max/mult_-0.2",
    "all/entropy/max/mult_-0.3",
    "all/entropy/max/mult_-0.4",
]

generate_ppl_change_matrix(
    configs,
    config_flores["model"],
    config_flores["dataset"],
    config_flores["languages"],
    in_path,
    normal_ppl_result,
)

### Classification

In [None]:
metric_path = (
    project_dir
    / "classification"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "EleutherAI"
    / "sae-Llama-3.2-1B-131k"
    / "MartinThoma"
    / "wili_2018"
    / "sae-count"
    / "metrics.json"
)

sae_based_lid_metric = json.load(open(metric_path, "r"))

In [None]:
fastext_classifier_metric = (
    project_dir
    / "classification"
    / "MartinThoma"
    / "wili_2018"
    / "fasttext"
    / "metrics.json"
)
fastext_classifier_metric = json.load(open(fastext_classifier_metric, "r"))

In [None]:
neuron_based_lid_metric = (
    project_dir
    / "classification"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "MartinThoma"
    / "wili_2018"
    / "neuron-count"
    / "metrics.json"
)
neuron_based_lid_metric = json.load(open(neuron_based_lid_metric, "r"))

In [None]:

output_path = (
    project_dir
    / "visualization"
    / "classification"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "MartinThoma"
    / "wili_2018"
)

plot_fastext_vs_sae_metrics(
    sae_based_lid_metric,
    fastext_classifier_metric,
    neuron_based_lid_metric,
    output_path,
)

In [None]:

output_path = (
    project_dir
    / "visualization"
    / "classification"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "EleutherAI"
    / "sae-Llama-3.2-1B-131k"
    / "MartinThoma"
    / "wili_2018"
    / "sae-count"
)

plot_metrics(sae_based_lid_metric, output_path)

In [None]:

output_path = (
    project_dir
    / "visualization"
    / "classification"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "MartinThoma"
    / "wili_2018"
    / "neuron-count"
)

plot_metrics(neuron_based_lid_metric, output_path)

### Text Generation Visualization

In [None]:
# Combine data from both tables
data_lower = {
    "Language": [
        "de",
        "fr",
        "it",
        "pt",
        "hi",
        "es",
        "th",
        "bg",
        "ru",
        "tr",
        "vi",
        "ja",
        "ko",
        "zh",
    ],
    "Alpha": [
        0.4,
        0.3,
        0.4,
        0.2,
        0.175,
        0.5,
        0.375,
        0.4,
        0.5,
        0.25,
        0.3,
        0.3,
        0.4,
        0.3,
    ],
    "Change_Count": [17, 29, 10, 17, 29, 42, 18, 8, 27, 16, 11, 9, 16, 28],
    "Change_Incoherent": [0, 0, 3, 0, 19, 0, 6, 0, 2, 1, 0, 4, 4, 1],
    "Change_Partially_Coherent": [3, 15, 4, 7, 7, 10, 12, 5, 15, 13, 7, 3, 11, 12],
    "Change_Coherent": [14, 14, 3, 10, 3, 32, 0, 3, 10, 2, 4, 2, 1, 15],
    "Unchange_Count": [83, 71, 90, 83, 71, 58, 82, 92, 73, 84, 89, 91, 84, 72],
    "Unchange_Incoherent": [2, 0, 0, 5, 2, 0, 0, 1, 2, 0, 2, 6, 3, 1],
    "Unchange_Partially_Coherent": [7, 0, 8, 16, 3, 3, 11, 30, 4, 17, 7, 10, 18, 3],
    "Unchange_Coherent": [74, 71, 82, 62, 64, 55, 71, 61, 67, 67, 80, 75, 63, 68],
}

data_higher = {
    "Language": [
        "en",
        "de",
        "fr",
        "it",
        "pt",
        "hi",
        "es",
        "th",
        "bg",
        "ru",
        "tr",
        "vi",
        "ja",
        "ko",
        "zh",
    ],
    "Alpha": [
        -1.2,
        0.5,
        0.4,
        0.5,
        0.25,
        0.2,
        0.8,
        0.4,
        0.5,
        0.6,
        0.3,
        0.4,
        0.4,
        0.5,
        0.4,
    ],
    "Change_Count": [5, 27, 38, 23, 12, 30, 45, 45, 22, 34, 27, 18, 19, 43, 66],
    "Change_Incoherent": [1, 2, 14, 20, 3, 23, 1, 37, 12, 9, 11, 4, 8, 24, 7],
    "Change_Partially_Coherent": [3, 13, 14, 3, 7, 5, 26, 8, 10, 12, 11, 9, 10, 18, 32],
    "Change_Coherent": [1, 12, 10, 0, 2, 2, 18, 0, 0, 13, 5, 5, 1, 1, 27],
    "Unchange_Count": [95, 73, 62, 77, 88, 70, 55, 55, 78, 66, 73, 82, 81, 57, 34],
    "Unchange_Incoherent": [8, 3, 1, 2, 28, 9, 0, 2, 3, 15, 6, 11, 9, 6, 1],
    "Unchange_Partially_Coherent": [10, 15, 0, 4, 17, 3, 1, 11, 32, 5, 8, 12, 10, 8, 2],
    "Unchange_Coherent": [77, 55, 61, 71, 43, 58, 54, 41, 43, 46, 59, 59, 62, 43, 31],
}

ordered_languages = [
    "de",
    "fr",
    "it",
    "pt",
    "hi",
    "es",
    "th",
    "bg",
    "ru",
    "tr",
    "vi",
    "ja",
    "ko",
    "zh",
]

color_map = {
    "Coherent": "rgb(53, 167, 107)",
    "Partially Coherent": "rgb(253, 174, 97)",
    "Incoherent": "rgb(215, 48, 39)",
}

In [None]:
#
df_lower = pd.DataFrame(data_lower)
df_lower = df_lower.set_index("Language").reindex(ordered_languages).reset_index()

fig = go.Figure()

fig.add_trace(
    go.Bar(
        x=df_lower["Language"],
        y=df_lower["Change_Coherent"],
        name="Coherent",
        marker_color=color_map["Coherent"],
        text=df_lower["Change_Coherent"],
        textposition="inside",
        textfont=dict(color="white", size=10),
    )
)
fig.add_trace(
    go.Bar(
        x=df_lower["Language"],
        y=df_lower["Change_Partially_Coherent"],
        name="Partially Coherent",
        marker_color=color_map["Partially Coherent"],
        text=df_lower["Change_Partially_Coherent"],
        textposition="inside",
        textfont=dict(color="white", size=10),
    )
)
fig.add_trace(
    go.Bar(
        x=df_lower["Language"],
        y=df_lower["Change_Incoherent"],
        name="Incoherent",
        marker_color=color_map["Incoherent"],
        text=df_lower["Change_Incoherent"],
        textposition="inside",
        textfont=dict(color="white", size=10),
    )
)

fig.update_layout(
    barmode="stack",
    title="Changed Texts Count (Lower α)",
    xaxis_title="Language",
    yaxis_title="Count of Changed Texts",
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=0.907,
        xanchor="right",
        x=1,
        traceorder="normal",
    ),
    width=900,
    height=500,
    plot_bgcolor="white",
)

fig.update_xaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
)

fig.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)

output_path = (
    project_dir
    / "images"
    / "visualization"
    / "text_generation"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "EleutherAI"
    / "sae-Llama-3.2-1B-131k"
    / "all"
)

os.makedirs(output_path, exist_ok=True)

fig.write_image(
    output_path / "changed_texts_count.pdf",
)

In [None]:
# Combine both datasets
df_lower = pd.DataFrame(data_lower)
df_lower["Table"] = "Lower α"

df_higher = pd.DataFrame(data_higher)
df_higher["Table"] = "Higher α"

# Remove English from higher as it's a special case with negative alpha
df_higher_no_en = df_higher[df_higher["Language"] != "en"].copy()

# Combine dataframes
df_combined = pd.concat([df_lower, df_higher_no_en])

# Find languages that appear in both tables
common_languages = set(df_lower["Language"]).intersection(
    set(df_higher_no_en["Language"])
)

# Create a dataframe with paired data for languages that appear in both tables
paired_data = []
for lang in common_languages:
    lower_row = df_lower[df_lower["Language"] == lang].iloc[0]
    higher_row = df_higher_no_en[df_higher_no_en["Language"] == lang].iloc[0]

    paired_data.append(
        {
            "Language": lang,
            "Alpha_Lower": lower_row["Alpha"],
            "Alpha_Higher": higher_row["Alpha"],
            "Change_Count_Lower": lower_row["Change_Count"],
            "Change_Count_Higher": higher_row["Change_Count"],
            "Change_Incoherent_Lower": lower_row["Change_Incoherent"],
            "Change_Incoherent_Higher": higher_row["Change_Incoherent"],
            "Change_Partially_Coherent_Lower": lower_row["Change_Partially_Coherent"],
            "Change_Partially_Coherent_Higher": higher_row["Change_Partially_Coherent"],
            "Change_Coherent_Lower": lower_row["Change_Coherent"],
            "Change_Coherent_Higher": higher_row["Change_Coherent"],
        }
    )

df_paired = pd.DataFrame(paired_data)

# Calculate percentages of coherence categories within the Changed texts
df_paired["Change_Incoherent_Pct_Lower"] = (
    df_paired["Change_Incoherent_Lower"] / df_paired["Change_Count_Lower"] * 100
)
df_paired["Change_Partially_Coherent_Pct_Lower"] = (
    df_paired["Change_Partially_Coherent_Lower"] / df_paired["Change_Count_Lower"] * 100
)
df_paired["Change_Coherent_Pct_Lower"] = (
    df_paired["Change_Coherent_Lower"] / df_paired["Change_Count_Lower"] * 100
)

df_paired["Change_Incoherent_Pct_Higher"] = (
    df_paired["Change_Incoherent_Higher"] / df_paired["Change_Count_Higher"] * 100
)
df_paired["Change_Partially_Coherent_Pct_Higher"] = (
    df_paired["Change_Partially_Coherent_Higher"]
    / df_paired["Change_Count_Higher"]
    * 100
)
df_paired["Change_Coherent_Pct_Higher"] = (
    df_paired["Change_Coherent_Higher"] / df_paired["Change_Count_Higher"] * 100
)

# Sort by alpha difference to see the effect of increasing alpha
df_paired["Alpha_Diff"] = df_paired["Alpha_Higher"] - df_paired["Alpha_Lower"]
df_paired = df_paired.sort_values(by="Alpha_Diff", ascending=False)

In [None]:
# Sort the paired dataframe according to the ordered language list
df_paired["Language_Order"] = df_paired["Language"].apply(
    lambda x: (
        ordered_languages.index(x) if x in ordered_languages else len(ordered_languages)
    )
)
df_paired = df_paired.sort_values("Language_Order").reset_index(drop=True)

# Figure 1: Comparing change in language generation at different alpha values
fig1 = go.Figure()

# Add lines for each coherence category
fig1.add_trace(
    go.Scatter(
        x=df_paired["Language"],
        y=df_paired["Change_Count_Lower"],
        mode="markers+lines",
        name="Changed Text Count (Lower α)",
        marker=dict(size=10, color="blue"),
        line=dict(width=2),
    )
)

fig1.add_trace(
    go.Scatter(
        x=df_paired["Language"],
        y=df_paired["Change_Count_Higher"],
        mode="markers+lines",
        name="Changed Text Count (Higher α)",
        marker=dict(size=10, color="red"),
        line=dict(width=2),
    )
)

# Add alpha values as annotations
for i, row in df_paired.iterrows():
    fig1.add_annotation(
        x=row["Language"],
        y=row["Change_Count_Lower"],
        text=f"α={row['Alpha_Lower']}",
        showarrow=False,
        yshift=-20,
        font=dict(size=10, color="blue"),
    )
    fig1.add_annotation(
        x=row["Language"],
        y=row["Change_Count_Higher"],
        text=f"α={row['Alpha_Higher']}",
        showarrow=False,
        yshift=10,
        font=dict(size=10, color="red"),
    )

fig1.update_layout(
    title="Impact of Increasing Scaling Factor (α) on Language Generation",
    xaxis_title="Target Language",
    yaxis_title="Count of Texts Changed to Target Language",
    legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
    width=1000,
    height=600,
    hovermode="x unified",
    plot_bgcolor="white",
)

fig1.update_xaxes(
    categoryorder="array",
    categoryarray=ordered_languages,
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)

fig1.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)

output_path = (
    project_dir
    / "images"
    / "visualization"
    / "text_generation"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "EleutherAI"
    / "sae-Llama-3.2-1B-131k"
    / "all"
)

os.makedirs(output_path, exist_ok=True)

fig1.write_image(
    output_path / "impact_of_increasing_scaling_factor_on_language_generation.pdf",
)

In [None]:
# Figure 2: Coherence breakdown as stacked bars
fig2 = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Lower α Values", "Higher α Values"),
    specs=[[{"type": "bar"}, {"type": "bar"}]],
)

# Convert dataframe to long format for easier plotting
coherence_data = []

for i, row in df_paired.iterrows():
    # Lower alpha values
    coherence_data.append(
        {
            "Language": row["Language"],
            "Alpha Value": "Lower",
            "Alpha": row["Alpha_Lower"],
            "Category": "Coherent",
            "Percentage": row["Change_Coherent_Pct_Lower"],
        }
    )
    coherence_data.append(
        {
            "Language": row["Language"],
            "Alpha Value": "Lower",
            "Alpha": row["Alpha_Lower"],
            "Category": "Partially Coherent",
            "Percentage": row["Change_Partially_Coherent_Pct_Lower"],
        }
    )
    coherence_data.append(
        {
            "Language": row["Language"],
            "Alpha Value": "Lower",
            "Alpha": row["Alpha_Lower"],
            "Category": "Incoherent",
            "Percentage": row["Change_Incoherent_Pct_Lower"],
        }
    )

    # Higher alpha values
    coherence_data.append(
        {
            "Language": row["Language"],
            "Alpha Value": "Higher",
            "Alpha": row["Alpha_Higher"],
            "Category": "Coherent",
            "Percentage": row["Change_Coherent_Pct_Higher"],
        }
    )
    coherence_data.append(
        {
            "Language": row["Language"],
            "Alpha Value": "Higher",
            "Alpha": row["Alpha_Higher"],
            "Category": "Partially Coherent",
            "Percentage": row["Change_Partially_Coherent_Pct_Higher"],
        }
    )
    coherence_data.append(
        {
            "Language": row["Language"],
            "Alpha Value": "Higher",
            "Alpha": row["Alpha_Higher"],
            "Category": "Incoherent",
            "Percentage": row["Change_Incoherent_Pct_Higher"],
        }
    )

df_coherence = pd.DataFrame(coherence_data)

# Plot lower alpha coherence breakdown
for category in ["Coherent", "Partially Coherent", "Incoherent"]:
    df_cat = df_coherence[
        (df_coherence["Alpha Value"] == "Lower")
        & (df_coherence["Category"] == category)
    ]

    # Reorder data according to language_order
    df_cat = df_cat.set_index("Language").reindex(ordered_languages).reset_index()

    fig2.add_trace(
        go.Bar(
            x=df_cat["Language"],
            y=df_cat["Percentage"],
            name=category,
            marker_color=color_map[category],
            legendgroup=category,
            showlegend=True,
            text=[f"{val:.1f}%" for val in df_cat["Percentage"]],
            textposition="inside",
            textfont=dict(color="white", size=10),
        ),
        row=1,
        col=1,
    )

# Plot higher alpha coherence breakdown
for category in ["Coherent", "Partially Coherent", "Incoherent"]:
    df_cat = df_coherence[
        (df_coherence["Alpha Value"] == "Higher")
        & (df_coherence["Category"] == category)
    ]

    # Reorder data according to language_order
    df_cat = df_cat.set_index("Language").reindex(ordered_languages).reset_index()

    fig2.add_trace(
        go.Bar(
            x=df_cat["Language"],
            y=df_cat["Percentage"],
            name=category,
            marker_color=color_map[category],
            legendgroup=category,
            showlegend=False,
            text=[f"{val:.1f}%" for val in df_cat["Percentage"]],
            textposition="inside",
            textfont=dict(color="white", size=10),
        ),
        row=1,
        col=2,
    )

# Add alpha values as annotations on x-axis
for col, alpha_val in enumerate(["Lower", "Higher"], 1):
    for i, lang in enumerate(ordered_languages):
        alpha = df_paired[df_paired["Language"] == lang][f"Alpha_{alpha_val}"].values[0]
        fig2.add_annotation(
            x=lang,
            y=-10,
            text=f"α={alpha}",
            showarrow=False,
            xref=f"x{col}",
            yref=f"y{col}",
            font=dict(size=10),
        )

fig2.update_layout(
    title="Impact of Scaling Factor (α) on Text Coherence in Changed Languages",
    barmode="stack",
    legend=dict(orientation="h", yanchor="top", y=1.115, xanchor="right", x=1),
    width=1000,
    height=600,
    yaxis=dict(title="Percentage (%)", range=[0, 100]),
    yaxis2=dict(title="Percentage (%)", range=[0, 100]),
    xaxis=dict(title="Target Language"),
    xaxis2=dict(title="Target Language"),
    plot_bgcolor="white",
)

output_path = (
    project_dir
    / "images"
    / "visualization"
    / "text_generation"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "EleutherAI"
    / "sae-Llama-3.2-1B-131k"
    / "all"
)

os.makedirs(output_path, exist_ok=True)

fig2.write_image(
    output_path / "impact_of_scaling_factor_on_text_coherence_changed.pdf",
)

In [None]:
fig2b = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Lower α Values", "Higher α Values"),
    specs=[[{"type": "bar"}, {"type": "bar"}]],
)

unchanged_coherence_data = []
for _, row in df_paired.iterrows():
    lang = row["Language"]
    lower = df_lower.loc[df_lower["Language"] == lang].iloc[0]
    higher = df_higher_no_en.loc[df_higher_no_en["Language"] == lang].iloc[0]

    # Lower α
    total_lower = (
        lower["Unchange_Coherent"]
        + lower["Unchange_Partially_Coherent"]
        + lower["Unchange_Incoherent"]
    )
    if total_lower > 0:
        unchanged_coherence_data += [
            {
                "Language": lang,
                "Alpha Value": "Lower",
                "Category": "Coherent",
                "Percentage": lower["Unchange_Coherent"] / total_lower * 100,
            },
            {
                "Language": lang,
                "Alpha Value": "Lower",
                "Category": "Partially Coherent",
                "Percentage": lower["Unchange_Partially_Coherent"] / total_lower * 100,
            },
            {
                "Language": lang,
                "Alpha Value": "Lower",
                "Category": "Incoherent",
                "Percentage": lower["Unchange_Incoherent"] / total_lower * 100,
            },
        ]

    # Higher α
    total_higher = (
        higher["Unchange_Coherent"]
        + higher["Unchange_Partially_Coherent"]
        + higher["Unchange_Incoherent"]
    )
    if total_higher > 0:
        unchanged_coherence_data += [
            {
                "Language": lang,
                "Alpha Value": "Higher",
                "Category": "Coherent",
                "Percentage": higher["Unchange_Coherent"] / total_higher * 100,
            },
            {
                "Language": lang,
                "Alpha Value": "Higher",
                "Category": "Partially Coherent",
                "Percentage": higher["Unchange_Partially_Coherent"]
                / total_higher
                * 100,
            },
            {
                "Language": lang,
                "Alpha Value": "Higher",
                "Category": "Incoherent",
                "Percentage": higher["Unchange_Incoherent"] / total_higher * 100,
            },
        ]

df_unchanged_coherence = pd.DataFrame(unchanged_coherence_data)

language_order = df_paired["Language"].tolist()

for col, alpha_val in enumerate(["Lower", "Higher"], start=1):
    for category in ["Coherent", "Partially Coherent", "Incoherent"]:
        df_cat = (
            df_unchanged_coherence.query(
                "`Alpha Value` == @alpha_val and Category == @category"
            )
            .set_index("Language")
            .reindex(language_order)
            .reset_index()
        )
        fig2b.add_trace(
            go.Bar(
                x=df_cat["Language"],
                y=df_cat["Percentage"],
                name=category,
                marker_color=color_map[category],
                legendgroup=category,
                showlegend=(col == 1),
                text=[f"{v:.1f}%" for v in df_cat["Percentage"]],
                textposition="inside",
                textfont=dict(color="white", size=10),
            ),
            row=1,
            col=col,
        )

    for lang in language_order:
        α = (
            row[f"Alpha_{alpha_val}"]
            if False
            else df_paired.loc[
                df_paired["Language"] == lang, f"Alpha_{alpha_val}"
            ].iloc[0]
        )
        fig2b.add_annotation(
            x=lang,
            y=-5,
            text=f"α={α}",
            showarrow=False,
            xref=f"x{col}",
            yref=f"y{col}",
            font=dict(size=10),
        )

fig2b.update_layout(
    title="Impact of Scaling Factor (α) on Text Coherence in Unchanged Languages",
    barmode="stack",
    legend=dict(orientation="h", yanchor="top", y=1.115, xanchor="right", x=1),
    width=1000,
    height=600,
    yaxis=dict(title="Percentage (%)", range=[0, 100]),
    yaxis2=dict(title="Percentage (%)", range=[0, 100]),
    xaxis=dict(title="Target Language"),
    xaxis2=dict(title="Target Language"),
    plot_bgcolor="white",
)


output_path = (
    project_dir
    / "images"
    / "visualization"
    / "text_generation"
    / "meta-llama"
    / "Llama-3.2-1B"
    / "EleutherAI"
    / "sae-Llama-3.2-1B-131k"
    / "all"
)

os.makedirs(output_path, exist_ok=True)

fig2b.write_image(
    output_path / "impact_of_scaling_factor_on_text_coherence_unchanged.pdf",
)

### Language-specific features Properties

In [None]:
interpretation_folder = project_dir / "interpret_sae_features" / "explanations"
scores_path = project_dir / "interpret_sae_features" / "scores"

visualize_path = project_dir / "visualization" / "interpret_sae_features" / "scores"

hookpoints = [
    "layers.0.mlp",
    "layers.1.mlp",
    "layers.2.mlp",
    "layers.3.mlp",
    "layers.4.mlp",
    "layers.5.mlp",
    "layers.6.mlp",
    "layers.7.mlp",
    "layers.8.mlp",
    "layers.9.mlp",
    "layers.10.mlp",
    "layers.11.mlp",
    "layers.12.mlp",
    "layers.13.mlp",
    "layers.14.mlp",
    "layers.15.mlp",
]

interpretations = load_all_interpretations(interpretation_folder)
latent_df, counts = load_data(scores_path, hookpoints)
df_metrics = get_metrics_per_latent(latent_df)
metrics = convert_df_metrics_to_nested_dict(df_metrics)

In [None]:
lape_all_result_path = (
    project_dir
    / "sae_features_specific"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "lape_all.pt"
)

lape_all_result = torch.load(lape_all_result_path, weights_only=False)

#### Scoring Result

##### Mean

In [None]:
def get_mean_score_result(lape_result, df_metrics):
    score_results = {}

    sorted_lang = lape_result["sorted_lang"]

    for lang_idx, lang in enumerate(sorted_lang):
        lang_final_indices = lape_result["final_indice"][lang_idx]

        combined_df = None

        for layer_idx, _ in enumerate(lang_final_indices):
            lang_layer_final_indices = lang_final_indices[layer_idx].tolist()
            layer_str = f"layers.{layer_idx}.mlp"

            combined_df = pd.concat(
                [
                    combined_df,
                    df_metrics.query(
                        "layer == @layer_str and latent_idx in @lang_layer_final_indices"
                    ),
                ]
            )

        score_results[lang] = (
            combined_df.groupby("score_type")[["accuracy", "f1_score", "precision", "recall",]]
            .mean()
            .to_dict(orient="index")
        )

    rows = []

    for lang, score_types in score_results.items():
        for score_type, metrics in score_types.items():
            row = {"Language": lang, "Score Type": score_type}
            row.update(metrics)
            rows.append(row)

    df_scores = pd.DataFrame(rows)
    df_scores[["accuracy", "f1_score", "precision", "recall"]] = df_scores[
        ["accuracy", "f1_score", "precision", "recall",]
    ].round(2)

    return df_scores

In [None]:
average_score_result = get_mean_score_result(lape_all_result, df_metrics)

In [None]:
average_score_result

In [None]:
# Macro average
macro_averages = average_score_result.groupby("Score Type")[["accuracy", "precision", "recall", "f1_score"]].mean().round(2)
macro_averages

In [None]:
# Check average scores
macro_averages.reset_index()[
    ["accuracy", "precision", "recall", "f1_score"]
].mean().round(
    2
)

In [None]:
# Micro average
df_metrics.groupby("score_type")[["accuracy", "f1_score", "precision", "recall"]].mean().round(2)

In [None]:
output_path = project_dir / "interpret_sae_features" / "scores" / "average_scores.csv"
average_score_result.to_csv(output_path, index=False)

##### Median

In [None]:
def get_median_score_result(lape_result, df_metrics):
    score_results = {}

    sorted_lang = lape_result["sorted_lang"]

    for lang_idx, lang in enumerate(sorted_lang):
        lang_final_indices = lape_result["final_indice"][lang_idx]

        combined_df = None

        for layer_idx, _ in enumerate(lang_final_indices):
            lang_layer_final_indices = lang_final_indices[layer_idx].tolist()
            layer_str = f"layers.{layer_idx}.mlp"

            combined_df = pd.concat(
                [
                    combined_df,
                    df_metrics.query(
                        "layer == @layer_str and latent_idx in @lang_layer_final_indices"
                    ),
                ]
            )

        score_results[lang] = (
            combined_df.groupby("score_type")[
                [
                    "accuracy",
                    "f1_score",
                    "precision",
                    "recall",
                ]
            ]
            .agg(
                [
                    ("q1", lambda x: x.quantile(0.25)),
                    ("median", "median"),
                    ("q3", lambda x: x.quantile(0.75)),
                ]
            )
            .to_dict(orient="index")
        )

    rows = []

    for lang, score_types in score_results.items():
        for score_type, metrics in score_types.items():
            row = {"Language": lang, "Score Type": score_type}
            row.update(metrics)
            rows.append(row)

    df_scores = pd.DataFrame(rows)
    df_scores.columns = [
        f"{col[0]}_{col[1]}" if isinstance(col, tuple) else col for col in df_scores.columns
    ]
    metric_cols = [
        "accuracy_q1", "accuracy_median", "accuracy_q3",
        "f1_score_q1", "f1_score_median", "f1_score_q3",
        "precision_q1", "precision_median", "precision_q3",
        "recall_q1", "recall_median", "recall_q3",
    ]
    df_scores[metric_cols] = df_scores[metric_cols].round(2)

    return df_scores

In [None]:
median_score_result = get_median_score_result(lape_all_result, df_metrics)

In [None]:
median_score_result

In [None]:
# Macro average
metric_cols = [
    "accuracy_q1",
    "accuracy_median",
    "accuracy_q3",
    "f1_score_q1",
    "f1_score_median",
    "f1_score_q3",
    "precision_q1",
    "precision_median",
    "precision_q3",
    "recall_q1",
    "recall_median",
    "recall_q3",
]

macro_averages = median_score_result.groupby("Score Type")[metric_cols].mean().round(2)
macro_averages

In [None]:
# Median of median
macro_averages = median_score_result.groupby("Score Type")[metric_cols].median().round(2)
macro_averages

In [None]:
# Overall Median
df_metrics.groupby("score_type")[["accuracy", "f1_score", "precision", "recall"]].agg(
    [
        ("q1", lambda x: x.quantile(0.25)),
        ("median", "median"),
        ("q3", lambda x: x.quantile(0.75)),
    ]
).round(2)

#### Features similarity (IoU and Pearson)

In [None]:
data_path_dataset_token_activations_xnli = (
    statistic_dir
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / config_xnli["dataset"]
    / "dataset_token_activations"
)

data_path_dataset_token_activations_pawsx = (
    statistic_dir
    / config_pawsx["model"]
    / config_pawsx["sae"]["model"]
    / config_pawsx["dataset"]
    / "dataset_token_activations"
)

data_path_dataset_token_activations_flores = (
    statistic_dir
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / config_flores["dataset"]
    / "dataset_token_activations"
)

In [None]:
lape_all_result["sorted_lang"].index("Hindi")

In [None]:
output_dir = (
    project_dir
    / "visualization"
    / "similarity"
    / config_flores["model"]
    / config_flores["sae"]["model"]
)

task_configs = {
    "xnli": {
        "path": data_path_dataset_token_activations_xnli,
        "config": config_xnli,
    },
    "paws-x": {
        "path": data_path_dataset_token_activations_pawsx,
        "config": config_pawsx,
    },
    "flores": {
        "path": data_path_dataset_token_activations_flores,
        "config": config_flores,
    },
}

plot_features_similarity(
    lape_all_result,
    config_flores["layers"],
    output_dir,
    task_configs,
    start_index=5,
    end_index=6,
)

#### Cosine Similarity

##### Cosine Similarity of Language-Specific Features

In [None]:
sorted_lang = lape_all_result["sorted_lang"]

lang_to_sae_features = {
    lang: {
        "final_indices": [],
        "stacked_sae_features": [],
    }
    for lang in sorted_lang
}


for lang_idx, lang in enumerate(sorted_lang):
    lang_final_indices = lape_all_result["final_indice"][lang_idx]
    lang_sae_features = lape_all_result["sae_features"][lang_idx]

    lang_to_sae_features[lang]["final_indices"] = [
        indices.tolist() for indices in lang_final_indices
    ]
    lang_to_sae_features[lang]["stacked_sae_features"] = [
        features.tolist() for features in lang_sae_features
    ]

In [None]:
import numpy as np
import plotly.graph_objects as go
from sklearn.metrics.pairwise import cosine_similarity

# Collect all features and create language boundaries
all_features = []
all_feature_indices = []
all_layer_numbers = []
language_boundaries = []
current_position = 0
lang_names_for_features = []

# Stack all features from all languages
for lang in sorted_lang:
    lang_features = lang_to_sae_features[lang]["stacked_sae_features"]
    lang_indices = lang_to_sae_features[lang]["final_indices"]

    # Each element in stacked_sae_features corresponds to a layer/position
    # and contains feature vectors for the indices found at that position
    for layer_idx, feature_vectors in enumerate(lang_features):
        layer_indices = lang_indices[
            layer_idx
        ]  # Get corresponding indices for this layer

        # feature_vectors is a list of feature vectors for this layer
        for feat_idx, feature_vector in enumerate(feature_vectors):
            all_features.append(feature_vector)
            # Get the actual feature index from final_indices
            if feat_idx < len(layer_indices):
                actual_index = layer_indices[feat_idx]
            else:
                actual_index = f"unknown_{current_position}"
            all_feature_indices.append(actual_index)
            all_layer_numbers.append(layer_idx)  # Store the layer number
            lang_names_for_features.append(lang)
            current_position += 1

    # Mark the boundary after this language
    language_boundaries.append(current_position)

# Convert to numpy array for cosine similarity calculation
features_array = np.array(all_features)
print(f"Feature array shape: {features_array.shape}")

# Calculate cosine similarity matrix
similarity_matrix = cosine_similarity(features_array)

# Create the heatmap
fig = go.Figure(
    data=go.Heatmap(
        z=similarity_matrix,
        colorscale="RdBu",
        zmid=0,  # Center the colorscale at 0
        colorbar=dict(
            title="Cosine Similarity",
        ),
        hovertemplate="Lang: %{customdata[0]}<br>Layer: %{customdata[1]}, Index: %{customdata[2]}<br>Lang: %{customdata[3]}<br>Layer: %{customdata[4]}, Index: %{customdata[5]}<br>Similarity: %{z:.3f}<extra></extra>",
        customdata=np.array(
            [
                [
                    lang_names_for_features[i],
                    all_layer_numbers[i],
                    all_feature_indices[i],
                    lang_names_for_features[j],
                    all_layer_numbers[j],
                    all_feature_indices[j],
                ]
                for i in range(len(all_feature_indices))
                for j in range(len(all_feature_indices))
            ]
        ).reshape(len(all_feature_indices), len(all_feature_indices), 6),
    )
)

# Add language boundaries as lines
boundary_color = "black"
boundary_width = 2

# Add vertical lines for language boundaries
for boundary in language_boundaries[:-1]:  # Exclude the last boundary (end of data)
    fig.add_vline(
        x=boundary - 0.5,
        line=dict(color=boundary_color, width=boundary_width),
        layer="above",
    )

# Add horizontal lines for language boundaries
for boundary in language_boundaries[:-1]:
    fig.add_hline(
        y=boundary - 0.5,
        line=dict(color=boundary_color, width=boundary_width),
        layer="above",
    )

# Create language labels for the plot
lang_positions = []
lang_labels = []
prev_boundary = 0

for i, (lang, boundary) in enumerate(zip(sorted_lang, language_boundaries)):
    # Calculate the middle position for each language section
    middle_pos = (prev_boundary + boundary) / 2
    lang_positions.append(middle_pos)
    lang_labels.append(lang)
    prev_boundary = boundary

# Update layout
fig.update_layout(
    title={
        "text": "Cosine Similarity Heatmap of Language-Specific Features",
        "x": 0.5,
        "xanchor": "center",
        "font": {"size": 16},
    },
    xaxis=dict(
        title="Feature Index",
        tickfont=dict(size=6),
        # Add language labels at appropriate positions
        tickmode="array",
        tickvals=lang_positions,
        ticktext=lang_labels,
        tickangle=45,
    ),
    yaxis=dict(
        title="Feature Index",
        tickfont=dict(size=6),
        # Add language labels at appropriate positions
        tickmode="array",
        tickvals=lang_positions,
        ticktext=lang_labels,
        autorange="reversed",  # Reverse y-axis to match typical matrix visualization
    ),
    # width=800,
    height=1000,
    font=dict(size=10),
)

# Add annotations for language boundaries
annotations = []
for i, lang in enumerate(sorted_lang):
    # Add language labels on the diagonal
    pos = lang_positions[i]
    annotations.append(
        dict(
            x=pos,
            y=pos,
            text=lang,
            showarrow=False,
            font=dict(color="white", size=8, family="Arial Black"),
            bgcolor="rgba(0,0,0,0.7)",
            bordercolor="white",
            borderwidth=1,
        )
    )

fig.update_layout(annotations=annotations)

# Show the plot
fig.show()

# Optional: Print some statistics
print(f"\nSimilarity Matrix Statistics:")
print(f"Shape: {similarity_matrix.shape}")
print(f"Min similarity: {similarity_matrix.min():.3f}")
print(f"Max similarity: {similarity_matrix.max():.3f}")
print(f"Mean similarity: {similarity_matrix.mean():.3f}")

# Print language boundaries for reference
print(f"\nLanguage boundaries:")
for i, (lang, boundary) in enumerate(zip(sorted_lang, language_boundaries)):
    start = language_boundaries[i - 1] if i > 0 else 0
    feature_count = boundary - start
    print(f"{lang}: features {start} to {boundary-1} (total: {feature_count} features)")

    # Show some example feature indices and layers for this language
    lang_feature_indices = all_feature_indices[start:boundary]
    lang_layer_numbers = all_layer_numbers[start:boundary]
    examples = [
        (layer, idx)
        for layer, idx in zip(lang_layer_numbers[:5], lang_feature_indices[:5])
    ]
    print(f"  Example (layer, index): {examples}{'...' if feature_count > 5 else ''}")

# Print total feature count
print(f"\nTotal features collected: {len(all_features)}")
print(f"Languages: {len(sorted_lang)}")
print(
    f"Feature indices range: {min(all_feature_indices)} to {max(all_feature_indices)}"
)
print(f"Layer numbers range: {min(all_layer_numbers)} to {max(all_layer_numbers)}")

In [None]:
from sparsify import Sae

sae_vectors = {}

layers = [
    "layers.0.mlp",
    "layers.1.mlp",
    "layers.2.mlp",
    "layers.3.mlp",
    "layers.4.mlp",
    "layers.5.mlp",
    "layers.6.mlp",
    "layers.7.mlp",
    "layers.8.mlp",
    "layers.9.mlp",
    "layers.10.mlp",
    "layers.11.mlp",
    "layers.12.mlp",
    "layers.13.mlp",
    "layers.14.mlp",
    "layers.15.mlp",
]

for layer in layers:
    sae = Sae.load_from_hub("EleutherAI/sae-Llama-3.2-1B-131k", hookpoint=layer)

    sae_vectors[layer] = {
        "bias": sae.b_dec.detach(),
    }

    del sae

In [None]:
records = {
    "Language": [],
    "Layer": [],
    "Index": [],
    "Similarity": [],
}

for lang, layer_final_indices in lang_to_sae_features.items():
    for layer_idx, final_indices in enumerate(layer_final_indices["final_indices"]):
        layer_str = f"layers.{layer_idx}.mlp"
        bias_vector = sae_vectors[layer_str]["bias"]

        # Get the feature vectors for this language and layer
        feature_vectors = layer_final_indices["stacked_sae_features"][layer_idx]
        lang_layer_feature_indices = layer_final_indices["final_indices"][layer_idx]

        if len(feature_vectors) == 0:
            continue

        # Calculate cosine similarity between bias vector and feature vectors
        similarity_scores = cosine_similarity(bias_vector.unsqueeze(0), feature_vectors)

        for idx, score in zip(lang_layer_feature_indices, similarity_scores[0]):
            records["Language"].append(lang)
            records["Layer"].append(layer_str)
            records["Index"].append(idx)
            records["Similarity"].append(score.item())

In [None]:
output_path = project_dir / "hfls" / "sae_bias_similarity.csv"

pd.DataFrame(records).to_csv(output_path, index=False, float_format="%.3f")

##### Cosine Similarity of Language-Specific Features in particular layers

In [None]:
from sparsify import Sae

sae_vectors = {}

layers = [
    "layers.0.mlp",
    "layers.1.mlp",
    "layers.2.mlp",
    "layers.3.mlp",
    "layers.4.mlp",
    "layers.5.mlp",
    "layers.6.mlp",
    "layers.7.mlp",
    "layers.8.mlp",
    "layers.9.mlp",
    "layers.10.mlp",
    "layers.11.mlp",
    "layers.12.mlp",
    "layers.13.mlp",
    "layers.14.mlp",
    "layers.15.mlp",
]

sorted_lang = lape_all_result["sorted_lang"]

records = {
    "Language": [],
    "Layer": [],
    "Feature Index": [],
    "Opposing Feature Index": [],
    "Cosine Similarity": [],
}

for layer in layers:
    print(f"Processing layer: {layer}")

    sae = Sae.load_from_hub("EleutherAI/sae-Llama-3.2-1B-131k", hookpoint=layer)
    layer_idx = hookpoint_to_layer[layer]

    for lang_index, lang in enumerate(sorted_lang):
        lang_layer_sae_features = lape_all_result["sae_features"][lang_index][layer_idx]
        lang_layer_sae_feature_indices = lape_all_result["final_indice"][lang_index][
            layer_idx
        ]

        if len(lang_layer_sae_features) == 0:
            continue

        similarity_matrix = cosine_similarity(
            lang_layer_sae_features, sae.W_dec.detach()
        )

        opposing_vectors = similarity_matrix.argmin(axis=-1)
        opposing_values = similarity_matrix.min(axis=-1)

        for feature, opposing_vector, opposing_value in zip(
            lang_layer_sae_feature_indices, opposing_vectors, opposing_values
        ):
            records["Language"].append(lang)
            records["Layer"].append(layer)
            records["Feature Index"].append(feature.item())
            records["Opposing Feature Index"].append(opposing_vector.item())
            records["Cosine Similarity"].append(opposing_value.item())

In [None]:
output_path = project_dir / "hfls" / "sae_opposing_features.csv"

os.makedirs(output_path.parent, exist_ok=True)

pd.DataFrame(records).to_csv(output_path, index=False, float_format="%.3f")

In [None]:
data_path_dataset_token_activations_xnli = (
    statistic_dir
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / config_xnli["dataset"]
    / "dataset_token_activations"
)

data_path_dataset_token_activations_pawsx = (
    statistic_dir
    / config_pawsx["model"]
    / config_pawsx["sae"]["model"]
    / config_pawsx["dataset"]
    / "dataset_token_activations"
)

data_path_dataset_token_activations_flores = (
    statistic_dir
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / config_flores["dataset"]
    / "dataset_token_activations"
)

In [None]:
def feature_count_and_intersection(
    feature_indices: tuple[int, int],
    layer: str,
):
    feature_token_count = 0
    opposing_feature_token_count = 0
    intersection_count = 0

    for config, data_path in [
        (config_xnli, data_path_dataset_token_activations_xnli),
        (config_pawsx, data_path_dataset_token_activations_pawsx),
        (config_flores, data_path_dataset_token_activations_flores),
    ]:
        lang_to_dataset_token_activations_xnli = load_lang_to_dataset_token_activations(
            data_path,
            layer,
            config["languages"],
            feature_indices,
        )

        for (
            _,
            lang_to_dataset_token_activations,
        ) in lang_to_dataset_token_activations_xnli.items():
            feature = lang_to_dataset_token_activations.query(
                "index == @feature_indices[0]"
            )
            opposing_feature = lang_to_dataset_token_activations.query(
                "index == @feature_indices[1]"
            )

            if not feature.empty:
                feature_token_count += feature["count"].item()

            if not opposing_feature.empty:
                opposing_feature_token_count += opposing_feature["count"].item()

            if not feature.empty and not opposing_feature.empty:
                feature_dataset_row_id_token_id = {
                    (row_id, token_id)
                    for row_id, token_id, _ in literal_eval(
                        feature["dataset_row_id_token_id_act_val"].item()
                    )
                }

                opposing_feature_dataset_row_id_token_id = {
                    (row_id, token_id)
                    for row_id, token_id, _ in literal_eval(
                        opposing_feature["dataset_row_id_token_id_act_val"].item()
                    )
                }

                intersection_count += len(
                    feature_dataset_row_id_token_id
                    & opposing_feature_dataset_row_id_token_id
                )

    return (feature_token_count, opposing_feature_token_count, intersection_count)

In [None]:
output_path = project_dir / "hfls" / "sae_opposing_features.csv"

records = pd.read_csv(
    output_path,
)

In [None]:
records[
    ["Feature Token Count", "Opposing Feature Token Count", "Intersection Count"]
] = records[["Feature Index", "Opposing Feature Index", "Layer"]].apply(
    lambda row: feature_count_and_intersection(
        (row["Feature Index"], row["Opposing Feature Index"]),
        f"model.{row['Layer']}",
    ),
    axis=1,
    result_type="expand",
)

In [None]:
pd.DataFrame(records)

In [None]:
output_path = project_dir / "hfls" / "sae_opposing_features.csv"

os.makedirs(output_path.parent, exist_ok=True)

pd.DataFrame(records).to_csv(
    output_path,
    index=False,
)

##### Top 10 Tokens of Features x W_embed

In [None]:
from nnsight import LanguageModel

llm = LanguageModel("meta-llama/Llama-3.2-1B", device_map="cpu", dispatch=True)

In [None]:
top_token_records = {
    "Lang": [],
    "Layer": [],
    "Feature ID": [],
    "Top Tokens": [],
}

for lang, layer_final_indices in lang_to_sae_features.items():
    print(f"Language: {lang}")
    for layer_idx, final_indices in enumerate(layer_final_indices["final_indices"]):
        layer_str = f"layers.{layer_idx}.mlp"

        # Get the feature vectors for this language and layer
        feature_vectors = layer_final_indices["stacked_sae_features"][layer_idx]

        if len(feature_vectors) == 0:
            continue

        # Feed the feature vectors through the lm_head to get token logits
        with torch.no_grad():
            norm = llm.model.norm(torch.tensor(feature_vectors))
            logits = llm.lm_head(norm)

        # Get the top 20 tokens for each feature vector
        top_token_indices = torch.topk(logits, 10, dim=-1).indices

        # Print the language, layer, and top tokens for each feature
        for i, feature_idx in enumerate(feature_vectors):
            top_tokens = top_token_indices[i].tolist()
            token_strings = [llm.tokenizer.decode([idx]) for idx in top_tokens]
            token_display = token_strings

            top_token_records["Lang"].append(lang)
            top_token_records["Layer"].append(layer_str)
            top_token_records["Feature ID"].append(final_indices[i])
            top_token_records["Top Tokens"].append(token_display)

top_token_df = pd.DataFrame(top_token_records)

In [None]:
import fasttext

try:
    model_path = "lid.176.bin"
    lang_model = fasttext.load_model(model_path)
except:
    import urllib.request

    url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin"
    local_path = Path("lid.176.bin")

    if not local_path.exists():
        urllib.request.urlretrieve(url, local_path)

    lang_model = fasttext.load_model(str(local_path))

In [None]:
def token_is_language(token, lang):
    candidate_languages, _ = lang_model.predict(token.replace("\n", ""), k=1)
    iso_code_lang = lang_choices_to_iso639_1[lang]

    return f"__label__{iso_code_lang}" in candidate_languages

In [None]:
def create_language_excel(top_token_df, output_filename="language_feature_tokens.xlsx"):
    if not isinstance(top_token_df, pd.DataFrame):
        print("Error: Input 'top_token_df' must be a pandas DataFrame.")
        return

    required_columns = ["Lang", "Layer", "Feature ID", "Top Tokens"]
    if not all(col in top_token_df.columns for col in required_columns):
        print(
            f"Error: DataFrame must contain the following columns: {required_columns}"
        )
        return

    try:
        # Create a Pandas Excel writer using XlsxWriter as the engine.
        with pd.ExcelWriter(output_filename, engine="xlsxwriter") as writer:
            # Get unique languages and sort them for consistent sheet order
            if "sorted_lang" in top_token_df.columns:
                unique_langs = sorted(top_token_df["Lang"].unique())
            else:
                unique_langs = sorted(top_token_df["Lang"].unique())

            for lang in unique_langs:
                # Filter DataFrame for the current language
                lang_df = top_token_df[top_token_df["Lang"] == lang].copy()

                if lang_df.empty:
                    print(
                        f"No data found for language: {lang}. Skipping sheet creation."
                    )
                    continue

                # Ensure the columns are in the desired order for the Excel sheet:
                df_to_write = lang_df[["Layer", "Lang", "Feature ID", "Top Tokens"]]

                # Sanitize sheet name
                sheet_name = str(lang)
                invalid_chars = ["[", "]", "*", "?", ":", "/"]
                for char in invalid_chars:
                    sheet_name = sheet_name.replace(char, "_")
                if len(sheet_name) > 31:
                    sheet_name = sheet_name[:31]

                # Write the dataframe to a new sheet. Pandas writes data and default headers.
                df_to_write.to_excel(
                    writer, sheet_name=sheet_name, index=False, startrow=0
                )

                # Get the xlsxwriter workbook and worksheet objects.
                workbook = writer.book
                worksheet = writer.sheets[sheet_name]

                # --- Define Formats ---
                # Header format (bold, centered, border)
                header_format = workbook.add_format(
                    {
                        "bold": True,
                        "align": "center",
                        "valign": "vcenter",
                        "border": 1,
                    }
                )

                # Format for 'Layer', 'Lang', 'Feature ID' data cells (centered, bordered)
                data_center_bordered_format = workbook.add_format(
                    {
                        "align": "center",
                        "valign": "vcenter",
                        "border": 1,  # Apply border to all data cells in these columns
                    }
                )

                # Format for 'Top Tokens' data cells (wrap text, top-aligned, bordered)
                data_wrap_bordered_format = workbook.add_format(
                    {
                        "text_wrap": True,
                        "valign": "top",
                        "border": 1,  # Apply border to all data cells in this column
                    }
                )

                # --- Apply Column Widths ---
                worksheet.set_column("A:A", 17.56)  # Layer
                worksheet.set_column("B:B", 8.67)  # Lang
                worksheet.set_column("C:C", 11.89)  # Feature ID
                worksheet.set_column("D:D", 60)  # Top Tokens

                # --- Apply Border Only to Header and Filled Rows ---
                n_rows = len(df_to_write)
                # Write header with border
                for col_num, value in enumerate(df_to_write.columns.values):
                    worksheet.write(0, col_num, value, header_format)
                # Write data with border
                for row in range(n_rows):
                    worksheet.write(
                        row + 1,
                        0,
                        df_to_write.iloc[row, 0],
                        data_center_bordered_format,
                    )
                    worksheet.write(
                        row + 1,
                        1,
                        df_to_write.iloc[row, 1],
                        data_center_bordered_format,
                    )
                    worksheet.write(
                        row + 1,
                        2,
                        df_to_write.iloc[row, 2],
                        data_center_bordered_format,
                    )

                    top_tokens = []

                    # Format tokens in 'Top Tokens' column
                    bold_format = workbook.add_format({"bold": True})

                    for token in df_to_write.iloc[row, 3]:
                        if token_is_language(token, lang):
                            top_tokens.append(bold_format)
                            top_tokens.append(token)
                            top_tokens.append(", ")
                        else:
                            top_tokens.append(token)
                            top_tokens.append(", ")

                    top_tokens = top_tokens[:-1]  # Remove the last comma
                    worksheet.write_rich_string(
                        row + 1, 3, *top_tokens, data_wrap_bordered_format
                    )

        print(
            f"Excel file '{output_filename}' created successfully. Borders applied only to header and filled rows of each sheet."
        )

    except Exception as e:
        print(f"An error occurred: {e}")


output_path = (
    project_dir
    / "top_token_features"
    / config_xnli["model"]
    / config_xnli["sae"]["model"]
    / "top_token_features.xlsx"
)

os.makedirs(output_path.parent, exist_ok=True)

# NOTE: the generated excel must be refined
create_language_excel(top_token_df, output_path)

In [None]:
for lang, layer_final_indices in lang_to_sae_features.items():
    print(f"Language: {lang}")
    for layer_idx, final_indices in enumerate(layer_final_indices["final_indices"]):
        layer_str = f"layers.{layer_idx}.mlp"
        bias_vector = sae_vectors[layer_str]["bias"]

        # Get the feature vectors for this language and layer
        feature_vectors = layer_final_indices["stacked_sae_features"][layer_idx]

        if len(feature_vectors) == 0:
            continue

        # Feed the feature vectors through the lm_head to get token logits
        with torch.no_grad():
            norm = llm.model.norm(torch.tensor(feature_vectors).sum(dim=0))
            logits = llm.lm_head(norm)

        # Get the top 20 tokens for each feature vector
        top_token_indices = torch.topk(logits, 10, dim=-1).indices

        # Print the language, layer, and top tokens for each feature
        top_tokens = top_token_indices.tolist()
        token_strings = [llm.tokenizer.decode([idx]) for idx in top_tokens]
        token_display = ", ".join([f"{token}" for token in token_strings])
        print(f"Layer: {layer_str}")
        print(f"  Top tokens: {token_display}")

#### Entropies and scores

In [None]:
sae_features_info = load_sae_features_info_df(
    lape_all_result,
    config_flores["layers"],
    metrics,
)

In [None]:
output_dir = (
    project_dir
    / "visualization"
    / "correlation"
    / config_flores["model"]
    / config_flores["sae"]["model"]
)

plot_sae_features_entropy_score_correlation(
    sae_features_info,
    output_dir,
)

#### Entropy Distribution

In [None]:
output_dir = (
    project_dir
    / "visualization"
    / "entropy"
    / config_flores["model"]
    / config_flores["sae"]["model"]
)

plot_entropy_distribution(
    sae_features_info,
    output_dir,
)

### All Language-Specific Features Information

In [None]:
lang_to_sae_features_info = load_lang_to_sae_features_info(
    lape_all_result,
    config_flores["layers"],
    interpretations,
    metrics,
)

In [None]:
output_dir = project_dir / "interpret_sae_features" / "language_specific_features"

with open(output_dir / "lang_to_sae_features_info.json", "w") as f:
    json.dump(lang_to_sae_features_info, f, indent=4)

In [None]:
lang_to_sae_features_info_extra = load_lang_to_sae_features_info(
    lape_all_result,
    config_flores["layers"],
    interpretations,
    metrics,
    extra=True,
)

In [None]:
output_dir = project_dir / "interpret_sae_features" / "language_specific_features"

with open(output_dir / "lang_to_sae_features_info_extra.json", "w") as f:
    json.dump(lang_to_sae_features_info_extra, f, indent=4)

In [None]:
import xlsxwriter

workbook = xlsxwriter.Workbook(output_dir / "lang_to_sae_features_info_extra.xlsx")

# Define cell formats
header_format = workbook.add_format(
    {
        "bold": True,
        "text_wrap": True,
        "valign": "vcenter",
        "align": "center",
        "border": 1,
    }
)

center_aligned_format = workbook.add_format(
    {
        "align": "center",
        "valign": "vcenter",
        "border": 1,
    }
)

left_wrap_format = workbook.add_format(
    {
        "align": "left",
        "valign": "vcenter",
        "text_wrap": True,
        "border": 1,
    }
)

# Iterate over each language in the JSON data
for lang_name, layers_data in lang_to_sae_features_info_extra.items():
    # Add a new worksheet for each language. Sheet names have a max length of 31.
    worksheet_name = lang_name if len(lang_name) < 32 else lang_name[:31]
    worksheet = workbook.add_worksheet(worksheet_name)

    # Set column widths
    worksheet.set_column("A:A", 17.56)  # Layer
    worksheet.set_column("B:B", 8.67)  # Lang
    worksheet.set_column("C:C", 11.89)  # Feature ID
    worksheet.set_column("D:D", 52.22)  # Interpretation
    worksheet.set_column("E:H", 8)  # Detection metrics
    worksheet.set_column("I:L", 8)  # Fuzzing metrics

    # Write headers
    # Row 1: Main Headers
    worksheet.merge_range("A1:A2", "Layer", header_format)
    worksheet.merge_range("B1:B2", "Lang", header_format)
    worksheet.merge_range("C1:C2", "Feature ID", header_format)
    worksheet.merge_range("D1:D2", "Interpretation", header_format)

    # Merged cells for Detection and Fuzzing
    worksheet.merge_range("E1:H1", "Detection", header_format)
    worksheet.merge_range("I1:L1", "Fuzzing", header_format)

    # Row 2: Sub-headers for metrics
    metric_sub_headers = ["Accuracy", "F1 score", "Precision", "Recall"]
    detection_start_col = 4  # Column E
    fuzzing_start_col = 8  # Column I

    for i, sub_header in enumerate(metric_sub_headers):
        worksheet.write(1, detection_start_col + i, sub_header, header_format)
        worksheet.write(1, fuzzing_start_col + i, sub_header, header_format)

    # Start writing data from the third row (index 2)
    current_row = 2
    for layer_key, feature_id_dict in layers_data.items():
        for fid_key, details in feature_id_dict.items():
            # Extract basic info
            layer_val = details.get("Layer", "")
            lang_val = details.get("Lang", "")
            feature_id_val = details.get("Feature ID", "")
            interpretation_val = details.get("Interpretation", "")
            # Remove surrounding quotes from interpretation if they exist due to json dump
            if interpretation_val.startswith('"') and interpretation_val.endswith('"'):
                interpretation_val = interpretation_val[1:-1]

            # Extract metrics
            detection_metrics = {}
            fuzz_metrics = {}
            for metric_set in details.get("Metrics", []):
                if metric_set.get("score_type") == "detection":
                    detection_metrics = {
                        "accuracy": metric_set.get("accuracy"),
                        "f1_score": metric_set.get("f1_score"),
                        "precision": metric_set.get("precision"),
                        "recall": metric_set.get("recall"),
                    }
                elif metric_set.get("score_type") == "fuzz":
                    fuzz_metrics = {
                        "accuracy": metric_set.get("accuracy"),
                        "f1_score": metric_set.get("f1_score"),
                        "precision": metric_set.get("precision"),
                        "recall": metric_set.get("recall"),
                    }

            # Write data to cells with specified formats
            worksheet.write(current_row, 0, layer_val, center_aligned_format)
            worksheet.write(current_row, 1, lang_val, center_aligned_format)
            worksheet.write(current_row, 2, feature_id_val, center_aligned_format)
            worksheet.write(current_row, 3, interpretation_val, left_wrap_format)

            # Write Detection metrics
            worksheet.write(
                current_row,
                detection_start_col + 0,
                detection_metrics.get("accuracy"),
                center_aligned_format,
            )
            worksheet.write(
                current_row,
                detection_start_col + 1,
                detection_metrics.get("f1_score"),
                center_aligned_format,
            )
            worksheet.write(
                current_row,
                detection_start_col + 2,
                detection_metrics.get("precision"),
                center_aligned_format,
            )
            worksheet.write(
                current_row,
                detection_start_col + 3,
                detection_metrics.get("recall"),
                center_aligned_format,
            )

            # Write Fuzzing metrics
            worksheet.write(
                current_row,
                fuzzing_start_col + 0,
                fuzz_metrics.get("accuracy"),
                center_aligned_format,
            )
            worksheet.write(
                current_row,
                fuzzing_start_col + 1,
                fuzz_metrics.get("f1_score"),
                center_aligned_format,
            )
            worksheet.write(
                current_row,
                fuzzing_start_col + 2,
                fuzz_metrics.get("precision"),
                center_aligned_format,
            )
            worksheet.write(
                current_row,
                fuzzing_start_col + 3,
                fuzz_metrics.get("recall"),
                center_aligned_format,
            )

            current_row += 1

workbook.close()

## Langauage-Shared Features Visualization

#### Shared-Feature Intersection and Jaccard Similarity

In [None]:
lang_to_count_final_indicies = {}

for shared_count in range(2, len(config_flores["languages"]) + 1):
    lape_shared_path = (
        project_dir
        / "sae_features_shared"
        / config_xnli["model"]
        / config_xnli["sae"]["model"]
        / f"lape_shared_{shared_count}.pt"
    )

    lape_shared_result = torch.load(lape_shared_path)

    sorted_lang = lape_shared_result["sorted_lang"]

    for lang in sorted_lang:
        if lang not in lape_shared_result["features_info"]:
            print(
                f"Language {lang} not found in features_info for shared count {shared_count}."
            )
            continue

        lang_indices = lape_shared_result["features_info"][lang]["indicies"]

        if lang not in lang_to_count_final_indicies:
            lang_to_count_final_indicies[lang] = set()

        lang_to_count_final_indicies[lang].update(lang_indices)

In [None]:
output_path = (
    project_dir
    / "visualization"
    / "shared_features"
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / "feature_intersection_heatmap.html"
)

plot_intersection_heatmap(lang_to_count_final_indicies, output_path)

In [None]:
output_path = (
    project_dir
    / "visualization"
    / "shared_features"
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / "iou_heatmap.html"
)


plot_iou_heatmap(lang_to_count_final_indicies, output_path)

#### Shared-Feature Distribution

In [None]:
count_final_indicies = {}

for shared_count in range(2, len(config_flores["languages"]) + 1):
    lape_shared_path = (
        project_dir
        / "sae_features_shared"
        / config_xnli["model"]
        / config_xnli["sae"]["model"]
        / f"lape_shared_{shared_count}.pt"
    )

    lape_shared_result = torch.load(lape_shared_path)

    sorted_lang = lape_shared_result["sorted_lang"]

    for lang in sorted_lang:
        if lang not in lape_shared_result["features_info"]:
            print(
                f"Language {lang} not found in features_info for shared count {shared_count}."
            )
            continue

        lang_indices = lape_shared_result["features_info"][lang]["indicies"]

        if shared_count not in count_final_indicies:
            count_final_indicies[shared_count] = set()

        count_final_indicies[shared_count].update(lang_indices)

In [None]:
output_path = (
    project_dir
    / "visualization"
    / "shared_features"
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / "feature_distribution.html"
)

plot_shared_count_bar_chart(count_final_indicies, output_path)

In [None]:
count_final_indicies_half_low = {}
count_final_indicies_half_high = {}

half_low_indices = [2, 3, 4, 5, 6, 7, 8]
half_high_indices = [9, 10, 11, 12, 13, 14, 15]

for half_low_index in half_low_indices:
    count_final_indicies_half_low[half_low_index] = count_final_indicies[half_low_index]

for half_high_index in half_high_indices:
    count_final_indicies_half_high[half_high_index] = count_final_indicies[
        half_high_index
    ]

In [None]:
output_path = (
    project_dir
    / "visualization"
    / "shared_features"
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / "feature_distribution_half_low.html"
)

plot_shared_count_bar_chart(count_final_indicies_half_low, output_path)

In [None]:
output_path = (
    project_dir
    / "visualization"
    / "shared_features"
    / config_flores["model"]
    / config_flores["sae"]["model"]
    / "feature_distribution_half_high.html"
)

plot_shared_count_bar_chart(count_final_indicies_half_high, output_path)

#### PPL Change

In [None]:
normal_ppl_output_path = (
    project_dir
    / "ppl"
    / config_flores["model"]
    / config_flores["dataset"]
    / "normal"
    / "ppl.pt"
)

normal_ppl_result = torch.load(normal_ppl_output_path, weights_only=False)

config_lang_to_ppl_avg_ppl = {
    "normal": {
        lang: round(normal_ppl_result[lang]["mean_perplexity"], 2)
        for lang in normal_ppl_result
    }
}

In [None]:
config_lang_to_ppl_avg_ppl

In [None]:
shared_features_ppl_output_path_list = [
    Path(
        r"ppl/meta-llama/Llama-3.2-1B/openlanguagedata/flores_plus/shared_features/entropy/max/mult_-0.2"
    ),
    Path(
        r"ppl/meta-llama/Llama-3.2-1B/openlanguagedata/flores_plus/shared_features/entropy/max/mult_-0.3"
    ),
    Path(
        r"ppl/meta-llama/Llama-3.2-1B/openlanguagedata/flores_plus/shared_features/entropy/max/mult_-0.4"
    ),
    Path(
        r"ppl/meta-llama/Llama-3.2-1B/openlanguagedata/flores_plus/shared_features/entropy/max/mult_0.2"
    ),
    Path(
        r"ppl/meta-llama/Llama-3.2-1B/openlanguagedata/flores_plus/shared_features/entropy/max/mult_0.3"
    ),
    Path(
        r"ppl/meta-llama/Llama-3.2-1B/openlanguagedata/flores_plus/shared_features/entropy/max/mult_0.4"
    ),
]

shared_features_ppl_to_result = {
    path.name: torch.load(project_dir / path / "ppl.pt", weights_only=False)
    for path in shared_features_ppl_output_path_list
}

In [None]:
for config, ppl_results in shared_features_ppl_to_result.items():
    langs = list(ppl_results.keys())

    for intervened_lang_index, intervened_lang in enumerate(langs):

        if config not in config_lang_to_ppl_avg_ppl:
            config_lang_to_ppl_avg_ppl[config] = {}

        config_lang_to_ppl_avg_ppl[config][intervened_lang] = round(
            ppl_results[intervened_lang]["mean_perplexity"], 2
        )

In [None]:
import pandas as pd

pd.DataFrame(config_lang_to_ppl_avg_ppl)