In [1]:
import os
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    PROJECT_ROOT = Path.cwd().parent.resolve()

os.chdir(PROJECT_ROOT)

In [2]:
from matplotlib import pyplot as plt
from matplotlib.transforms import ScaledTranslation
import pandas as pd
import utils
from utils import ModelForAnalysis

plt.style.use("ggplot")
plt.style.use("my.mplstyle")

In [3]:
NUM_SHOTS_OF_INTEREST = [1,2,5,10,20,50,100,200]
LARGELY_SAMPLED_EPITOPES = pd.read_csv("analysis_results/CDR3 Levenshtein/ovr_nn_200_shot.csv").epitope.unique()

In [None]:
reference_models = [
    ModelForAnalysis("SCEPTR", "ovr_nn", "#7048e8", "d", zorder=2),
    ModelForAnalysis("TCRdist", "ovr_nn", "#f03e3e", "o"),
    ModelForAnalysis("TCR-BERT", "ovr_nn", "#74b816", "s")
]

mlm_only = ModelForAnalysis("SCEPTR (MLM only)", "ovr_nn", "#7048e8", linestyle="--", zorder=2)
avg_pool = ModelForAnalysis("SCEPTR (average pooling)", "ovr_nn", "#7048e8", linestyle="--", zorder=2)
shuffled_data = ModelForAnalysis("SCEPTR (shuffled data)", "ovr_nn", "#7048e8", linestyle="--", zorder=2)
synthetic_data = ModelForAnalysis("SCEPTR (synthetic data)", "ovr_nn", "#7048e8", linestyle=":", zorder=2)
cdr3_only = ModelForAnalysis("SCEPTR (CDR3 only)", "ovr_nn", "#b197fc", zorder=2)
cdr3_mlm_only = ModelForAnalysis("SCEPTR (CDR3 only, MLM only)", "ovr_nn", "#b197fc", linestyle="--", zorder=2)

fig, axs = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(14/2.54,12/2.54))

utils.plot_performance_curves(reference_models + [mlm_only], NUM_SHOTS_OF_INTEREST, LARGELY_SAMPLED_EPITOPES, axs[0,0])
axs[0,0].set_title("Training Ablation")
handles, labels = utils.get_legend_handles_labels_without_errorbars(axs[0,0])
axs[0,0].legend(handles[3:], labels[3:], loc="lower right")

utils.plot_performance_curves(reference_models + [avg_pool], NUM_SHOTS_OF_INTEREST, LARGELY_SAMPLED_EPITOPES, axs[0,1])
axs[0,1].set_title("Architectural Ablation")
handles, labels = utils.get_legend_handles_labels_without_errorbars(axs[0,1])
axs[0,1].legend(handles[3:], labels[3:], loc="lower right")

utils.plot_performance_curves(reference_models + [shuffled_data, synthetic_data], NUM_SHOTS_OF_INTEREST, LARGELY_SAMPLED_EPITOPES, axs[1,0])
axs[1,0].set_title("Data Ablation")
handles, labels = utils.get_legend_handles_labels_without_errorbars(axs[1,0])
axs[1,0].legend(handles[3:], labels[3:], loc="lower right")

utils.plot_performance_curves(reference_models + [cdr3_only, cdr3_mlm_only], NUM_SHOTS_OF_INTEREST, LARGELY_SAMPLED_EPITOPES, axs[1,1])
axs[1,1].set_title("Feature Ablation")
handles, labels = utils.get_legend_handles_labels_without_errorbars(axs[1,1])
axs[1,1].legend(handles[3:], labels[3:], loc="lower right")

for ax in axs[0]:
    ax.set_xlabel("")

for ax in axs[:,1]:
    ax.set_ylabel("")

for ax, label in zip(axs[:,0], ("a", "c")):
    trans = ScaledTranslation(-50/100, 0, fig.dpi_scale_trans)
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans, fontsize='large', fontweight="bold", va='top')

for ax, label in zip(axs[:,1], ("b", "d")):
    trans = ScaledTranslation(-20/100, 0, fig.dpi_scale_trans)
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans, fontsize='large', fontweight="bold", va='top')

fig.tight_layout()

fig.legend(handles[:3], labels[:3], loc="center left", bbox_to_anchor=(1, 0, 0.5, 1), ncols=1)
fig.savefig("ablation_summary.pdf", bbox_inches="tight")