In [1]:
import os
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    PROJECT_ROOT = Path.cwd().parent.resolve()

os.chdir(PROJECT_ROOT)

In [2]:
from matplotlib import pyplot as plt
from matplotlib.transforms import ScaledTranslation
import numpy as np
import pandas as pd
from utils import ModelForAnalysis

plt.style.use("ggplot")
plt.style.use("my.mplstyle")

In [None]:
models = (
    ModelForAnalysis("SCEPTR", None, None),
    ModelForAnalysis("TCR-BERT", None, None),
    ModelForAnalysis("SCEPTR (left-aligned)", None, None),
)
epc_insights_per_model = {model.name: model.load_epc_analyser().get_summary_df() for model in models}

fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(17/2.54,6/2.54), sharex=True)

position_array = np.arange(5) * 6
errorbar_kwargs = {
    "fmt": "-",
    "capsize": 3,
}

for ax, model_name in zip(axs, epc_insights_per_model):
    epc_insights = epc_insights_per_model[model_name]

    ax.errorbar(position_array, epc_insights["ins"], yerr=epc_insights["ins_std"], label="insertion", c="#7048e8", **errorbar_kwargs)
    ax.errorbar(position_array+1, epc_insights["del"], yerr=epc_insights["del_std"], label="deletion", c="#f76707", **errorbar_kwargs)
    ax.errorbar(position_array+2, epc_insights["sub"], yerr=epc_insights["sub_std"], label="substitution", c="#37b24d", **errorbar_kwargs)
    ax.set_xticks(
        position_array+1, epc_insights.index.str.replace("_","-")
    )

    ax.set_title(model_name)
    ax.set_ylim(0)

axs[0].set_ylabel("distance")
axs[1].set_xlabel("CDR3 region")

for ax, label in zip(axs, ("a", "", "b")):
    trans = ScaledTranslation(-20/100, 20/100, fig.dpi_scale_trans)
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans, fontsize='large', fontweight="bold", va='top')

fig.legend(*axs[0].get_legend_handles_labels(), loc="upper center", bbox_to_anchor=(0,-0.5,1,0.5), ncols=3)
fig.tight_layout()