In [None]:
import os
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    PROJECT_ROOT = Path.cwd().parent.resolve()

os.chdir(PROJECT_ROOT)

In [None]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from pandas import DataFrame, Series
from paths import DATA_DIR
from pyrepseq.metric import tcr_metric
from sceptr import variant
from scipy import stats

plt.style.use("ggplot")
plt.style.use("my.mplstyle")

In [None]:
def get_nn_pairs(anchors: DataFrame, comparisons: DataFrame, model) -> list[tuple[int, int]]:
    anchor_indices = anchors.index.to_list()
    cdist_matrix = model.calc_cdist_matrix(anchors, comparisons)
    cdist_matrix[range(len(anchors)),anchor_indices] = 1000

    nn_indices = np.argmin(cdist_matrix, axis=1)
    return [tuple(sorted((a_idx, c_idx))) for a_idx, c_idx in zip(anchor_indices, nn_indices)]

In [None]:
# Load all TCRs from Tanno test set
tanno_test = pd.read_csv(DATA_DIR/"preprocessed"/"tanno"/"test.csv")

# Get 500 random near-distance TCR pairs according to SCEPTR
sceptr_model = variant.default()
sceptr_500_tcrs = tanno_test.sample(n=500, random_state=420)
sceptr_500_pairs = get_nn_pairs(sceptr_500_tcrs, tanno_test, sceptr_model)

In [None]:
# Get 500 random near-distance TCR pairs according to TCRdist
tcrdist_model = tcr_metric.Tcrdist()
tcrdist_500_tcrs = tanno_test.sample(n=500, random_state=421)
tcrdist_500_pairs = get_nn_pairs(tcrdist_500_tcrs, tanno_test, tcrdist_model)

In [None]:
combined_pairs = list(set(sceptr_500_pairs + tcrdist_500_pairs))

sceptr_dists = []
tcrdist_dists = []

for pair in combined_pairs:
    tcr_pair = tanno_test.loc[list(pair)]
    sceptr_dists.append(sceptr_model.calc_pdist_vector(tcr_pair).item())
    tcrdist_dists.append(tcrdist_model.calc_pdist_vector(tcr_pair).item())

In [None]:
# Calculate density estimates
coords = np.vstack([sceptr_dists, tcrdist_dists])
gaussian_kde = stats.gaussian_kde(coords)
density_estimates = gaussian_kde(coords)

In [None]:
plt.figure(figsize=(4,4))
plt.scatter(sceptr_dists, tcrdist_dists, s=10, c=density_estimates)
plt.xlabel("SCEPTR distance")
plt.ylabel("TCRdist distance")
plt.show()

In [None]:
tcr_dist_df = DataFrame(combined_pairs, columns=["anchor_idx", "comparison_idx"])
tcr_dist_df["SCEPTR"] = sceptr_dists
tcr_dist_df["TCRdist"] = tcrdist_dists
tcr_dist_df.head()

In [None]:
summaries = []
indices = []

def get_summary_for_pair_of_interest(pair_info, source_tcr_df) -> Series:
    tcrs = tanno_test.loc[pair_info.loc[["anchor_idx", "comparison_idx"]]]
    tcrs = tcrs[["TRAV", "CDR3A", "TRBV", "CDR3B"]]

    anchor_tcr = tcrs.iloc[0]
    anchor_tcr.index = ["TRAV_anchor", "CDR3A_anchor", "TRBV_anchor", "CDR3B_anchor"]

    comp_tcr = tcrs.iloc[1]
    comp_tcr.index = ["TRAV_comp", "CDR3A_comp", "TRBV_comp", "CDR3B_comp"]

    summary = pd.concat([anchor_tcr, comp_tcr, pair_info[["SCEPTR", "TCRdist"]]])

    return summary

def get_pair_indices(pair_info) -> tuple[int, int]:
    return tuple(pair_info.loc[["anchor_idx", "comparison_idx"]])

In [None]:
# Find the pair with the max tcrdist
max_tcrdist_pair = tcr_dist_df.iloc[tcr_dist_df["TCRdist"].argmax()]

summaries.append(get_summary_for_pair_of_interest(max_tcrdist_pair, tanno_test))
indices.append(get_pair_indices(max_tcrdist_pair))

In [None]:
# Find pair with comparable SCEPTR dist but lower TCRdist
temp_df = tcr_dist_df[tcr_dist_df["SCEPTR"] >= max_tcrdist_pair["SCEPTR"]]
comparison_pair = temp_df.iloc[temp_df["TCRdist"].argmin()]

summaries.append(get_summary_for_pair_of_interest(comparison_pair, tanno_test))
indices.append(get_pair_indices(comparison_pair))

In [None]:
# Find pair with max SCEPTR dist
max_sceptr_dist_pair = tcr_dist_df.iloc[tcr_dist_df["SCEPTR"].argmax()]

summaries.append(get_summary_for_pair_of_interest(max_sceptr_dist_pair, tanno_test))
indices.append(get_pair_indices(max_sceptr_dist_pair))

In [None]:
# Find pair with comparable TCRdist but lower SCEPTR dist
temp_df = tcr_dist_df[tcr_dist_df["TCRdist"] >= max_sceptr_dist_pair["TCRdist"]]
comparison_pair = temp_df.iloc[temp_df["SCEPTR"].argmin()]

summaries.append(get_summary_for_pair_of_interest(comparison_pair, tanno_test))
indices.append(get_pair_indices(comparison_pair))

In [None]:
summaries = DataFrame(summaries)
summaries.to_csv("pairs_of_interest.csv", index=False)

In [None]:
plt.figure(figsize=(4,4))

plt.scatter(sceptr_dists, tcrdist_dists, s=10, c=density_estimates)
plt.scatter(summaries["SCEPTR"], summaries["TCRdist"], s=10, c="r")

plt.xlabel("SCEPTR distance")
plt.ylabel("TCRdist distance")
plt.show()