## Ablation of top-k Guided Generation. Study of "k" influence.

We analyze how the value of "k" affects the result using concept probing.

In [1]:
import pandas as pd

from frames.nlp.datasets import load_multilingual_question_dataset
from frames.representations import FrameUnembeddingRepresentation
from frames.utils.plotting import lineplot_and_save

In [None]:
# parameters
MODEL_ID = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4" # "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4"
N = 1 << 9
STEPS = 1 << 5
BATCH_SIZE = 1 << 5

GUIDES = [("woman.n.01", "male.n.01"), ("black.n.01", "white.n.01")]

X = "token index"
Y = "total projection"
HUE = "guidance level"
HUE2 = "guidance"

TOPK_ABLATION_K = (3, 2, 1)

In [3]:
fur = FrameUnembeddingRepresentation.from_model_id(MODEL_ID)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[32m2025-02-17 17:58:53.341[0m | [1mINFO    [0m | [36mframes.models.hf.base[0m:[36m__init__[0m:[36m88[0m - [1mLoaded model: hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4[0m


In [4]:
dataset = load_multilingual_question_dataset(fur.data.languages).iloc[:N]
samples, langs = dataset.shape
print("Loaded dataset with", samples, "samples and", langs, "languages")

README.md:   0%|          | 0.00/13.7k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/137M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/978k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/202362 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1750 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/15.6k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/242k [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/1190 [00:00<?, ? examples/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/337k [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/1190 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/7.40k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/12.2M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/1.62M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/54159 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7609 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/34.9k [00:00<?, ?B/s]

mlqa.py:   0%|          | 0.00/8.44k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/63.4M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/82451 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10253 [00:00<?, ? examples/s]

Loaded dataset with 512 samples and 8 languages


In [5]:
def gen_guided_probe(inputs, guide, k):
    _, probe = fur.quick_generate_with_topk_guide(
        inputs,
        guide=guide,
        k=k,
        steps=STEPS,
        batch_size=BATCH_SIZE,
        min_lemmas_per_synset=3,
        max_token_count=3,
    )
    return pd.DataFrame(probe).melt(var_name=X, value_name=Y)

In [None]:
inputs = dataset["English"].tolist()

df = pd.concat(
    [gen_guided_probe(inputs, guide, k).assign(**{HUE: f"{k=}", HUE2: f"{guide=}"}) for k in TOPK_ABLATION_K for guide in GUIDES]
)
df[HUE] = df[HUE].replace("k=1", "k=1 (no guidance)")

lineplot_and_save("06_guided_generation_top_k_ablation", df, x=X, y=Y, hue=HUE)

 31%|███▏      | 5/16 [02:51<06:31, 35.63s/it]