In [1]:
import os
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    PROJECT_ROOT = Path.cwd().parent.resolve()

os.chdir(PROJECT_ROOT)

In [2]:
from matplotlib.figure import Figure
from matplotlib import pyplot as plt
import pandas as pd
from typing import Iterable
import utils
from utils import ModelForAnalysis

plt.style.use("ggplot")
plt.style.use("my.mplstyle")

In [3]:
def plot_predetermined_split_results(models: Iterable[ModelForAnalysis]) -> Figure:
    aucs_per_model = []

    for model in models:
        aucs_per_model.append(model.load_data().set_index("epitope"))
    
    for model, aucs in zip(models, aucs_per_model):
        aucs.columns = [model.name]
    
    aucs_per_model = pd.concat(aucs_per_model, axis="columns")
    aucs_per_model["avg"] = aucs_per_model.apply(lambda row: row.mean(), axis="columns")
    aucs_per_model = aucs_per_model.sort_values(by="avg", ascending=False)

    fig, ax = plt.subplots(figsize=(8/2.54,8/2.54))

    for model in models:
        ax.plot(aucs_per_model[model.name], model.style, c=model.colour, label=model.name, zorder=model.zorder)

    ax.set_ylabel("AUROC")
    ax.set_xlabel("Epitope")
    ax.tick_params(axis='x', labelrotation=45)

    current_min = ax.get_ylim()[0]
    new_min = min(current_min, 0.5)
    ax.set_ylim(new_min)

    ax.legend()
    fig.tight_layout()

    return fig

In [None]:
fig = plot_predetermined_split_results((
    ModelForAnalysis("SCEPTR (finetuned)", "ovr_predetermined_split_avg_dist", "#5f3dc4", "P", zorder=2),
    ModelForAnalysis("SCEPTR", "ovr_predetermined_split_nn", "#7048e8", "d", zorder=1.9),
    ModelForAnalysis("TCRdist", "ovr_predetermined_split_nn", "#f03e3e", "o", zorder=1.8),
    ModelForAnalysis("TCR-BERT", "ovr_predetermined_split_nn", "#74b816", "s"),
))

In [None]:
fig = plot_predetermined_split_results((
    ModelForAnalysis("SCEPTR (finetuned)", "ovr_predetermined_split_filtered_avg_dist", "#5f3dc4", "P", zorder=2),
    ModelForAnalysis("SCEPTR", "ovr_predetermined_split_filtered_nn", "#7048e8", "d", zorder=1.9),
    ModelForAnalysis("TCRdist", "ovr_predetermined_split_filtered_nn", "#f03e3e", "o", zorder=1.8),
    ModelForAnalysis("TCR-BERT", "ovr_predetermined_split_filtered_nn", "#74b816", "s"),
))

In [None]:
models = (
    ModelForAnalysis("SCEPTR (finetuned)", "ovr_predetermined_split_nn", "#5f3dc4", "P", zorder=2, display_name="SCEPTR (finetuned, NN)"),
    ModelForAnalysis("SCEPTR (finetuned)", "ovr_predetermined_split_avg_dist", "#5f3dc4", "P", ":", zorder=2, display_name="SCEPTR (finetuned, Avg Dist)"),
    ModelForAnalysis("SCEPTR", "ovr_predetermined_split_nn", "#b197fc", "d", display_name="SCEPTR (NN)"),
    ModelForAnalysis("SCEPTR", "ovr_predetermined_split_avg_dist", "#b197fc", "d", ":", display_name="SCEPTR (Avg Dist)"),
)

aucs_per_model = []

for model in models:
    aucs_per_model.append(model.load_data().set_index("epitope"))

for model, aucs in zip(models, aucs_per_model):
    aucs.columns = [model.name]

aucs_per_model = pd.concat(aucs_per_model, axis="columns")
aucs_per_model = aucs_per_model.loc[["TFEYVSQPFLMDLE","GILGFVFTL","SPRWYFYYL","YLQPRTFLL","TTDPSFLGRY","NLVPMVATV"]]

fig, ax = plt.subplots(figsize=(8/2.54,8/2.54))

for model in models:
    ax.plot(aucs_per_model[model.name], model.style, c=model.colour, label=model.name, zorder=model.zorder)

ax.set_ylabel("AUROC")
ax.set_xlabel("Epitope")
ax.tick_params(axis='x', labelrotation=45)

current_min = ax.get_ylim()[0]
new_min = min(current_min, 0.5)
ax.set_ylim(new_min)

fig.legend(loc="center left", bbox_to_anchor=(1,0,1,1))
fig.tight_layout()

# fig.savefig(f"sceptr_baseline_vs_finetuned_nn_and_avg_dist.pdf", bbox_inches="tight")

In [None]:
models = (
    ModelForAnalysis("SCEPTR", "ovr_unseen_epitopes_nn", "#7048e8", "d", zorder=2),
    ModelForAnalysis("SCEPTR (finetuned)", "ovr_unseen_epitopes_avg_dist", "#5f3dc4", "P", zorder=2),
    ModelForAnalysis("TCRdist", "ovr_unseen_epitopes_nn", "#f03e3e", "o", zorder=1.9),
    ModelForAnalysis("TCR-BERT", "ovr_unseen_epitopes_nn", "#74b816", "s"),
)

fig, ax = plt.subplots(figsize=(8/2.54,8/2.54))

utils.plot_performance_curves(models, (1,2,5,10,20), ("CINGVCWTV","GLCTLVAML","LLWNGPMAV","ATDALMTGF","QYIKWPWYI","LTDEMIAQY"), ax)

handles, labels = ax.get_legend_handles_labels()
new_handles = [
    plt.Line2D(
        [0], [0],
        color=handle[0].get_color(),
        lw=handle[0].get_linewidth(),
        linestyle=handle[0].get_linestyle(),
        marker=handle[0].get_marker(),
        markersize=handle[0].get_markersize()
    )
    for handle in handles
    ]
ax.legend(handles=new_handles, labels=labels)

fig.tight_layout()