# Benchmarking simple CDR3 Levenshtein distance

## Importing modules and data

In [1]:
import sys
import os
sys.path.append('/home/yutanagano/Projects/tcr_embedder')
os.chdir('/home/yutanagano/Projects/tcr_embedder')

In [2]:
from itertools import combinations
import json
from math import exp, log
import numpy as np
import pandas as pd
from pandas import DataFrame, notna
from pathlib import Path
from polyleven import levenshtein
import random
import seaborn
from sklearn.neighbors import KNeighborsClassifier
from statistics import mean
import torch
from tqdm import tqdm

seaborn.set_theme()
seaborn.set_style('white')

In [3]:
back_df = pd.read_csv('/home/yutanagano/UCLOneDrive/MBPhD/projects/tcr_embedder/data/tanno_processed/test.csv').iloc[:1000]
ep_df = pd.read_csv('/home/yutanagano/UCLOneDrive/MBPhD/projects/tcr_embedder/data/vdjdb/evaluation.csv')

back_df = back_df[['CDR3A', 'CDR3B', 'Epitope']]
ep_df = ep_df[['CDR3A', 'CDR3B', 'Epitope']]

### Load model

In [4]:
def cdr3_leven_dist(
    cdr3a_1: str,
    cdr3a_2: str,
    cdr3b_1: str,
    cdr3b_2: str
) -> float:
    dists = []

    if notna(cdr3a_1) and notna(cdr3a_2):
        dists.append(levenshtein(cdr3a_1, cdr3a_2))
    
    if notna(cdr3b_1) and notna(cdr3b_2):
        dists.append(levenshtein(cdr3b_1, cdr3b_2))

    if len(dists) == 0:
        return None
    
    return float(mean(dists))

### Create benchmarking directory

In [5]:
BENCHMARK_DIR = Path(f'benchmarks/cdr3_levenshtein')
if not BENCHMARK_DIR.is_dir():
    BENCHMARK_DIR.mkdir()

## Calculate alignment and uniformity

In [6]:
# def cdr3_leven_alignment(ep_df: DataFrame) -> float:
#     pairs_by_ep = ep_df.groupby('Epitope').apply(lambda subdf: subdf.apply(lambda x: list(combinations(x, 2))))
#     pairs_by_ep = pairs_by_ep.reset_index(drop=True)

#     pairs_by_ep['Epitope'] = pairs_by_ep['Epitope'].map(lambda x: x[0])
#     pairs_by_ep['d'] = pairs_by_ep.apply(lambda row: cdr3_leven_dist(row['CDR3A'][0], row['CDR3A'][1], row['CDR3B'][0], row['CDR3B'][1]), axis=1)

#     distances_by_ep = pairs_by_ep.groupby('Epitope')['d'].mean()

#     return distances_by_ep.mean()

In [7]:
# def cdr3_leven_uniformity(back_df: DataFrame) -> float:
#     pairs = back_df.apply(lambda x: list(combinations(x, 2)))
#     pairs = pairs.reset_index(drop=True)

#     pairs['Epitope'] = pairs['Epitope'].map(lambda x: x[0])
#     pairs['d'] = pairs.apply(lambda row: cdr3_leven_dist(row['CDR3A'][0], row['CDR3A'][1], row['CDR3B'][0], row['CDR3B'][1]), axis=1)
#     exp_neg_dists = pairs['d'].map(lambda d: exp(-d))

#     return log(exp_neg_dists.mean())

In [8]:
# alignment = cdr3_leven_alignment(ep_df)

In [9]:
# uniformity = cdr3_leven_uniformity(ep_df)

## k-NN evaluation of embeddings

In [10]:
def knn_cdr3_leven_dist(x: np.ndarray, y: np.ndarray) -> float:
    tcr1 = ep_df.iloc[int(x[0])]
    tcr2 = ep_df.iloc[int(y[0])]
    dist = cdr3_leven_dist(tcr1['CDR3A'], tcr2['CDR3A'], tcr1['CDR3B'], tcr2['CDR3B'])

    if dist is None:
        return np.inf

    return dist

In [11]:
scores = []
tcr_ivecs = np.array([[idx] for idx in ep_df.index], dtype=float)
ep_cat_codes = ep_df['Epitope'].astype('category').cat.codes
ep_len = len(ep_df)

for i in tqdm(random.sample(range(ep_len), k=1000)):
    filt = np.ones(ep_len, dtype=bool)
    filt[i] = False

    loo_ivecs = tcr_ivecs[filt]
    loo_cat_codes = ep_cat_codes[filt]

    test_ivec = tcr_ivecs[[i]]
    expected_cat_code = ep_cat_codes[[i]]

    knn = KNeighborsClassifier(metric=knn_cdr3_leven_dist)
    knn.fit(loo_ivecs, loo_cat_codes)

    scores.append(knn.predict(test_ivec).item() == expected_cat_code.item())

knn_accuracy = torch.tensor(scores, dtype=torch.float32).mean().item()

  4%|▍         | 44/1000 [02:38<57:34,  3.61s/it]  


KeyboardInterrupt: 

## Write out evaluation metrics to json

In [None]:
metrics_dict = {
    'model_name': 'cdr3_levenshtein',
    # 'alignment': alignment,
    # 'uniformity': uniformity,
    # 'alignment + uniformity': alignment + uniformity,
    '5nn_accuracy': knn_accuracy
}

In [None]:
with open(BENCHMARK_DIR/'metrics.json', 'w') as f:
    json.dump(metrics_dict, f, indent=4)