In [160]:
import numpy as np
from numpy import ndarray
import pandas as pd
from pathlib import Path
from scipy import stats
from src.model import tcr_metric, tcr_representation_model

### Prepare data

In [2]:
processed_data_path = Path.cwd().parent/"tcr_data"/"preprocessed"
model_saves_path = Path.cwd().parent/"model_saves"

In [3]:
labelled_data = pd.read_csv(processed_data_path/"gdb"/"test.csv")
background_data = pd.read_csv(processed_data_path/"tanno"/"test.csv")

In [158]:
LABELLED_DATA_CAT_CODES = labelled_data.Epitope.astype("category").cat.codes.to_numpy()

In [4]:
background_sample = background_data.sample(n=500, random_state=420)

In [5]:
evaluation_data = pd.concat([labelled_data, background_sample]).reset_index(drop=True)

### Load models

In [67]:
tcrdist = tcr_metric.BetaTcrdist()
blastr = tcr_representation_model.load_blastr_save(model_saves_path/"Beta_CDR_BERT_Large")

### Benchmarking

In [68]:
cdist_matrix = blastr.calc_cdist_matrix(evaluation_data, labelled_data)

In [159]:
def get_predictions_given_cdist_and_threshold(cdist: ndarray, threshold: float) -> ndarray:
    predictions = []
    within_threshold = (cdist <= threshold)

    for row in within_threshold:
        neighbour_catcodes = LABELLED_DATA_CAT_CODES[row]

        if len(neighbour_catcodes) == 0:
            prediction = -1
        else:
            prediction, _ = stats.mode(neighbour_catcodes, keepdims=False)
        
        predictions.append(prediction)
    
    return np.array(predictions)

In [161]:
get_predictions_given_cdist_and_threshold(cdist_matrix, 0.1)

array([ 0,  0,  0, ..., -1, -1, -1])