In [1]:
import os
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    PROJECT_ROOT = Path.cwd().parent.resolve()

os.chdir(PROJECT_ROOT)

In [None]:
from matplotlib import pyplot as plt
from matplotlib import colors as mcolors
from matplotlib.transforms import ScaledTranslation
from matplotlib.gridspec import GridSpec
import numpy as np
from numpy import ndarray
import pandas as pd
from pandas import DataFrame
from paths import DATA_DIR
from pyrepseq.metric import tcr_metric
from sceptr import variant
from scipy import stats
import seaborn as sns

plt.style.use("ggplot")
plt.style.use("my.mplstyle")

In [3]:
# Load Tanno test data and take small subsample
tanno_test = pd.read_csv(DATA_DIR/"preprocessed"/"tanno"/"test.csv")
tanno_sample = tanno_test.sample(n=1000, random_state=420)

In [None]:
tanno_sample.head()

In [5]:
# Load in all necessary models
sceptr_model = variant.default()
tcrdist_model = tcr_metric.Tcrdist()
tcrdist_a_model = tcr_metric.AlphaTcrdist()
tcrdist_b_model = tcr_metric.BetaTcrdist()

In [None]:
# Compute pdists
sceptr_pdist = sceptr_model.calc_pdist_vector(tanno_sample)
tcrdist_pdist = tcrdist_model.calc_pdist_vector(tanno_sample)
tcrdist_a_pdist = tcrdist_a_model.calc_pdist_vector(tanno_sample)
tcrdist_b_pdist = tcrdist_b_model.calc_pdist_vector(tanno_sample)

In [7]:
def random_subsample_indices(k: int, out_of: int, seed = None):
    if seed is not None:
        np.random.seed(seed)
    return np.random.choice(out_of, k, replace=False)

In [8]:
def plot_line_best_fit(x, y, ax = None, c = None):
    if ax is None:
        ax = plt.gca()

    y = y.astype(float)
    x = x.astype(float)

    w = np.linalg.lstsq(np.vstack([x, np.ones(len(x))]).T, y, rcond=None)[0]

    xx = np.linspace(x.min(), x.max()).T
    yy = w[0]*xx + w[1]

    ax.plot(xx, yy, c=c)

## Main text figure
- Overview scatter with correlation coefficient
- Scatter coloured by pGen
- pGen as a covariate of TCRdist

In [9]:
# Calculate density estimates
coords = np.vstack([sceptr_pdist, tcrdist_pdist])
coords_1k = coords[:, random_subsample_indices(10_000, coords.shape[1])]
gaussian_kde = stats.gaussian_kde(coords_1k)
density_estimates = gaussian_kde(coords)

In [10]:
# Calculate average p_Gen
pgens = tanno_sample.apply(
    lambda row: row["alpha_pgen"] * row["beta_pgen"],
    axis="columns"
).to_numpy()

