# Heatmap & line_plot

In [None]:
import pyarrow as pa
import pyarrow.parquet as pq
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

In [None]:
RESULT_DIR = '02.results/00.Benchmark_Results/'
RESULT_NAME = ['99.Gemini/', 
               '01.Llama/',
               '02.Mistral/',
               '03.Qwen/Qwen-8B/',
               '03.Qwen/Qwen-14B/']               

In [None]:
BENCHMARK_DIR = '00.data/02.WikiBench/'
BENCHMARK_NAME = ['00.original_benchmark_TF_500.parquet',
                  '01.subject_shuffled_benchmark_TF_500.parquet',
                  '02.object_shuffled_benchmark_TF_500.parquet',
                  '03.property_scoped_subject_shuffled_benchmark_TF_500.parquet',
                  '04.property_scoped_object_shuffled_benchmark_TF_500.parquet']

In [None]:
BENCHMARK_GEMINI = [pq.read_table(f"{RESULT_DIR}{RESULT_NAME[0]}{name}").to_pandas() for name in BENCHMARK_NAME]
BENCHMARK_LLAMA = [pq.read_table(f"{RESULT_DIR}{RESULT_NAME[1]}{name}").to_pandas() for name in BENCHMARK_NAME]
BENCHMARK_MISTRAL = [pq.read_table(f"{RESULT_DIR}{RESULT_NAME[2]}{name}").to_pandas() for name in BENCHMARK_NAME]
BENCHMARK_QWEN3_8B = [pq.read_table(f"{RESULT_DIR}{RESULT_NAME[3]}{name}").to_pandas() for name in BENCHMARK_NAME]
BENCHMARK_QWEN3_14B = [pq.read_table(f"{RESULT_DIR}{RESULT_NAME[4]}{name}").to_pandas() for name in BENCHMARK_NAME]

BENCHMARK_RESULTS = [BENCHMARK_MISTRAL,
                     BENCHMARK_LLAMA,
                     BENCHMARK_GEMINI,
                     BENCHMARK_QWEN3_8B,
                     BENCHMARK_QWEN3_14B]

In [None]:
MODEL_NAMES = ["Mistral", "LLaMA", "Gemini", "Qwen3-8B", "Qwen3-14B"]

print(len(MODEL_NAMES), len(BENCHMARK_RESULTS))
BENCHMARK_LABELS = [
    "ORG",
    "SS",
    "SO",
    "PSS",
    "POS",
]


In [None]:
LANGUAGE_LIST = ['en','de','fr','es','it','pt','ko','ja']

In [None]:
pd.set_option('display.max_columns', None)

In [None]:
all_values = set()
for model_results in BENCHMARK_RESULTS:
    for lang in LANGUAGE_LIST:
        col = f"response_TF_{lang}"
        for df in model_results:
            if col in df.columns:
                all_values.update(df[col].unique())

print(all_values)
print("총 개수:", len(all_values))

In [None]:
def is_correct(df):
    if df.iloc[0]['kind'] == 'original':
        correct = 0
        wrong = 1
        unsure = 2
    else:
        correct = 1
        wrong = 0
        unsure = 2
    for lang in LANGUAGE_LIST:
        col_response_TF = f"response_TF_{lang}"
        col_correct = f"correct_{lang}"
        correct_list = []
        for response in df[col_response_TF].tolist():
            if response == "True":
                correct_list.append(correct)
            elif response == "<answer>True</answer>":
                correct_list.append(correct)
            elif response == "False":
                correct_list.append(wrong)
            elif response == "<answer>False</answer>":
                correct_list.append(wrong)
            else:
                correct_list.append(unsure)
        df[col_correct] = correct_list
    return df

In [None]:
for model_idx in range(0, len(BENCHMARK_RESULTS)):  
    for bench_idx in range(len(BENCHMARK_RESULTS[model_idx])):
        BENCHMARK_RESULTS[model_idx][bench_idx] = is_correct(BENCHMARK_RESULTS[model_idx][bench_idx])

In [None]:
df

In [None]:
MODEL_DFS = {
    "Gemini": BENCHMARK_GEMINI,
    "LLaMA": BENCHMARK_LLAMA,
    "Mistral": BENCHMARK_MISTRAL,
    "Qwen-8B": BENCHMARK_QWEN3_8B,
    "Qwen-14B": BENCHMARK_QWEN3_14B,
}

MODEL_TYPE = {
    "Gemini": "gemini",
    "LLaMA": "local",
    "Mistral": "local",
    "Qwen-8B": "local",
    "Qwen-14B": "local",
}


In [None]:
def compute_row_multilingual_accuracy(df):
    """
    각 row에서 8개 언어 중 correct==0 비율 (0~1)
    """
    corr_cols = [f"correct_{l}" for l in LANGUAGE_LIST]
    return df[corr_cols].eq(0).mean(axis=1)


