In [None]:
import os
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    PROJECT_ROOT = Path.cwd().parent.resolve()

os.chdir(PROJECT_ROOT)

In [None]:
import numpy as np
import pandas as pd
from sceptr import variant

In [None]:
background_data = pd.read_csv("tcr_data/preprocessed/tanno/test.csv")
background_sample = background_data.sample(n=10_000, random_state=420)
labelled_data = pd.read_csv("tcr_data/preprocessed/benchmarking/combined.csv")

In [None]:
def uniformity(model, alpha: int = 1) -> float:
    pdist = model.calc_pdist_vector(background_sample)
    terms = np.exp(-(pdist ** alpha))
    return np.log(terms.mean())

def alignment(model, alpha: int = 1) -> float:
    pdist = labelled_data.groupby("Epitope").apply(lambda df: model.calc_pdist_vector(df), include_groups=False).to_list()
    pdist = np.concatenate(pdist)
    return np.mean(pdist ** alpha)

In [None]:
VARIANTS = (
    variant.default(),
    variant.mlm_only(),
)

results = pd.DataFrame.from_records([
    {"variant": model.name, "alignment": alignment(model), "uniformity": uniformity(model)} for model in VARIANTS
])
results["sum"] = results["alignment"] + results["uniformity"]
results = results.set_index("variant")

In [None]:
results