## Evaluation of Words as Frames

The Frame Representation Hypothesis assumes we can build concepts by "averaging" word frames.

One way to verify this claim is by checking if the Concept Frame correlation with the words used
to generate is higher than its correlation with other words or random frames.

In [None]:
import pandas as pd
import torch
from matplotlib import pyplot as plt

from frames.representations import FrameUnembeddingRepresentation
from frames.utils.memory import garbage_collection_cuda
from frames.utils.plotting import histplot_and_save
from frames.utils.settings import load_models

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODELS = load_models()

LEMMA_COUNT = {
    "gemma": 22,
    "llama": 54,
    "phi": 15,
}  # avoid GPU OOM errors (+80GB required)
FLAG_RANK = 3
N_RANDOM = 100

X = "total projection"
HUE = "model"

RAND_PROJ = "Random Projection"

In [None]:
def compute_similarities(concepts, words):
    sim = concepts.similarity(words)
    return sim[sim != 0]  # remove padded null vectors and flatten


def compare_words_and_random_frames(*args, **kwargs):
    garbage_collection_cuda()

    fur = FrameUnembeddingRepresentation.from_model_id(*args, **kwargs)

    kw = dict(
        min_lemmas_per_synset=next(
            v for k, v in LEMMA_COUNT.items() if k in kwargs["id"].lower()
        ),
        max_token_count=FLAG_RANK,
    )

    all_concepts = fur.get_all_concepts(**kw)
    print("Concepts =", all_concepts)

    all_word_frames = fur.get_all_words_frames(**kw)
    print("Words =", all_concepts)

    random_word_frames = torch.rand_like(all_word_frames.tensor[:N_RANDOM])

    concept_word_sim = compute_similarities(all_concepts, all_word_frames).float()
    concept_random_sim = compute_similarities(all_concepts, random_word_frames).float()

    rand_df = pd.DataFrame({X: concept_random_sim.cpu()}).assign(**{HUE: RAND_PROJ})
    word_df = pd.DataFrame({X: concept_word_sim.cpu()}).assign(**{HUE: str(fur.model)})

    return pd.concat([rand_df, word_df])


df = pd.concat(
    [
        compare_words_and_random_frames(**kwargs)
        for kwargs in MODELS.to_dict(orient="index").values()
    ]
)

plt.yscale("log")
histplot_and_save("03_concept_vs_word_frames_relationship", df, X, HUE, discrete=True)