# Fig 4 & Fig 5

In [None]:
import pyarrow as pa
import pyarrow.parquet as pq
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

In [None]:
RESULT_DIR = '02.results/00.Benchmark_Results/'
RESULT_NAME = ['99.Gemini/', 
               '01.Llama/',
               '02.Mistral/',
               '03.Qwen/Qwen-8B/',
               '03.Qwen/Qwen-14B/']               

In [None]:
BENCHMARK_DIR = '00.data/02.WikiBench/'
BENCHMARK_NAME = ['00.original_benchmark_TF_500.parquet',
                  '01.subject_shuffled_benchmark_TF_500.parquet',
                  '02.object_shuffled_benchmark_TF_500.parquet',
                  '03.property_scoped_subject_shuffled_benchmark_TF_500.parquet',
                  '04.property_scoped_object_shuffled_benchmark_TF_500.parquet']

In [None]:
BENCHMARK_GEMINI = [pq.read_table(f"{RESULT_DIR}{RESULT_NAME[0]}{name}").to_pandas() for name in BENCHMARK_NAME]
BENCHMARK_LLAMA = [pq.read_table(f"{RESULT_DIR}{RESULT_NAME[1]}{name}").to_pandas() for name in BENCHMARK_NAME]
BENCHMARK_MISTRAL = [pq.read_table(f"{RESULT_DIR}{RESULT_NAME[2]}{name}").to_pandas() for name in BENCHMARK_NAME]
BENCHMARK_QWEN3_8B = [pq.read_table(f"{RESULT_DIR}{RESULT_NAME[3]}{name}").to_pandas() for name in BENCHMARK_NAME]
BENCHMARK_QWEN3_14B = [pq.read_table(f"{RESULT_DIR}{RESULT_NAME[4]}{name}").to_pandas() for name in BENCHMARK_NAME]

BENCHMARK_RESULTS = [BENCHMARK_MISTRAL,
                     BENCHMARK_LLAMA,
                     BENCHMARK_GEMINI,
                     BENCHMARK_QWEN3_8B,
                     BENCHMARK_QWEN3_14B]

In [None]:
MODEL_NAMES = ["Mistral", "LLaMA", "Gemini", "Qwen3-8B", "Qwen3-14B"]
print(len(MODEL_NAMES), len(BENCHMARK_RESULTS))
BENCHMARK_LABELS = [
    "ORG",
    "SS",
    "SO",
    "PSS",
    "POS",
]


In [None]:
LANGUAGE_LIST = ['en','de','fr','es','it','pt','ko','ja']

In [None]:
pd.set_option('display.max_columns', None)

In [None]:
all_values = set()
for model_results in BENCHMARK_RESULTS:
    for lang in LANGUAGE_LIST:
        col = f"response_TF_{lang}"
        for df in model_results:
            if col in df.columns:
                all_values.update(df[col].unique())

print(all_values)
print("총 개수:", len(all_values))

In [None]:
def is_correct(df):
    if df.iloc[0]['kind'] == 'original':
        correct = 0
        wrong = 1
        unsure = 2
    else:
        correct = 1
        wrong = 0
        unsure = 2
    for lang in LANGUAGE_LIST:
        col_response_TF = f"response_TF_{lang}"
        col_correct = f"correct_{lang}"
        correct_list = []
        for response in df[col_response_TF].tolist():
            if response == "True":
                correct_list.append(correct)
            elif response == "<answer>True</answer>":
                correct_list.append(correct)
            elif response == "False":
                correct_list.append(wrong)
            elif response == "<answer>False</answer>":
                correct_list.append(wrong)
            else:
                correct_list.append(unsure)
        df[col_correct] = correct_list
    return df

In [None]:
for model_idx in range(0, len(BENCHMARK_RESULTS)):  
    for bench_idx in range(len(BENCHMARK_RESULTS[model_idx])):
        BENCHMARK_RESULTS[model_idx][bench_idx] = is_correct(BENCHMARK_RESULTS[model_idx][bench_idx])

In [None]:
df

