## Ablation of top-k Guided Generation. Study of "language" and "k" influence.

We analyze how the different languages are affected by Concept Guided Generation for different levels of "k".

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from frames.nlp.datasets import load_multilingual_question_dataset
from frames.representations import FrameUnembeddingRepresentation
from frames.utils.plotting import lineplot_and_save

In [None]:
# parameters
MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"  # "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4"
N = 1 << 8
STEPS = 1 << 4
BATCH_SIZE = 1 << 5

GUIDE = ("woman.n.01", "male.n.01")

X = "Guidance Level"
Y = "Concept Relative Projection (Guided - Unguided)"
HUE = "language"

TOPK_ABLATION_K = list(range(7, 0, -1))

In [None]:
fur = FrameUnembeddingRepresentation.from_model_id(MODEL_ID)
print("memory cost: ", fur.model.memory_footprint >> 20, "Mb")

In [None]:
dataset = load_multilingual_question_dataset(fur.data.languages).iloc[:N]
samples, langs = dataset.shape
print("Loaded dataset with", samples, "samples and", langs, "languages")

In [None]:
def gen_guided_probe(inputs, k):
    _, probe = fur.quick_generate_with_topk_guide(
        inputs,
        guide=GUIDE,
        k=k,
        steps=STEPS,
        batch_size=BATCH_SIZE,
        min_lemmas_per_synset=3,
        max_token_count=3,
    )
    return pd.DataFrame(probe.float()).melt(var_name=X, value_name=Y)

In [None]:
df = pd.concat(
    [
        gen_guided_probe(dataset[lang].tolist(), k=k).assign(**{HUE: lang, X: k})
        for lang in dataset.columns
        for k in TOPK_ABLATION_K
    ]
)

lineplot_and_save("09_guided_generation_language_topk_ablation", df, x=X, y=Y, hue=HUE)

In [None]:
values = df[df[X] == 1].groupby(HUE)[Y].mean()
df[Y] = df.apply(lambda row: row[Y] - values[row[HUE]], axis=1)

sns.set_style("whitegrid")
sns.set_context("paper", font_scale=1.5)

plt.figure(figsize=(5, 7))

sns.lineplot(
    df,
    x=X,
    y=Y,
    hue=HUE,
    style=HUE,
    markers=True,
    dashes=False,
    errorbar=("ci", 10),
    palette="colorblind",
    linewidth=2,
)
# plt.yscale("symlog")
plt.xlim(1, 5)
plt.xticks(TOPK_ABLATION_K)
plt.yticks(range(0, 11))

# lineplot_and_save("11_guided_generation_language_topk_ablation", df, x=X, y=Y, hue=HUE)

In [None]:
plt.yscale("symlog")
lineplot_and_save("09_guided_generation_language_topk_ablation", df, x=X, y=Y, hue=HUE)