In [1]:
import os
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    PROJECT_ROOT = Path.cwd().parent.resolve()

os.chdir(PROJECT_ROOT)

In [58]:
import numpy as np
import pandas as pd
from paths import DATA_DIR
from pyrepseq.metric import tcr_metric
import tidytcells as tt

In [3]:
train = pd.read_csv(DATA_DIR/"preprocessed"/"benchmarking"/"train.csv")
test = pd.read_csv(DATA_DIR/"preprocessed"/"benchmarking"/"test.csv")

In [4]:
test_discrimination = test[test["Epitope"].isin(train["Epitope"].unique())]

In [None]:
test_discrimination.head()

In [8]:
cdr3_levenshtein = tcr_metric.Cdr3Levenshtein()

In [21]:
cdist_matrix = cdr3_levenshtein.calc_cdist_matrix(test_discrimination, train)

In [59]:
same_trav = np.empty_like(cdist_matrix)
for i, anc_trav in enumerate(test_discrimination.TRAV.map(lambda x: tt.tr.standardize(x, precision="gene"))):
    for j, comp_trav in enumerate(train.TRAV.map(lambda x: tt.tr.standardize(x, precision="gene"))):
        same_trav[i,j] = anc_trav == comp_trav

same_trbv = np.empty_like(cdist_matrix)
for i, anc_trbv in enumerate(test_discrimination.TRBV.map(lambda x: tt.tr.standardize(x, precision="gene"))):
    for j, comp_trbv in enumerate(train.TRBV.map(lambda x: tt.tr.standardize(x, precision="gene"))):
        same_trbv[i,j] = anc_trbv == comp_trbv

same_vs = same_trav * same_trbv

In [60]:
updated_cdist = cdist_matrix + ((1-same_vs) * 99999)
nn_dists = np.min(updated_cdist, axis=1)

In [61]:
combined_cdr3_length = test_discrimination.apply(
    lambda row: len(row.CDR3A) + len(row.CDR3B),
    axis='columns'
).to_numpy()

In [62]:
seq_identity = 1 - nn_dists / combined_cdr3_length

In [63]:
threshold = 0.95
legal_test_seq_mask = seq_identity < threshold

In [None]:
test_discrimination[legal_test_seq_mask].groupby("Epitope").count()

In [None]:
test_discrimination.groupby("Epitope").count()