In [5]:
import os
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    PROJECT_ROOT = Path.cwd().parent.resolve()

os.chdir(PROJECT_ROOT)

In [6]:
from matplotlib import pyplot as plt
from matplotlib.transforms import ScaledTranslation
import pandas as pd
import utils
from utils import ModelForAnalysis

plt.style.use("ggplot")
plt.style.use("my.mplstyle")

In [7]:
NUM_SHOTS_OF_INTEREST = [1,2,5,10,20,50,100,200]
LARGELY_SAMPLED_EPITOPES_10X = pd.read_csv("analysis_results/10x/CDR3 Levenshtein/ovr_nn_200_shot.csv").epitope.unique()

In [None]:
fig, ax = plt.subplots(figsize=(10/2.54,8/2.54))

utils.plot_performance_curves(
    (
        ModelForAnalysis("SCEPTR", "ovr_nn", "#7048e8", "d", zorder=2, ten_x=True),
        ModelForAnalysis("TCRdist", "ovr_nn", "#f03e3e", "o", ten_x=True),
        ModelForAnalysis("CDR3 Levenshtein", "ovr_nn", "#f76707", "^", ten_x=True),
        ModelForAnalysis("TCR-BERT", "ovr_nn", "#74b816", "s", ten_x=True),
        ModelForAnalysis("ESM2 (T6 8M)", "ovr_nn", "#37b24d", "p", ten_x=True),
        ModelForAnalysis("ProtBert", "ovr_nn", "#0ca678", "x", ten_x=True),
    ),
    NUM_SHOTS_OF_INTEREST,
    LARGELY_SAMPLED_EPITOPES_10X,
    ax
)

handles, labels = ax.get_legend_handles_labels()
new_handles = [
    plt.Line2D(
        [0], [0],
        color=handle[0].get_color(),
        lw=handle[0].get_linewidth(),
        linestyle=handle[0].get_linestyle(),
        marker=handle[0].get_marker(),
        markersize=handle[0].get_markersize()
    )
    for handle in handles
    ]
ax.legend(handles=new_handles, labels=labels, loc="lower right")

fig.tight_layout()
fig.savefig("benchmarking_10x.pdf", bbox_inches="tight")