## Ablation of top-k Guided Generation. Study of "k" influence.

We analyze how the value of "k" affects the result using concept probing.

In [1]:
import pandas as pd

from frames.nlp.datasets import load_multilingual_question_dataset
from frames.representations import FrameUnembeddingRepresentation
from frames.utils.plotting import lineplot_and_save

In [None]:
# parameters
MODEL_ID = "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4"
N = 1 << 9
STEPS = 1 << 5
BATCH_SIZE = 1 << 5

GUIDE = ("woman.n.01", "male.n.01")

X = "token index"
Y = "total projection"
HUE = "guidance level"

TOPK_ABLATION_K = (3, 2, 1)

In [None]:
fur = FrameUnembeddingRepresentation.from_model_id(MODEL_ID)
print("memory cost: ", fur.model.memory_footprint >> 20, "Mb")

In [None]:
dataset = load_multilingual_question_dataset(fur.data.languages).iloc[:N]
samples, langs = dataset.shape
print("Loaded dataset with", samples, "samples and", langs, "languages")

In [None]:
def gen_guided_probe(inputs, k):
    _, probe = fur.quick_generate_with_topk_guide(
        inputs,
        guide=GUIDE,
        k=k,
        steps=STEPS,
        batch_size=BATCH_SIZE,
        min_lemmas_per_synset=3,
        max_token_count=3,
    )
    return pd.DataFrame(probe).melt(var_name=X, value_name=Y)

In [None]:
inputs = dataset["English"].tolist()

df = pd.concat(
    [gen_guided_probe(inputs, k=k).assign(**{HUE: f"{k=}"}) for k in TOPK_ABLATION_K]
)
df[HUE] = df[HUE].replace("k=1", "k=1 (no guidance)")

lineplot_and_save("06_guided_generation_top_k_ablation", df, x=X, y=Y, hue=HUE)