# Preprocessing Labelled Data For SCEPTR Benchmarking And Finetuning

## Sources
- VDJdb (as of the 28th of January, 2024)
- Preprocessed 10xGenomics whitepaper data (from Montemurro et al. 2023)

## Inclusion criteria
- Paired chain data

## Exclusion criteria
- 10xGenomics whitepaper (later replace with Morten Nielsen's ITRAP-filtered)
- Non-human

All data will be cleaned with tidytcells and only standardizable and functional TCR data will be passed on.

In [None]:
import pandas as pd
from pandas import DataFrame
from pathlib import Path
import tidytcells as tt
from tqdm import tqdm

In [None]:
tcr_data_path = Path("../tcr_data/")

In [None]:
vdjdb = pd.read_csv(
    tcr_data_path/"raw"/"vdjdb"/"vdjdb_20240128.tsv",
    sep="\t"
)

In [None]:
vdjdb.head()

In [None]:
references = vdjdb["Reference"].dropna()
references_for_10x = references[references.str.contains("10x")]

In [None]:
references_for_10x.unique()

In [None]:
def preprocess(df: DataFrame) -> DataFrame:
    df = enforce_exclusion_criteria(df)

    df = group_paired_chains(df)
    df = drop_rows_with_missing_data(df)
    df = standardize_nomenclature(df)

    df = df.reset_index(drop=True)
    return df

def enforce_exclusion_criteria(df: DataFrame) -> DataFrame:
    df = remove_non_human_data(df)
    df = remove_10x_data(df)
    df = remove_single_chain_data(df)
    return df.copy()

def remove_non_human_data(df: DataFrame) -> DataFrame:
    return df[df["Species"] == "HomoSapiens"].copy()

def remove_10x_data(df: DataFrame) -> DataFrame:
    return df[df["Reference"] != "https://www.10xgenomics.com/resources/application-notes/a-new-way-of-exploring-immunity-linking-highly-multiplexed-antigen-recognition-to-immune-repertoire-and-phenotype/#"].copy()

def remove_single_chain_data(df: DataFrame) -> DataFrame:
    return df[df["complex.id"] != 0].copy()

def group_paired_chains(df: DataFrame) -> DataFrame:
    reformatted_rows = []

    sc_complex_ids = df["complex.id"].unique()
    for complex_id in tqdm(sc_complex_ids):
        tcr_info = df[df["complex.id"] == complex_id]

        if tcr_info.shape[0] != 2:
            print(tcr_info)
            raise RuntimeError

        tra_info = tcr_info[tcr_info["Gene"] == "TRA"].iloc[0]
        trb_info = tcr_info[tcr_info["Gene"] == "TRB"].iloc[0]

        reformatted_rows.append(
            {
                "TRAV": tra_info["V"],
                "CDR3A": tra_info["CDR3"],
                "TRAJ": tra_info["J"],
                "TRBV": trb_info["V"],
                "CDR3B": trb_info["CDR3"],
                "TRBJ": trb_info["J"],
                "Epitope": tra_info["Epitope"],
                "MHCA": tra_info["MHC A"],
                "MHCB": tra_info["MHC B"],
                "Reference": tra_info["Reference"]
            }
        )

    reformatted_df = DataFrame.from_records(reformatted_rows)
    reformatted_df = reformatted_df.drop_duplicates()
    return reformatted_df

def drop_rows_with_missing_data(df: DataFrame) -> DataFrame:
    return df.dropna(subset=["TRAV", "CDR3A", "TRAJ", "TRBV", "CDR3B", "TRBJ", "Epitope"])

def standardize_nomenclature(df: DataFrame) -> DataFrame:
    df["TRAV"] = df["TRAV"].map(lambda x: tt.tr.standardize(x, enforce_functional=True))
    df["TRAJ"] = df["TRAJ"].map(lambda x: tt.tr.standardize(x, enforce_functional=True))
    df["TRBV"] = df["TRBV"].map(lambda x: tt.tr.standardize(x, enforce_functional=True))
    df["TRBJ"] = df["TRBJ"].map(lambda x: tt.tr.standardize(x, enforce_functional=True))

    df["CDR3A"] = df["CDR3A"].map(tt.junction.standardize)
    df["CDR3B"] = df["CDR3B"].map(tt.junction.standardize)

    df["MHCA"] = df["MHCA"].map(tt.mh.standardize)
    df["MHCB"] = df["MHCB"].map(tt.mh.standardize)

    return df.copy()

In [None]:
vdjdb_cleaned = preprocess(vdjdb)

In [None]:
vdjdb_cleaned.head()

In [None]:
clean_10x = pd.read_csv(tcr_data_path/"raw"/"10x_filtered"/"tcr.csv")

In [None]:
def preprocess_10x(df: DataFrame) -> DataFrame:
    preprocessed = pd.DataFrame()

    preprocessed[["TRAV", "TRAJ"]] = df.apply(
        lambda row: row["genes_TRA"].split(";")[:2],
        axis="columns",
        result_type="expand"
    ).map(lambda x: tt.tr.standardize(x, enforce_functional=True))
    preprocessed["CDR3A"] = df["cdr3_TRA"].map(tt.junction.standardize)

    preprocessed[["TRBV", "TRBJ"]] = df.apply(
        lambda row: (row["genes_TRB"].split(";")[0], row["genes_TRB"].split(";")[2]),
        axis="columns",
        result_type="expand"
    ).map(lambda x: tt.tr.standardize(x, enforce_functional=True))
    preprocessed["CDR3B"] = df["cdr3_TRB"].map(tt.junction.standardize)

    preprocessed[["Epitope", "MHCA"]] = df.apply(
        lambda row: row["peptide_HLA"].split(),
        axis="columns",
        result_type="expand"
    )
    preprocessed["MHCA"] = preprocessed["MHCA"].map(tt.mh.standardize)
    preprocessed["MHCB"] = "B2M"
    preprocessed["Reference"] = "10x Genomics Whitepaper"

    return preprocessed.drop_duplicates().dropna(ignore_index=True)

preprocessed_10x = preprocess_10x(clean_10x)

In [None]:
preprocessed_10x.head()

In [None]:
combined = pd.concat([vdjdb_cleaned, preprocessed_10x], ignore_index=True)

In [None]:
combined.head()

In [None]:
vdjdb_cleaned.Epitope.nunique()

In [None]:
vdjdb_cleaned.to_csv(tcr_data_path/"preprocessed"/"benchmarking"/"vdjdb_cleaned.csv", index=False)
combined.to_csv(tcr_data_path/"preprocessed"/"benchmarking"/"combined.csv", index=False)

## Generate specific split for SCEPTR finetuning

In [None]:
highly_sampled_epitopes = vdjdb_cleaned.groupby("Epitope").filter(lambda ep_group: len(ep_group) > 300 and ep_group["Reference"].nunique() > 1)

In [None]:
train_valid_split = []

for epitope, tcr_indices in highly_sampled_epitopes.groupby("Epitope").groups.items():
    tcrs = highly_sampled_epitopes.loc[tcr_indices]
    num_tcrs_per_reference = tcrs.groupby("Reference").size().sort_values(ascending=False)
    cumulative_num_tcrs = num_tcrs_per_reference.cumsum()
    
    references_to_use_for_training = []

    enough_training_data = False
    for reference, cumsum in cumulative_num_tcrs.items():
        if enough_training_data:
            break

        references_to_use_for_training.append(reference)
        if cumsum > 200:
            enough_training_data = True
    
    train_split_for_epitope = tcrs[tcrs["Reference"].map(lambda x: x in references_to_use_for_training)]
    train_valid_split.append(train_split_for_epitope)

train_valid_split = pd.concat(train_valid_split)
test_split = vdjdb_cleaned[~vdjdb_cleaned.index.isin(train_valid_split.index)]

In [None]:
set(train_valid_split.index).intersection(set(test_split.index))

In [None]:
len(vdjdb_cleaned) == len(train_valid_split) + len(test_split)

In [None]:
train_split = train_valid_split.groupby("Epitope").sample(n=200, random_state=420)
valid_split = train_valid_split[~train_valid_split.index.isin(train_split.index)]

In [None]:
set(train_split.index).intersection(set(valid_split.index))

In [None]:
len(train_valid_split) == len(train_split) + len(valid_split)

In [None]:
test_split[test_split.Epitope.map(lambda x: x in train_valid_split.Epitope.unique())].groupby("Epitope").aggregate({"Reference": "unique"}).to_dict()

In [None]:
train_valid_split.groupby("Epitope").aggregate({"Reference": "unique"}).to_dict()

In [None]:
valid_split.groupby("Epitope").size()

In [None]:
train_valid_split.to_csv(tcr_data_path/"preprocessed"/"benchmarking"/"train_valid.csv", index=False)
train_split.to_csv(tcr_data_path/"preprocessed"/"benchmarking"/"train.csv", index=False)
valid_split.to_csv(tcr_data_path/"preprocessed"/"benchmarking"/"valid.csv", index=False)
test_split.to_csv(tcr_data_path/"preprocessed"/"benchmarking"/"test.csv", index=False)