## Evaluation of Words as Natural Flags

The flag representation hyothesis consists of assuming words are naturally represented as flags.

One way to verify this claim is by checking if the matrix representation of a word is naturally orthogonal. We can measure this by computing the ratio between the matrix rank and the number of vectors in the matrix.

In this experiment, we show words can be naturally orthogonal, although this is an approximation because the longer the word becomes (higher token count), the less likely it is to be orthogonal. However, since most words have a low token count, the average word orthogonality is quite high, which is a good indicator that words are naturally represented as flags.

In [1]:
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from flags.lrh import FlagUnembeddingRepresentation
from flags.nlp.synsets import SupportedLanguages
from flags.utils.memory import gc_cuda
from flags.utils.settings import load_models

sns.set_style("whitegrid")

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
MODELS = load_models()

X = "token count"
X_MAX = 16

HUE1 = "Model Family"

HUE2 = "Model"
Y = "Relative Rank (Matrix Rank / Num Tokens)"

In [None]:
def get_tokenization_data(**kwargs):
    with gc_cuda():
        fur = FlagUnembeddingRepresentation.from_model_id(
            language_codes=SupportedLanguages.Llama,
            synsets_kwargs=dict(variations=None),
            **kwargs,
        )
        cols = {HUE1: fur.model.family, HUE2: str(fur.model)}
        df = (
            fur.data.get_dataframe(fur.model.tokenizer, max_token_count=X_MAX)
            .assign(**cols)
            .to_pandas()
        )
        df[Y] = (
            fur.compute_orthogonality(df["tokens"], batch_size=1 << 12).cpu().numpy()
        )
        return df


df = MODELS.to_dict(orient="index").values()
df = pd.concat([get_tokenization_data(**kwargs) for kwargs in df])

df.shape

In [None]:
df_hist = df[df[X] < X_MAX - 1].drop_duplicates(["lemma", HUE1])

df_hist.to_pickle("resources/01_token_count_per_model_family.pkl")

g = sns.displot(
    df_hist,
    x=X,
    hue=HUE1,
    kind="hist",
    binwidth=1,
    shrink=1,
    palette="colorblind",
    multiple="dodge",
    height=5,
    aspect=1.1,
    facet_kws=dict(legend_out=False),
)

descriptive_stats = df.groupby(HUE1)[X].describe()

# Extract colors used in the histogram
palette = sns.color_palette(n_colors=len(descriptive_stats.index))
model_families = descriptive_stats.index

for color, model_family in zip(palette, model_families):
    percentile_75 = descriptive_stats.loc[model_family, "75%"]

    # Add dashed vertical line
    g.ax.axvline(percentile_75, color=color, linestyle="--", linewidth=2)

dashed_line = mlines.Line2D(
    [], [], color="black", linestyle="--", linewidth=1.5, label="75% Percentile"
)

legend_data = {
    model_family: plt.Line2D([], [], color=color, marker="o", linestyle="")
    for model_family, color in zip(model_families, palette)
}
legend_data["75% Percentile"] = dashed_line

g.add_legend(legend_data=legend_data, title=HUE1)

plt.xticks(range(1, X_MAX - 1))

plt.savefig(
    "resources/01_token_count_per_model_family.png", dpi=300, bbox_inches="tight"
)

plt.show()

In [None]:
df_lineplot = (
    df.drop_duplicates(["lemma", HUE2])
    .groupby(["lemma", HUE2, X])[Y]
    .mean()
    .reset_index()
)

df_lineplot.to_pickle("resources/02_tokenization_natural_flag.pkl")

sns.lineplot(
    df_lineplot,
    x=X,
    y=Y,
    hue=HUE2,
    style=HUE2,
    markers=True,
    dashes=False,
    palette="colorblind",
)

plt.xticks(range(1, X_MAX))

plt.savefig("resources/02_tokenization_natural_flag.png", dpi=300, bbox_inches="tight")

plt.show()