num_tcrs = len(tanno_sample)
avg_pgens = np.empty(num_tcrs * (num_tcrs - 1) // 2)
min_pgens = np.empty(num_tcrs * (num_tcrs - 1) // 2)

pair_idx = 0
for anchor_idx in range(num_tcrs-1):
    for comparison_idx in range(anchor_idx+1, num_tcrs):
        avg_pgens[pair_idx] = (pgens[anchor_idx] + pgens[comparison_idx]) / 2
        min_pgens[pair_idx] = min(pgens[anchor_idx], pgens[comparison_idx])
        pair_idx += 1

In [None]:
fig = plt.figure(figsize=(8/2.54, 18/2.54))
ax_overview = fig.add_subplot(3,9,(1,8))
ax_pgen_scatter = fig.add_subplot(3,9,(10,17))
ax_pgen_scatter_cbar = fig.add_subplot(3,9,18)
ax_pgen_vs_tcrdist = fig.add_subplot(3,9,(19,26))

ax_overview.scatter(*coords, s=1, c=density_estimates, rasterized=True)
cor_results = stats.pearsonr(*coords)
ax_overview.text(1.35, 75, f"$r={cor_results.statistic:.3f}$\n$p < 1e-126$")
ax_overview.set_ylabel("TCRdist distance")
ax_overview.set_xlabel("SCEPTR distance")

scatter_mappable = ax_pgen_scatter.scatter(*coords, s=1, c=min_pgens, norm=mcolors.LogNorm(), rasterized=True)
cb = plt.colorbar(scatter_mappable, cax=ax_pgen_scatter_cbar)
cb.set_label(r"$p_{Gen}$")
ax_pgen_scatter.set_ylabel("TCRdist distance")
ax_pgen_scatter.set_xlabel("SCEPTR distance")

non_zero_pgen = min_pgens > 0
close_sceptr_mask = (sceptr_pdist >= 0.98) * (sceptr_pdist <= 1.02) * non_zero_pgen
ax_pgen_vs_tcrdist.scatter(np.log10(min_pgens[close_sceptr_mask]), tcrdist_pdist[close_sceptr_mask], s=10)
plot_line_best_fit(np.log10(min_pgens[close_sceptr_mask]), tcrdist_pdist[close_sceptr_mask], c="k", ax=ax_pgen_vs_tcrdist)
cor_results = stats.pearsonr(np.log10(min_pgens[close_sceptr_mask]), tcrdist_pdist[close_sceptr_mask])
ax_pgen_vs_tcrdist.text(-29,100,f"$r = {cor_results.statistic:.3f}$\n$p = {cor_results.pvalue:.2e}$")
ax_pgen_vs_tcrdist.set_xlabel(r"$\log_{10}(p_{Gen})$")
ax_pgen_vs_tcrdist.set_ylabel("TCRdist distance")

for ax, label in zip([ax_overview, ax_pgen_scatter, ax_pgen_vs_tcrdist], ("a", "b", "c")):
    trans = ScaledTranslation(-40/100, 0, fig.dpi_scale_trans)
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans, fontsize='large', fontweight="bold", va='top')

fig.tight_layout()

plt.show()

fig.savefig("sceptr_vs_tcrdist.pdf", bbox_inches="tight")

In [None]:
# SCEPTR distance bins
bin_delimiters = np.linspace(1,1.6,13)
non_zero_pgen = min_pgens > 0

current_min = 0
pearson_rs = np.empty(14)
pearson_r_lower_bounds = np.empty(14)
pearson_r_upper_bounds = np.empty(14)

for i, delimiter in enumerate(bin_delimiters):
    sceptr_dist_mask = (sceptr_pdist >= current_min) * (sceptr_pdist < delimiter)
    results = stats.pearsonr(tcrdist_pdist[sceptr_dist_mask * non_zero_pgen], np.log10(min_pgens[sceptr_dist_mask * non_zero_pgen]))
    bounds = results.confidence_interval()
    pearson_rs[i] = results.statistic
    pearson_r_lower_bounds[i] = bounds.low
    pearson_r_upper_bounds[i] = bounds.high

    current_min = delimiter

sceptr_dist_mask = sceptr_pdist >= current_min
results = stats.pearsonr(tcrdist_pdist[sceptr_dist_mask * non_zero_pgen], np.log10(min_pgens[sceptr_dist_mask * non_zero_pgen]))
bounds = results.confidence_interval()
pearson_rs[13] = results.statistic
pearson_r_lower_bounds[13] = bounds.low
pearson_r_upper_bounds[13] = bounds.high

fig = plt.figure(figsize=(10/2.54,8/2.54))

plt.plot(pearson_rs)
plt.fill_between(np.arange(14), pearson_r_lower_bounds, pearson_r_upper_bounds, alpha=0.3)

xticks = np.linspace(0,13)
xtick_labels = []
for i in range(len(bin_delimiters)-1):
    xtick_labels.append(f"$[{bin_delimiters[i]:.02f}, {bin_delimiters[i+1]:0.2f})$")

plt.xticks(np.arange(14), ["$<1.00$"] + xtick_labels + ["$\geq 1.60$"], rotation=90)

plt.ylabel("Pearson $r$")
plt.xlabel("SCEPTR distance")

plt.show()

fig.savefig("pgen_vs_tcrdist.pdf", bbox_inches="tight")

### Why do some pairs with similar TCRdist have different SCEPTR dists?

In [None]:
tcrdist_a_pdist_normed = tcrdist_a_pdist / np.mean(tcrdist_a_pdist)
tcrdist_b_pdist_normed = tcrdist_b_pdist / np.mean(tcrdist_b_pdist)

tcrdist_min_chain = np.min(np.stack([tcrdist_a_pdist, tcrdist_b_pdist], axis=1), axis=1)
tcrdist_max_chain = np.max(np.stack([tcrdist_a_pdist, tcrdist_b_pdist], axis=1), axis=1)

fig = plt.figure(figsize=(9/2.54,8/2.54))

plt.scatter(sceptr_pdist, tcrdist_pdist, s=1, c=tcrdist_max_chain - tcrdist_min_chain, vmin=0, vmax=80, rasterized=True)

plt.xlabel("SCEPTR distance")
plt.ylabel("TCRdist distance")
cb = plt.colorbar()
cb.set_label(r"$|$ $d_\alpha$ - $d_\beta$ $|$ (TCRdist)")

plt.show()

fig.savefig("sceptr_vs_tcrdist_chain_delta.pdf", bbox_inches="tight")

In [14]:
# Investigate power means
alphas = np.linspace(-10,10,21)
rhos = [
    stats.pearsonr(
        sceptr_pdist,
        stats.pmean(
            np.vstack([tcrdist_a_pdist_normed, tcrdist_b_pdist_normed]),
            p=alpha,
            axis=0
        )
    ).statistic
    for alpha in alphas
]

min_rho = stats.pearsonr(
    sceptr_pdist,
    np.min(np.stack([tcrdist_a_pdist_normed, tcrdist_b_pdist_normed], axis=1), axis=1)
).statistic

max_rho = stats.pearsonr(
    sceptr_pdist,
    np.max(np.stack([tcrdist_a_pdist_normed, tcrdist_b_pdist_normed], axis=1), axis=1)
).statistic

In [None]:
fig = plt.figure(figsize=(10/2.54,8/2.54))
plt.plot(alphas, rhos)

plt.scatter(15,max_rho, marker="^", c="C0", label="maximum ($p=\infty$)")
plt.scatter(1,rhos[10+1], marker="o", c="C0", label="arithmetic ($p=1$)")
plt.scatter(0,rhos[10+0], marker="x", c="C0", label="geometric ($p=0$)")
plt.scatter(-1,rhos[10-1], marker="s", c="C0", label="harmonic ($p=-1$)")
plt.scatter(-15,min_rho, marker="v", c="C0", label="minimum ($p=-\infty$)")

plt.xticks(np.linspace(-15,15,7), ["$-\infty$", "$-10$", "$-5$", "$0$", "$5$", "$10$", "$\infty$"])

plt.xlabel('Power mean exponent $p$')
plt.ylabel('Pearson r')

plt.legend(loc="lower left")

plt.show()
fig.savefig("sceptr_and_tcrdist_averaging.pdf", bbox_inches="tight")

## Fine-tuning distances

In [16]:
finetuned_model = variant.finetuned()

In [17]:
# Load labelled data
labelled_training = pd.read_csv(DATA_DIR/"preprocessed"/"benchmarking"/"train.csv")
labelled_testing = pd.read_csv(DATA_DIR/"preprocessed"/"benchmarking"/"test.csv")
labelled_testing = labelled_testing[labelled_testing["Epitope"].isin(labelled_training["Epitope"].unique())]

In [18]:
# Compute co-specific and non-co-specific distances
def get_cross_cospecificity_dists(anchors: DataFrame, comparisons: DataFrame, metric: tcr_metric.TcrMetric) -> tuple[ndarray, ndarray]:
    cdist_matrix = metric.calc_cdist_matrix(anchors, comparisons)
    cospecificity_mask = anchors["Epitope"].to_numpy()[:,np.newaxis] == comparisons["Epitope"].to_numpy()[np.newaxis,:]

    cospecific_distances = cdist_matrix.flatten()[cospecificity_mask.flatten()]
    cross_specific_distances = cdist_matrix.flatten()[np.logical_not(cospecificity_mask.flatten())]

    return (cospecific_distances, cross_specific_distances)

default_dists = get_cross_cospecificity_dists(labelled_training, labelled_testing, sceptr_model)
finetuned_dists = get_cross_cospecificity_dists(labelled_training, labelled_testing, finetuned_model)

In [19]:
data = DataFrame({
    "SCEPTR (default) distances": np.concatenate([default_dists[0], default_dists[1]]),
    "SCEPTR (finetuned) distances": np.concatenate([finetuned_dists[0], finetuned_dists[1]]),
    "Cospecific": [True] * len(default_dists[0]) + [False] * len(default_dists[1])
})
data = data.sample(frac=1, random_state=420).reset_index(drop=True)

In [None]:
marginal_ratio = 4
x_name = "SCEPTR (default) distances"
y_name = "SCEPTR (finetuned) distances"
cross_spec_colour = "C0"
cospec_colour = "C1"

fig = plt.figure(figsize=(8/2.54, 8/2.54))
gs = GridSpec(marginal_ratio+1, marginal_ratio+1)

ax_joint = fig.add_subplot(gs[1:, :-1])
ax_joint.set_ylabel(y_name)
ax_joint.set_xlabel(x_name)

handles = [
    plt.Line2D([0], [0], marker="o", color="#ffffff00", markerfacecolor=cospec_colour, markersize=3),
    plt.Line2D([0], [0], marker="o", color="#ffffff00", markerfacecolor=cross_spec_colour, markersize=3),
]
labels = ["True", "False"]
ax_joint.legend(handles, labels, loc="upper left", title="Cospecific")

ax_marginal_x = fig.add_subplot(gs[0, :-1])
ax_marginal_x.set_xticklabels([])
ax_marginal_x.set_yticks([])

ax_marginal_y = fig.add_subplot(gs[1:, -1])
ax_marginal_y.set_yticklabels([])
ax_marginal_y.set_xticks([])

clist = data["Cospecific"].map({
    True: cospec_colour,
    False: cross_spec_colour
}).array

ax_joint.scatter(x=data[x_name], y=data[y_name], s=5, c=clist, edgecolors="white", rasterized=True)

cross_specific_points = data[data["Cospecific"] == False]
cospecific_points = data[data["Cospecific"] == True]

x_linspace = np.linspace(data[x_name].min(), data[x_name].max(), 100)
cospecific_x_kde = stats.gaussian_kde(cospecific_points[x_name])(x_linspace)
cross_specific_x_kde = stats.gaussian_kde(cross_specific_points[x_name])(x_linspace)
ax_marginal_x.fill_between(x_linspace, cross_specific_x_kde, color=cross_spec_colour, alpha=0.3)
ax_marginal_x.fill_between(x_linspace, cospecific_x_kde, color=cospec_colour, alpha=0.3)
ax_marginal_x.plot(x_linspace, cross_specific_x_kde, c=cross_spec_colour)
ax_marginal_x.plot(x_linspace, cospecific_x_kde, c=cospec_colour)

y_linspace = np.linspace(data[y_name].min(), data[y_name].max())
cospecific_y_kde = stats.gaussian_kde(cospecific_points[y_name])(y_linspace)
cross_specific_y_kde = stats.gaussian_kde(cross_specific_points[y_name])(y_linspace)
ax_marginal_y.fill_betweenx(y_linspace, cross_specific_y_kde, color=cross_spec_colour, alpha=0.3)
ax_marginal_y.fill_betweenx(y_linspace, cospecific_y_kde, color=cospec_colour, alpha=0.3)
ax_marginal_y.plot(cross_specific_y_kde, y_linspace, c=cross_spec_colour)
ax_marginal_y.plot(cospecific_y_kde, y_linspace, c=cospec_colour)

fig.tight_layout()

fig.savefig("pretrained_vs_finetuned_sceptr.pdf", bbox_inches="tight")

plt.show()