In [1]:
import os
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    PROJECT_ROOT = Path.cwd().parent.resolve()

os.chdir(PROJECT_ROOT)

In [2]:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.transforms import ScaledTranslation
import numpy as np
from numpy import ndarray
import pandas as pd
from pandas import Series
from typing import Iterable, Tuple, Literal
from utils import ModelForAnalysis

plt.style.use("ggplot")
plt.style.use("my.mplstyle")

In [3]:
RESULTS_DIR = PROJECT_ROOT/"analysis_results"

LARGELY_SAMPLED_EPITOPES = pd.read_csv("analysis_results/CDR3 Levenshtein/ovr_nn_200_shot.csv").epitope.unique()
LARGELY_SAMPLED_EPITOPES_10X = pd.read_csv("analysis_results/10x/CDR3 Levenshtein/ovr_nn_200_shot.csv").epitope.unique()

NUM_SHOTS_OF_INTEREST = [1,2,5,10,20,50,100,200]

In [4]:
models = (
    ModelForAnalysis("SCEPTR", "ovr_nn", "#7048e8", "d", zorder=2),
    ModelForAnalysis("TCRdist", "ovr_nn", "#f03e3e", "o", zorder=1.9),
    ModelForAnalysis("CDR3 Levenshtein", "ovr_nn", "#f76707", "^"),
    ModelForAnalysis("TCR-BERT", "ovr_nn", "#74b816", "s"),
    ModelForAnalysis("ESM2 (T6 8M)", "ovr_nn", "#37b24d", "p"),
    ModelForAnalysis("ProtBert", "ovr_nn", "#0ca678", "x"),
)

In [5]:
def generate_revised_summary_figure(
        models: Iterable[ModelForAnalysis],
        ks: Iterable[int],
        epitopes: Iterable[str],
        # legend_in_axes: bool = False,
        # ncols: int = 0
) -> Figure:
    mean_std_collection = []

    for k in ks:
        model_dfs = [(model.name, model.load_data(k)) for model in models]
        summary_df = pd.DataFrame()
        summary_df["epitope"] = model_dfs[0][1]["epitope"]
        summary_df["split"] = model_dfs[0][1]["split"]

        for model_name, model_df in model_dfs:
            summary_df[model_name] = model_df["auc"]
        
        summary_df = summary_df[summary_df["epitope"].map(lambda ep: ep in epitopes)]

        # get average performance across epitopes per model
        avg_performance_df = summary_df.groupby("epitope").aggregate({model.name: "mean" for model in models})
        avg_performances = avg_performance_df.mean()

        # get error bars across epitopes per model
        model_averages = summary_df.apply(
            lambda row: np.mean(row.iloc[2:]),
            axis="columns"
        )

        delta_df = summary_df.copy()
        for model in models:
            delta_df[model.name] = delta_df[model.name] - model_averages

        variance_by_epitope = delta_df.groupby("epitope").apply(
            lambda df: Series(data=(df[model.name].var() for model in models), index=(model.name for model in models)),
            include_groups=False
        )
        stds = np.sqrt(variance_by_epitope.sum()) / len(epitopes)

        # append to mean_std collection
        mean_std_df = pd.DataFrame(data=(avg_performances, stds), index=("mean", "std"))
        mean_std_collection.append(mean_std_df)
    
    # plot results
    fig, ax = plt.subplots(figsize=(7/2.54,8/2.54))

    for model in models:
        mean_stds_for_model = pd.DataFrame([df[model.name] for df in mean_std_collection], index=ks)
        ax.errorbar(
            x=range(len(ks)),
            y=mean_stds_for_model["mean"],
            yerr=mean_stds_for_model["std"],
            fmt=model.style,
            markersize=5,
            c=model.colour,
            label=model.name,
            zorder=model.zorder,
            capsize=5
        )

    ax.set_ylabel("Mean AUROC")
    ax.set_xlabel("Number of Reference TCRs")
    ax.set_xticks(range(len(ks)), ks)

    handles, labels = ax.get_legend_handles_labels()
    new_handles = [
        plt.Line2D(
            [0], [0],
            color=handle[0].get_color(),
            lw=handle[0].get_linewidth(),
            linestyle=handle[0].get_linestyle(),
            marker=handle[0].get_marker(),
            markersize=handle[0].get_markersize()
        )
        for handle in handles
        ]
    fig.legend(handles=new_handles, labels=labels, loc="center left", bbox_to_anchor=(1,0,1,1))

    fig.tight_layout()

    return fig



In [None]:
fig = generate_revised_summary_figure(models, NUM_SHOTS_OF_INTEREST, LARGELY_SAMPLED_EPITOPES)

In [13]:
def get_benchmark_summary_with_errorbars(
    models: Iterable[ModelForAnalysis],
    ks: Iterable[int],
    epitopes: Iterable[str],
):
    mean_std_collection = []

    for k in ks:
        model_dfs = [(model.name, model.load_data(k)) for model in models]
        summary_df = pd.DataFrame()
        summary_df["epitope"] = model_dfs[0][1]["epitope"]
        summary_df["split"] = model_dfs[0][1]["split"]

        for model_name, model_df in model_dfs:
            summary_df[model_name] = model_df["auc"]
        
        summary_df = summary_df[summary_df["epitope"].map(lambda ep: ep in epitopes)]

        # get average performance across epitopes per model
        avg_performance_df = summary_df.groupby("epitope").aggregate({model.name: "mean" for model in models})
        avg_performances = avg_performance_df.mean()

        # get error bars across epitopes per model
        model_averages = summary_df.apply(
            lambda row: np.mean(row.iloc[2:]),
            axis="columns"
        )

        delta_df = summary_df.copy()
        for model in models:
            delta_df[model.name] = delta_df[model.name] - model_averages

        variance_by_epitope = delta_df.groupby("epitope").apply(
            lambda df: Series(data=(df[model.name].var() for model in models), index=(model.name for model in models)),
            include_groups=False
        )
        stds = np.sqrt(variance_by_epitope.sum()) / len(epitopes)

        # append to mean_std collection
        mean_std_df = pd.DataFrame(data=(avg_performances, stds), index=("mean", "std"))
        mean_std_collection.append(mean_std_df.T.stack())

    return pd.DataFrame(mean_std_collection, index=ks)

In [None]:
get_benchmark_summary_with_errorbars(models, NUM_SHOTS_OF_INTEREST, LARGELY_SAMPLED_EPITOPES)