In [None]:
BIN_LABELS = [
    "8/8",
    "7/8",
    "6/8",
    "5/8",
    "4/8",
    "3/8",
    "≤2/8"]

In [None]:
def multilingual_acc_to_bin(acc):
    """
    acc: multilingual accuracy in percentage (float)
    """
    if acc == 100.0:
        return "8/8"
    elif acc == 87.5:
        return "7/8"
    elif acc == 75.0:
        return "6/8"
    elif acc == 62.5:
        return "5/8"
    elif acc == 50.0:
        return "4/8"
    elif acc == 37.5:
        return "3/8"
    else:
        return "≤2/8"


In [None]:
def percent_bin_distribution_optionA(dfs_corrected, model_name):
    rows = []

    for fname, df in dfs_corrected.items():
        corr_cols = [f"correct_{l}" for l in LANGUAGE_LIST]

        # row 단위 multilingual 정확도 (%)
        acc_row = df[corr_cols].eq(0).mean(axis=1) * 100

        # bin 할당
        bins = acc_row.map(multilingual_acc_to_bin)

        # bin별 개수
        counts = bins.value_counts().to_dict()

        for b in BIN_LABELS:
            rows.append({
                "model": model_name,
                "file": fname,
                "bin": b,
                "count": counts.get(b, 0)
            })

    return pd.DataFrame(rows)


In [None]:
def build_corrected_dfs_for_model(model_name):
    dfs_raw = MODEL_DFS[model_name]   # BENCHMARK_GEMINI 등
    dfs_corr = {}

    for fname, df in zip(BENCHMARK_NAME, dfs_raw):
        dfs_corr[fname] = is_correct(df.copy())

    return dfs_corr

In [None]:
dfs_corr_gemini = build_corrected_dfs_for_model("Gemini")
bin_dist_gemini = percent_bin_distribution_optionA(dfs_corr_gemini, "Gemini")
bin_dist_gemini

# bin distribution 

In [None]:
def build_heatmap_data_by_bin():
    rows = []

    for model_name, model_results in zip(MODEL_NAMES, BENCHMARK_RESULTS):
        for bench_idx, df in enumerate(model_results):

            corr_cols = [f"correct_{l}" for l in LANGUAGE_LIST]
            acc_row = df[corr_cols].eq(0).mean(axis=1) * 100

            bins = acc_row.map(multilingual_acc_to_bin)
            df = df.assign(bin=bins)

            for b in BIN_LABELS:
                sub = df[df["bin"] == b]
                if len(sub) == 0:
                    continue

                for lang in LANGUAGE_LIST:
                    acc_lang = (sub[f"correct_{lang}"] == 0).mean()

                    rows.append({
                        "model": model_name,
                        "benchmark": BENCHMARK_LABELS[bench_idx],
                        "bin": b,
                        "lang": lang,
                        "accuracy": acc_lang,
                    })

    return pd.DataFrame(rows)


In [None]:
heatmap_all = build_heatmap_data_by_bin()
heatmap_all.head()


In [None]:
heatmap_all


In [None]:
LOCAL_MODEL_INFO = [
    ("Mistral", 0),
    ("LLaMA", 1),
    ("Qwen-8B", 3),
    ("Qwen-14B", 4),
]


In [None]:
all_bin_dists = []

for model_name, model_idx in LOCAL_MODEL_INFO:
    dfs_corr = {
        BENCHMARK_NAME[i]: BENCHMARK_RESULTS[model_idx][i]
        for i in range(len(BENCHMARK_NAME))
    }

    bin_dist = percent_bin_distribution_optionA(dfs_corr, model_name)
    all_bin_dists.append(bin_dist)

bin_dist_local_all = pd.concat(all_bin_dists, ignore_index=True)


In [None]:
bin_dist_local_all.head()
bin_dist_local_all.groupby(["model", "file"])["count"].sum()


In [None]:
bin_dist_local_all

In [None]:
dfs_corr_gemini = {
    BENCHMARK_NAME[i]: BENCHMARK_RESULTS[2][i]  
    for i in range(len(BENCHMARK_NAME))
}

bin_dist_gemini = percent_bin_distribution_optionA(
    dfs_corr_gemini,
    model_name="Gemini"
)


In [None]:
bin_dist_all = pd.concat(
    [bin_dist_local_all, bin_dist_gemini],
    ignore_index=True
)

bin_dist_all


In [None]:
from pathlib import Path

out_path = Path("03.notebooks/0208_result/Fig3/bin_distribution_0208.csv")
out_path.parent.mkdir(parents=True, exist_ok=True)