In [None]:
MODEL_DFS = {
    "Gemini": BENCHMARK_GEMINI,
    "LLaMA": BENCHMARK_LLAMA,
    "Mistral": BENCHMARK_MISTRAL,
    "Qwen-8B": BENCHMARK_QWEN3_8B,
    "Qwen-14B": BENCHMARK_QWEN3_14B,
}

MODEL_TYPE = {
    "Gemini": "gemini",
    "LLaMA": "local",
    "Mistral": "local",
    "Qwen-8B": "local",
    "Qwen-14B": "local",
}


In [None]:
def modelwise_jaccard_correct(
    row_df,
    benchmark,
    language,
    models
):
    
    mat = pd.DataFrame(
        index=models,
        columns=models,
        dtype=float
    )

    for m1 in models:
        A = set(
            row_df[
                (row_df["model"] == m1) &
                (row_df["benchmark"] == benchmark) &
                (row_df["language"] == language) &   
                (row_df["correct"] == 0)
            ]["row_id"]
        )

        for m2 in models:
            B = set(
                row_df[
                    (row_df["model"] == m2) &
                    (row_df["benchmark"] == benchmark) &
                    (row_df["language"] == language) &  
                    (row_df["correct"] == 0)
                ]["row_id"]
            )

            if not A and not B:
                mat.loc[m1, m2] = 1.0
            else:
                mat.loc[m1, m2] = len(A & B) / len(A | B)

    return mat


In [None]:
def build_rowwise_correct_df(
    benchmark_results,
    model_names,
    benchmark_labels,
    language_list,
):
    rows = []

    for model_name, model_benches in zip(model_names, benchmark_results):
        for bench_label, df in zip(benchmark_labels, model_benches):

            for _, r in df.iterrows():
                for lang in language_list:
                    rows.append({
                        "row_id": r["row_id"],
                        "model": model_name,
                        "benchmark": bench_label,  
                        "language": lang,            
                        "correct": r[f"correct_{lang}"] 
                    })

    return pd.DataFrame(rows)


In [None]:
row_df = build_rowwise_correct_df(
    benchmark_results=BENCHMARK_RESULTS,
    model_names=MODEL_NAMES,
    benchmark_labels=BENCHMARK_LABELS,
    language_list=LANGUAGE_LIST,
)

row_df.head(50)


# Fig 4  (fig 5 -> 4)

In [None]:
LANG_GROUPS = {
    "en": "g1",
    "de": "g1",
    "fr": "g2",
    "es": "g2",
    "it": "g2",
    "pt": "g2",
    "ko": "g3",
    "ja": "g3",
}

LANG_GROUP_TEXT_COLORS = {
    "g1": "#4A6FB3", 
    "g2": "#3A7F5C",  
    "g3": "#B24A4A", 
}

In [None]:
MODEL_ORDER = MODEL_NAMES


In [None]:
def vertical_text(s):
    return "\n".join(list(s))
#0208 15:40 

In [None]:
import matplotlib.pyplot as plt
import seaborn as sn
import matplotlib as mpl
import numpy as np

models = MODEL_ORDER


fig, axes = plt.subplots(
    nrows=len(BENCHMARK_LABELS),
    ncols=len(LANGUAGE_LIST),
    figsize=(3 * len(LANGUAGE_LIST), 3 * len(BENCHMARK_LABELS)),
    sharex=True,
    sharey=False   
)

plt.subplots_adjust(wspace=0.08, hspace=0.12)


cbar_ax = fig.add_axes([0.93, 0.15, 0.015, 0.7])


for i, bench_label in enumerate(BENCHMARK_LABELS):
    for j, lang in enumerate(LANGUAGE_LIST):

        ax = axes[i, j]

        jac_mat = modelwise_jaccard_correct(
            row_df=row_df,
            benchmark=bench_label,
            language=lang,
            models=models
        )

        sns.heatmap(
            jac_mat,
            ax=ax,
            vmin=0,
            vmax=1,
            cmap="Blues",
            square=True,
            cbar=False
        )

        if i == 0:
            group = LANG_GROUPS[lang]
            color = LANG_GROUP_TEXT_COLORS[group]
            ax.set_title(
                lang.upper(),
                fontsize=30,
                color=color,
                fontweight="bold"
            )

        if j == 0:
            ax.set_yticks(np.arange(len(models)) + 0.5)
            ax.set_yticklabels(models, fontsize=14)
        else:
            ax.set_yticks([])
            ax.set_yticklabels([])

        ax.set_xticks(np.arange(len(models)) + 0.5)
        ax.set_xticklabels(models, rotation=90, fontsize=12)

        ax.tick_params(length=0)
        ax.set_xlabel("")
        ax.set_ylabel("")

