## Ablation of top-k Guided Generation. Comparison among "model" families.

We analyze how the different models are affected by Concept Guided Generation.

In [None]:
import torch

from frames.nlp.datasets import load_multilingual_question_dataset
from frames.representations import FrameUnembeddingRepresentation
from frames.utils.memory import gc_cuda
from frames.utils.plotting import lineplot_and_save
from frames.utils.settings import load_models

In [None]:
MODELS = load_models()

N = 1 << 8
STEPS = 1 << 4
BATCH_SIZE = 1 << 4
GUIDE = ("woman.n.01", "male.n.01")

X = "Parameter Count (1e9)"
Y = "Concept Relative Projection (Guided - Unguided)"
HUE = "family"

LEMMA_COUNT = 3
TOKEN_COUNT = 3

MODEL_ABLATION_K = 3

In [None]:
dataset = load_multilingual_question_dataset(("English",)).iloc[:N].values.flatten()

print("Loaded English dataset with", len(dataset), "samples")

In [None]:
def get_relative_concept_projection(model_kwargs, k, *args, **kwargs):
    d = torch.bfloat16 if "gemma-2-9b" in model_kwargs["id"] else "auto"
    fur = FrameUnembeddingRepresentation.from_model_id(**model_kwargs, torch_dtype=d)
    guided = fur.quick_generate_with_topk_guide(*args, k=k, **kwargs)[1][:, -1]
    unguided = fur.quick_generate_with_topk_guide(*args, k=1, **kwargs)[1][:, -1]

    relative_projection = guided - unguided
    relative_projection = relative_projection.float().cpu().numpy()

    param_count = fur.model.parameter_count_string.removesuffix("B")

    return relative_projection, param_count


def get(model_kwargs):
    with gc_cuda():
        return get_relative_concept_projection(
            model_kwargs=model_kwargs,
            sentences=dataset,
            k=MODEL_ABLATION_K,
            steps=STEPS,
            batch_size=BATCH_SIZE,
            guide=GUIDE,
            min_lemmas_per_synset=LEMMA_COUNT,
            max_token_count=TOKEN_COUNT,
        )

In [None]:
data = MODELS.apply(get, axis=1)
MODELS[Y] = data.apply(lambda x: x[0])
MODELS[X] = data.apply(lambda x: int(x[1]))

df = MODELS.explode(Y).dropna()

lineplot_and_save("08_guided_generation_model_comparison", df, x=X, y=Y, hue=HUE)