bin_dist_all.to_csv(out_path, index=False)

print(f"Saved to {out_path}")


In [None]:
model_order = [
    "Mistral",
    "LLaMA",
    "Gemini",
    "Qwen3-8B",
    "Qwen3-14B",
]


In [None]:
model_display_names = {
    "Mistral": "Mistral",
    "LLaMA": "LLaMA",
    "Gemini": "Gemini",
    "Qwen3-8B": "Qwen3-8B",
    "Qwen3-14B": "Qwen3-14B",
}


In [None]:
assert set(model_order) == set(heatmap_all["model"].unique())


# Heatmap

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def plot_heatmap_all_models_concatenated(
    heatmap_all: pd.DataFrame,

    model_order,
    model_display_names,

    save_dir="03.notebooks/0208_result/Fig3/",
    out_name="fig3_heatmap_0208_2.pdf",

    language_order=("en", "de", "fr", "es", "it", "pt", "ko", "ja"),
    benchmark_order=("ORG", "SS", "SO", "PSS", "POS"),
    bin_order=("8/8", "7/8", "6/8", "5/8", "4/8", "3/8", "≤2/8"),

    cmap="coolwarm_r",
    vmin=0.0,
    vmax=1.0,
):
    Path(save_dir).mkdir(parents=True, exist_ok=True)

    LANG_GROUPS = {
        "en": "g1", "de": "g1",
        "fr": "g2", "es": "g2", "it": "g2", "pt": "g2",
        "ko": "g3", "ja": "g3",
    }

    LANG_GROUP_TEXT_COLORS = {
        "g1": "#4A6FB3",
        "g2": "#3A7F5C",
        "g3": "#B24A4A",
    }

    rows = [(lang, bench) for lang in language_order for bench in benchmark_order]
    n_rows = len(rows)
    n_bins = len(bin_order)

    model_blocks = []
    model_centers = []

    for model in model_order:
        block_rows = []

        for lang, bench in rows:
            sub = heatmap_all[
                (heatmap_all["model"] == model)
                & (heatmap_all["lang"] == lang)
                & (heatmap_all["benchmark"] == bench)
            ]

            row = (
                sub.set_index("bin")
                   .reindex(bin_order)["accuracy"]
                   .to_numpy()
            )
            block_rows.append(row)

        block = np.vstack(block_rows)
        model_blocks.append(block)

        start_col = sum(b.shape[1] for b in model_blocks[:-1])
        model_centers.append(start_col + (n_bins - 1) / 2)

    full_matrix = np.concatenate(model_blocks, axis=1)

    fig, ax = plt.subplots(
        figsize=(len(model_order) * n_bins * 0.75, n_rows * 0.40)
    )

    im = ax.imshow(
        full_matrix,
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        aspect="auto",
    )

    for i in range(1, len(model_order)):
        x = i * n_bins - 0.5
        ax.vlines(x, -0.5, n_rows - 0.5, colors="white", linewidth=6)

    ax.set_xticks([])
    ax.set_yticks([])

    rows_per_lang = len(benchmark_order)

    for i, lang in enumerate(language_order):
        y_center = i * rows_per_lang + rows_per_lang / 2 - 0.5
        color = LANG_GROUP_TEXT_COLORS[LANG_GROUPS[lang]]

        ax.text(
            -2.5, y_center, lang.upper(),
            ha="center", va="center",
            fontsize=27, fontweight="bold",
            color=color,
        )
    
    BENCHMARK_DISPLAY = {
        "ORG": "Original",
        "SS": "Shuffled\nQID",
        "SO": "Shuffled\nQSUB",
        "PSS": "Shuffled\nQID by PID",
        "POS": "Shuffled\nQSUB by PID",
    }
    
    for row_idx, (lang, bench) in enumerate(rows):
        label = bench
    
        ax.text(
            -1.0,               
            row_idx,             
            label,
            ha="center",
            va="center",
            fontsize=12,
            color="#333333",
        )

    
    for i in range(1, len(language_order)):
        y = i * rows_per_lang - 0.5
        ax.hlines(y, -0.5, full_matrix.shape[1] - 0.5, colors="white", linewidth=2)

    for center, model in zip(model_centers, model_order):
        ax.text(
            center, -1.1,
            model_display_names.get(model, model),
            ha="center", va="bottom",
            fontsize=19, fontweight="semibold",
        )

    xticks, xticklabels = [], []
    for i in range(len(model_order)):
        for j, b in enumerate(bin_order):
            xticks.append(i * n_bins + j)
            xticklabels.append(b)

    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabels, fontsize=11)

    for i in range(n_rows):
        for j in range(full_matrix.shape[1]):
            val = full_matrix[i, j]
            if np.isnan(val):
                continue

            label = "100%" if abs(val - 1.0) < 1e-6 else f"{val*100:.1f}%"
            ax.text(j, i, label, ha="center", va="center", fontsize=9.5)

    cbar = fig.colorbar(im, ax=ax, fraction=0.03, pad=0.03)
    cbar.ax.tick_params(labelsize=16)
    cbar.set_label("Accuracy (0–1)", fontsize=18)

    for spine in ax.spines.values():
        spine.set_visible(False)

    fig.tight_layout(rect=[0.08, 0.05, 0.95, 0.97])

    out_path = Path(save_dir) / out_name
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

    print(f"Saved: {out_path}")