for i, bench_label in enumerate(BENCHMARK_LABELS):
    y_pos = 1 - (i + 0.5) / len(BENCHMARK_LABELS)

    fig.text(
        0.035,       
        y_pos,
        bench_label,
        rotation=90,  
        va="center",
        ha="center",
        fontsize=30,
        fontweight="bold"
    )


norm = mpl.colors.Normalize(vmin=0, vmax=1)
sm = mpl.cm.ScalarMappable(norm=norm, cmap="Blues")
sm.set_array([])

cbar = fig.colorbar(sm, cax=cbar_ax)
cbar.set_label("Jaccard Index", fontsize=24)
cbar.ax.tick_params(labelsize=20)

plt.tight_layout(rect=[0.06, 0, 0.91, 0.95])

fig.savefig(
    "03.notebooks/0208_result/Fig4/Fig4_model-wise_jaccard0208_FINAL.pdf",
    bbox_inches="tight",
    dpi=300
)

plt.show()


# Fig 5  Model-model Jaccard similarity (Fig 4 -> 5)

In [None]:
def build_rowwise_correct_df_from_results(
    benchmark_results,
    model_names,
    benchmark_labels,
    language_list
):
    rows = []

    for m_idx, model in enumerate(model_names):
        for b_idx, bench in enumerate(benchmark_labels):
            df = benchmark_results[m_idx][b_idx]

            for row_id, r in df.iterrows():
                for lang in language_list:
                    rows.append({
                        "row_id": row_id,
                        "model": model,
                        "benchmark": bench,
                        "language": lang,
                        "correct": r[f"correct_{lang}"]  
                    })

    return pd.DataFrame(rows)


row_df = build_rowwise_correct_df_from_results(
    BENCHMARK_RESULTS,
    MODEL_NAMES,
    BENCHMARK_LABELS,
    LANGUAGE_LIST
)


In [None]:
def modelwise_jaccard_correct(row_df, benchmark, language, models):
    mat = pd.DataFrame(index=models, columns=models, dtype=float)

    for m1 in models:
        A = set(
            row_df[
                (row_df.model == m1) &
                (row_df.benchmark == benchmark) &
                (row_df.language == language) &
                (row_df.correct == 0)
            ].row_id
        )

        for m2 in models:
            B = set(
                row_df[
                    (row_df.model == m2) &
                    (row_df.benchmark == benchmark) &
                    (row_df.language == language) &
                    (row_df.correct == 0)
                ].row_id
            )

            if not A and not B:
                mat.loc[m1, m2] = 1.0
            else:
                mat.loc[m1, m2] = len(A & B) / len(A | B)

    return mat


In [None]:
BENCHMARK_RESULTS[2][0].columns


In [None]:
row_df = row_df.copy()
row_df["row_key"] = row_df["row_id"].astype(str) + "__" + row_df["language"].astype(str)


In [None]:
import pandas as pd

def languagewise_jaccard_correct(row_df, benchmark, model, languages):

    sub = row_df[
        (row_df["model"] == model) &
        (row_df["benchmark"] == benchmark)
    ]

    mat = pd.DataFrame(index=languages, columns=languages, dtype=float)

    for l1 in languages:
        A = set(sub[(sub["language"] == l1) & (sub["correct"] == 0)]["row_key"])

        for l2 in languages:
            B = set(sub[(sub["language"] == l2) & (sub["correct"] == 0)]["row_key"])

            mat.loc[l1, l2] = 1.0 if (not A and not B) else (len(A & B) / len(A | B))

    return mat


