In [17]:
import os
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    PROJECT_ROOT = Path.cwd().parent.resolve()

os.chdir(PROJECT_ROOT)

In [18]:
import json
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
import numpy as np
from paths import RESULTS_DIR
from typing import Literal

plt.style.use("ggplot")
plt.style.use("my.mplstyle")

In [65]:
MODEL_NAMES = (
    "ProtBert",
    "ESM2 (T6 8M)",
    "TCR-BERT",
    "CDR3 Levenshtein",
    "TCRdist",
    "SCEPTR",
)

EPITOPES = (
    "TFEYVSQPFLMDLE",
    "GILGFVFTL",
    "YLQPRTFLL",
    "NLVPMVATV",
    "SPRWYFYYL",
    "TTDPSFLGRY",
)

MODEL_COLOURS = {
    "TCRdist": "#f03e3e",
    "CDR3 Levenshtein": "#f76707",
    "TCR-BERT": "#74b816",
    "ESM2 (T6 8M)": "#37b24d",
    "ProtBert": "#0ca678",
    "SCEPTR": "#7048e8",
}

In [74]:
def load_roc_results(model_name: str) -> dict:
    with open(RESULTS_DIR/model_name/"individual_rocs.json", "r") as f:
        return json.load(f)

roc_results = {
    model_name: load_roc_results(model_name) for model_name in MODEL_NAMES
}

In [None]:
test_results = roc_results["CDR3 Levenshtein"]["1_shot_rocs"]["GILGFVFTL"]

tprs_mean = np.array(test_results["tprs_mean"])
tprs_std = np.array(test_results["tprs_std"])
tprs_upper_quartile = np.array(test_results["tprs_qt_0.75"])
tprs_lower_quartile = np.array(test_results["tprs_qt_0.25"])

roc_lower = np.clip(tprs_mean - tprs_std, 0, 1)
roc_upper = np.clip(tprs_mean + tprs_std, 0, 1)

plt.figure(figsize=(4,4))
plt.plot(np.linspace(0,1,101), tprs_mean)
plt.fill_between(np.linspace(0,1,101), tprs_lower_quartile, tprs_upper_quartile, alpha=0.2)
plt.show()

In [76]:
def plot_roc(model_name: str, num_shots: Literal["1", "200"], epitope: str, ax: Axes) -> None:
    results = roc_results[model_name][f"{num_shots}_shot_rocs"][epitope]

    tprs_mean = np.array(results["tprs_mean"])
    tprs_std = np.array(results["tprs_std"])

    roc_lower = np.clip(tprs_mean - tprs_std, 0, 1)
    roc_upper = np.clip(tprs_mean + tprs_std, 0, 1)

    # roc_lower = np.array(results["tprs_qt_0.25"])
    # roc_upper = np.array(results["tprs_qt_0.75"])

    ax.plot(np.linspace(0,1,101), tprs_mean, label=model_name, c=MODEL_COLOURS[model_name])
    ax.fill_between(np.linspace(0,1,101), roc_lower, roc_upper, color=MODEL_COLOURS[model_name], alpha=0.2)

In [None]:
epitope = EPITOPES[1]
num_shots = "1"

fig, axs = plt.subplots(nrows=2, ncols=6, sharex=True, sharey=True, figsize=(15,5))

for i, num_shots in enumerate(("1", "200")):
    for j, epitope in enumerate(EPITOPES):
        ax = axs[i,j]
        ax.plot([0,1],[0,1],"--k")
        for model_name in MODEL_NAMES:
            plot_roc(model_name, num_shots, epitope, ax)
        ax.set_ylim(0,1)
        ax.set_xlim(0,1)

        ax.set_title(f"{epitope} ($k={num_shots}$)")

        if i == 1:
            ax.set_xlabel("False Positive Rate")
        
        if j == 0:
            ax.set_ylabel("True Positive Rate")


labels = ["SCEPTR", "TCRdist", "CDR3 Levenshtein", "TCR-BERT", "ESM2 (T6 8M)", "ProtBert"]
handles = [
    plt.Line2D(
        [0], [0],
        color=MODEL_COLOURS[model_name],
    )
    for model_name in labels
]
labels[0] = "SCEPTR (ours)"
fig.legend(handles=handles, labels=labels, loc="upper center", bbox_to_anchor=(0,-1,1,1), ncols=6)