In [None]:
plot_heatmap_all_models_concatenated(
    heatmap_all=heatmap_all,
    model_order=model_order,
    model_display_names=model_display_names,
)


# line plot 

In [None]:
BENCHMARK_ORDER = ["ORG", "SS", "SO", "PSS", "POS"]

BENCHMARK_COLOR_MAP = {
    "ORG": "#1f77b4", 
    "SS":  "#ff7f0e", 
    "SO":  "#d62728",  
    "PSS": "#2ca02c",  
    "POS": "#9467bd",  
}


In [None]:
LANG_GROUPS = {
    "en": "g1",
    "de": "g1",
    "fr": "g2",
    "es": "g2",
    "it": "g2",
    "pt": "g2",
    "ko": "g3",
    "ja": "g3",
}
LANG_GROUP_TEXT_COLORS = {
    "g1": "#4A6FB3", 
    "g2": "#3A7F5C",  
    "g3": "#B24A4A",  
}

LANG_GROUP_BG_COLORS = {
    "g1": "#EEF3FB",  
    "g2": "#EEF7F1",
    "g3": "#FBEDEE",  
}


In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.ticker import FormatStrFormatter
import numpy as np


def plot_line_all_models_v2(
    df,
    save_dir="03.notebooks/0208_result/line_plot"
):
    Path(save_dir).mkdir(parents=True, exist_ok=True)

    MODEL_ORDER = ["Mistral", "LLaMA", "Gemini", "Qwen3-8B", "Qwen3-14B"]

    LANG_ORDER = ['en','de','fr','es','it','pt','ko','ja']
    BIN_ORDER = ["8/8","7/8","6/8","5/8","4/8","3/8","≤2/8"]

    n_rows = len(LANG_ORDER)
    n_cols = len(MODEL_ORDER)

    fig, axes = plt.subplots(
        nrows=n_rows,
        ncols=n_cols,
        figsize=(4.6 * n_cols, 2.8 * n_rows),
        sharex=True,
        sharey=True
    )

    for col, model in enumerate(MODEL_ORDER):
        for row, lang in enumerate(LANG_ORDER):
            ax = axes[row, col]

            lang_group = LANG_GROUPS[lang]
            ax.set_facecolor(LANG_GROUP_BG_COLORS[lang_group])

            sub = df[
                (df["model"] == model) &
                (df["lang"] == lang)
            ].copy()

            sub["bin"] = pd.Categorical(
                sub["bin"],
                categories=BIN_ORDER,
                ordered=True
            )

            for bench in BENCHMARK_ORDER:
                g = sub[sub["benchmark"] == bench]
                if len(g) == 0:
                    continue

                is_org = (bench == "ORG")

                ax.plot(
                    g["bin"],
                    g["accuracy"],
                    marker="o",
                    linewidth=3.5 if is_org else 2.2,
                    alpha=1.0 if is_org else 0.85,
                    color=BENCHMARK_COLOR_MAP[bench],
                    label=bench
                )

            if col == 0:
                ax.set_ylabel(
                    lang.upper(),
                    fontsize=30,
                    rotation=0,
                    labelpad=38,   
                    fontweight="bold",
                    color=LANG_GROUP_TEXT_COLORS[lang_group],
                )

            if row == 0:
                ax.set_title(
                    model,
                    fontsize=30,
                    fontweight="bold"
                )


            ax.set_ylim(0, 1.05)
            ax.set_yticks(np.linspace(0.0, 1.0, 6))        
            ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

            ax.grid(alpha=0.3)

            ax.tick_params(axis="x", labelsize=18)
            ax.tick_params(axis="y", labelsize=16)

    handles, labels = axes[0, 0].get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        loc="upper center",
        bbox_to_anchor=(0.5, 0.99),
        ncol=len(labels),
        frameon=False,
        fontsize=25
    )

    fig.tight_layout(rect=[0, 0, 1, 0.95])

    out_path = f"{save_dir}/line_plot_v6.pdf"
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

    print(f"Saved: {out_path}")


In [None]:
plot_line_all_models_v2(heatmap_all)
