In [39]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from pandas import DataFrame
from pathlib import Path
from tqdm import tqdm

plt.style.use("ggplot")
plt.style.use("../my.mplstyle")

In [2]:
tcr_data_path = Path("../tcr_data/")

In [3]:
vdjdb = pd.read_csv(
    tcr_data_path/"raw"/"vdjdb"/"vdjdb_20240128.tsv",
    sep="\t"
)

In [None]:
vdjdb

In [9]:
alphas = vdjdb[(vdjdb["complex.id"] == 0) & (vdjdb["Gene"] == "TRA")][["V", "J", "CDR3", "Epitope"]].drop_duplicates()
betas = vdjdb[(vdjdb["complex.id"] == 0) & (vdjdb["Gene"] == "TRB")][["V", "J", "CDR3", "Epitope"]].drop_duplicates()

In [None]:
def group_paired_chains(df: DataFrame) -> DataFrame:
    df = df[df["complex.id"] != 0]

    reformatted_rows = []

    sc_complex_ids = df["complex.id"].unique()
    for complex_id in tqdm(sc_complex_ids):
        tcr_info = df[df["complex.id"] == complex_id]

        if tcr_info.shape[0] != 2:
            print(tcr_info)
            raise RuntimeError

        tra_info = tcr_info[tcr_info["Gene"] == "TRA"].iloc[0]
        trb_info = tcr_info[tcr_info["Gene"] == "TRB"].iloc[0]

        reformatted_rows.append(
            {
                "TRAV": tra_info["V"],
                "CDR3A": tra_info["CDR3"],
                "TRAJ": tra_info["J"],
                "TRBV": trb_info["V"],
                "CDR3B": trb_info["CDR3"],
                "TRBJ": trb_info["J"],
                "Epitope": tra_info["Epitope"],
            }
        )

    reformatted_df = DataFrame.from_records(reformatted_rows)
    reformatted_df = reformatted_df.drop_duplicates()
    return reformatted_df

alpha_betas = group_paired_chains(vdjdb)

In [None]:
plt.figure(figsize=(8/2.54,6/2.54))
plt.yscale("log")

tcr_nums_alpha = alphas.groupby("Epitope").size().sort_values(ascending=False)
tcr_nums_beta = betas.groupby("Epitope").size().sort_values(ascending=False)
tcr_nums_alphabeta = alpha_betas.groupby("Epitope").size().sort_values(ascending=False)

plt.plot(range(len(tcr_nums_alphabeta)), tcr_nums_alphabeta, label=r"$\alpha\beta$")
plt.plot(range(len(tcr_nums_alpha)), tcr_nums_alpha, label=r"$\alpha$")
plt.plot(range(len(tcr_nums_beta)), tcr_nums_beta, label=r"$\beta$")

plt.xlabel("Ranked pMHCs")
plt.ylabel("Number of TCR binders")
plt.title("TCR binders by pMHC")

plt.legend()
plt.tight_layout()

plt.savefig("tcr_pmhc_data_distribution.svg", bbox_inches="tight")