In [1]:
import os
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    PROJECT_ROOT = Path.cwd().parent.resolve()

os.chdir(PROJECT_ROOT)

In [2]:
from matplotlib.figure import Figure
from matplotlib import pyplot as plt
import pandas as pd
from typing import Iterable
import utils
from utils import ModelForAnalysis

plt.style.use("ggplot")
plt.style.use("my.mplstyle")

In [3]:
RESULTS_DIR = PROJECT_ROOT/"analysis_results"

LARGELY_SAMPLED_EPITOPES = pd.read_csv("analysis_results/CDR3 Levenshtein/ovr_nn_200_shot.csv").epitope.unique()
LARGELY_SAMPLED_EPITOPES_10X = pd.read_csv("analysis_results/10x/CDR3 Levenshtein/ovr_nn_200_shot.csv").epitope.unique()

NUM_SHOTS_OF_INTEREST = [1,2,5,10,20,50,100,200]

In [4]:
models = (
    ModelForAnalysis("SCEPTR", "ovr_nn", "#7048e8", "d", zorder=2),
    ModelForAnalysis("TCRdist", "ovr_nn", "#f03e3e", "o", zorder=1.9),
    ModelForAnalysis("CDR3 Levenshtein", "ovr_nn", "#f76707", "^"),
    ModelForAnalysis("TCR-BERT", "ovr_nn", "#74b816", "s"),
    ModelForAnalysis("ESM2 (T6 8M)", "ovr_nn", "#37b24d", "p"),
    ModelForAnalysis("ProtBert", "ovr_nn", "#0ca678", "x"),
)

In [None]:
def generate_revised_summary_figure(
        models: Iterable[ModelForAnalysis],
        ks: Iterable[int],
        epitopes: Iterable[str],
        # legend_in_axes: bool = False,
        # ncols: int = 0
) -> Figure:
    fig, ax = plt.subplots(figsize=(7/2.54,8/2.54))

    utils.plot_performance_curves(models, ks, epitopes, ax)

    handles, labels = ax.get_legend_handles_labels()
    new_handles = [
        plt.Line2D(
            [0], [0],
            color=handle[0].get_color(),
            lw=handle[0].get_linewidth(),
            linestyle=handle[0].get_linestyle(),
            marker=handle[0].get_marker(),
            markersize=handle[0].get_markersize()
        )
        for handle in handles
        ]
    fig.legend(handles=new_handles, labels=labels, loc="center left", bbox_to_anchor=(1,0,1,1))

    fig.tight_layout()

    return fig

fig = generate_revised_summary_figure(models, NUM_SHOTS_OF_INTEREST, LARGELY_SAMPLED_EPITOPES)