In [None]:
def languagewise_jaccard_correct(row_df, benchmark, model, languages):
    sub = row_df[
        (row_df["benchmark"] == benchmark) &
        (row_df["model"] == model)
    ]

    mat = pd.DataFrame(index=languages, columns=languages, dtype=float)

    for l1 in languages:
        A = set(sub[(sub.language == l1) & (sub.correct == 0)]["row_id"])

        for l2 in languages:
            B = set(sub[(sub.language == l2) & (sub.correct == 0)]["row_id"])

            if not A and not B:
                mat.loc[l1, l2] = 1.0
            else:
                mat.loc[l1, l2] = len(A & B) / len(A | B)

    return mat


In [None]:
MODEL_KEYS = [
    "Mistral",
    "LLaMA",
    "Gemini",
    "Qwen3-8B",
    "Qwen3-14B",
]


In [None]:
MODEL_NAME_MAP = {
    "Mistral": "Mistral-Nemo-Instruct-2407",
    "LLaMA": "LLaMA 3.1 Instruct-8B",
    "Gemini": "Gemini 2.5 Flash",
    "Qwen3-8B": "Qwen-8B",
    "Qwen3-14B": "Qwen-14B",
}


In [None]:
sub = row_df[
    (row_df.benchmark == ptype) &
    (row_df.model == model)
]


In [None]:
PROBLEM_TYPES = ["ORG", "SS", "SO", "PSS", "POS"]
LANGUAGES = ["en", "de", "fr", "es", "it", "pt", "ko", "ja"]
LANG_LABELS = [l.upper() for l in LANGUAGES]

MODEL_KEYS = [
    "Mistral",
    "LLaMA",
    "Gemini",
    "Qwen3-8B",
    "Qwen3-14B",
]

MODEL_NAME_MAP = {
    "Mistral": "Mistral",
    "LLaMA": "LLaMA",
    "Gemini": "Gemini",
    "Qwen3-8B": "Qwen-8B",
    "Qwen3-14B": "Qwen-14B",
}


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl

fig, axes = plt.subplots(
    nrows=len(PROBLEM_TYPES),
    ncols=len(MODEL_KEYS),
    figsize=(3.8 * len(MODEL_KEYS), 3.8 * len(PROBLEM_TYPES)),
    sharex=True,
    sharey=True
)

plt.subplots_adjust(wspace=0.08, hspace=0.12)


norm = mpl.colors.Normalize(vmin=0, vmax=1)
sm = mpl.cm.ScalarMappable(norm=norm, cmap="Blues")
sm.set_array([])


for i, ptype in enumerate(PROBLEM_TYPES):
    for j, model in enumerate(MODEL_KEYS):

        ax = axes[i, j]

        jac = compute_language_jaccard_matrix(
            row_df=row_df,
            problem_type=ptype,
            model=model,
            languages=LANGUAGES
        )

        sns.heatmap(
            jac,
            ax=ax,
            cmap="Blues",
            vmin=0,
            vmax=1,
            square=True,
            cbar=False
        )

        if i == 0:
            ax.set_title(
                model,          
                fontsize=30,
                pad=12,
                fontweight="bold"
            )

        if j == 0:
            ax.set_ylabel(
                ptype,
                fontsize=30,
                rotation=90,
                labelpad=18,
                va="center",
                fontweight="bold"
            )
        else:
            ax.set_ylabel("")

        ax.set_xticks(np.arange(len(LANGUAGES)) + 0.5)
        ax.set_yticks(np.arange(len(LANGUAGES)) + 0.5)

        ax.set_xticklabels(LANG_LABELS, fontsize=15)

        ax.set_yticklabels(
            LANG_LABELS,
            fontsize=19,
            rotation=0,
            va="center"
        )

        ax.tick_params(length=0)

cbar_ax = fig.add_axes([0.93, 0.15, 0.025, 0.7])
cbar = fig.colorbar(sm, cax=cbar_ax)
cbar.set_label("Jaccard Index", fontsize=25)
cbar.ax.tick_params(labelsize=16)

# ===============================
# Save
# ===============================
fig.savefig(
    "03.notebooks/0208_result/Fig5/fig5_language_jaccard0208_confirm_2.pdf",
    bbox_inches="tight",
    dpi=300
)

plt.show()
