In [1]:
import os
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    PROJECT_ROOT = Path.cwd().parent.resolve()

os.chdir(PROJECT_ROOT)

In [2]:
from matplotlib import pyplot as plt
from matplotlib.transforms import ScaledTranslation
import pandas as pd
from utils import ModelForAnalysis

plt.style.use("ggplot")
plt.style.use("my.mplstyle")

In [3]:
LARGELY_SAMPLED_EPITOPES = pd.read_csv("analysis_results/CDR3 Levenshtein/ovr_nn_200_shot.csv").epitope.unique()

In [None]:
models = [
    ModelForAnalysis("SCEPTR", "ovr_nn", "#7048e8", "d", zorder=2),
    ModelForAnalysis("TCR-BERT", "ovr_nn", "#74b816", "s"),
    ModelForAnalysis("ESM2 (T6 8M)", "ovr_nn", "#37b24d", "p"),
    ModelForAnalysis("ProtBert", "ovr_nn", "#0ca678", "x"),
]

fig, axs = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(8/2.54,5.5/2.54))

for model in models:
    num_params = model.get_num_parameters()
    dimensionality = model.get_model_dimensionality()

    aurocs_per_epitope = model.load_data(200).groupby("epitope").aggregate({"auc":"mean"})
    avg_auroc = aurocs_per_epitope[aurocs_per_epitope.index.isin(LARGELY_SAMPLED_EPITOPES)].mean().item()

    axs[0].scatter(num_params, avg_auroc, c=model.colour, marker=model.marker, label=model.name)
    axs[1].scatter(dimensionality, avg_auroc, c=model.colour, marker=model.marker)

axs[0].set_ylabel("200-shot Mean AUROC")
axs[0].set_ylim(0.5, 0.83)
axs[0].set_xlim(10**4, 10**9)
axs[1].set_xlim(10**1, 5*10**3)
axs[0].set_xlabel("Parameter Count")
axs[1].set_xlabel("Representation\nDimensionality")
axs[0].set_xscale("log")
axs[1].set_xscale("log")

for ax, label in zip(axs, ("a", "b")):
    trans = ScaledTranslation(0, 20/100, fig.dpi_scale_trans)
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans, fontsize='large', fontweight="bold", va='top')

fig.legend(loc="upper center", bbox_to_anchor=(0,-0.5,1,0.5), ncols=2)

fig.tight_layout()