In [None]:
import os
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    PROJECT_ROOT = Path.cwd().parent.resolve()

os.chdir(PROJECT_ROOT)

In [None]:
from matplotlib import pyplot as plt
import numpy as np
from numpy import ndarray
import pandas as pd
from pyrepseq.metric import tcr_metric
from sceptr import variant
from scipy.spatial import distance
from sklearn import metrics
from typing import Tuple
import utils

plt.style.use("ggplot")
plt.style.use("my.mplstyle")

In [None]:
labelled_data = pd.read_csv("tcr_data/preprocessed/benchmarking/vdjdb_cleaned.csv")
labelled_data = labelled_data.drop_duplicates(subset=["TRAV","CDR3A","TRBV","CDR3B"])
ground_truth = distance.squareform(
    labelled_data.Epitope.to_numpy()[:,np.newaxis] == labelled_data.Epitope.to_numpy()[np.newaxis,:],
    checks=False
)

In [None]:
def calc_precision_recall(model) -> Tuple[ndarray, ndarray]:
    pdist = model.calc_pdist_vector(labelled_data)
    scores = utils.convert_dists_to_scores(pdist)
    precisions, recalls, _ = metrics.precision_recall_curve(ground_truth, scores, drop_intermediate=True)
    return recalls, precisions

def calc_roc(model) -> Tuple[ndarray, ndarray]:
    pdist = model.calc_pdist_vector(labelled_data)
    scores = utils.convert_dists_to_scores(pdist)
    fpr, tpr, _ = metrics.roc_curve(ground_truth, scores, drop_intermediate=True)
    return fpr, tpr

In [None]:
sceptr_curve = calc_precision_recall(variant.default())
tcrdist_curve = calc_precision_recall(tcr_metric.Tcrdist())

# sceptr_curve = calc_roc(variant.default())
# tcrdist_curve = calc_roc(tcr_metric.Tcrdist())

In [None]:
plt.figure()
plt.step(*sceptr_curve, where="post")
plt.step(*tcrdist_curve, where="post")
plt.xscale("log")
plt.show()