In [1]:
import os
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    PROJECT_ROOT = Path.cwd().parent.resolve()

os.chdir(PROJECT_ROOT)

In [2]:
from matplotlib import pyplot as plt
import pandas as pd
import utils
from utils import ModelForAnalysis

plt.style.use("ggplot")
plt.style.use("my.mplstyle")

In [3]:
RESULTS_DIR = PROJECT_ROOT/"analysis_results"

LARGELY_SAMPLED_EPITOPES = pd.read_csv("analysis_results/CDR3 Levenshtein/ovr_nn_200_shot.csv").epitope.unique()

NUM_SHOTS_OF_INTEREST = [1,2,5,10,20,50,100,200]

In [None]:
fig, ax = plt.subplots(ncols=3, sharey=True, sharex=True, figsize=(15/2.54,10/2.54))

tasks = (
    "ovr_nn",
    "ovr_nn_a",
    "ovr_nn_b"
)

headers = (
    r"$\alpha\beta$",
    r"$\alpha$",
    r"$\beta$"
)

for i, task in enumerate(tasks):
    models = (
        ModelForAnalysis("SCEPTR", task, "#7048e8", "d", zorder=2),
        ModelForAnalysis("TCRdist", task, "#f03e3e", "o"),
        ModelForAnalysis("SCEPTR (dropout noise only)", task, "#7048e8", linestyle="--", zorder=2),
    )

    utils.plot_performance_curves(models, NUM_SHOTS_OF_INTEREST, LARGELY_SAMPLED_EPITOPES, ax[i])

    ax[i].set_ylabel("")
    ax[i].set_xlabel("")
    ax[i].set_title(headers[i])

fig.supylabel("Mean AUROC")
fig.supxlabel("Number of Reference TCRs")

handles, labels = ax[0].get_legend_handles_labels()
new_handles = [
    plt.Line2D(
        [0], [0],
        color=handle[0].get_color(),
        lw=handle[0].get_linewidth(),
        linestyle=handle[0].get_linestyle(),
        marker=handle[0].get_marker(),
        markersize=handle[0].get_markersize()
    )
    for handle in handles
    ]
fig.legend(handles=new_handles, labels=labels, loc="upper center", bbox_to_anchor=(0,-1,1,1), ncols=3)

fig.tight_layout()