<a href="https://colab.research.google.com/github/perrin-isir/xomx-tutorials/blob/main/tutorials/xomx_tcr.ipynb"> <img align="left" src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab" title="Open in Google Colaboratory"></a>
<a id="raw-url" href="https://raw.githubusercontent.com/perrin-isir/xomx-tutorials/main/tutorials/xomx_tcr.ipynb" download> <img align="left" src="https://img.shields.io/badge/Github-Download%20(Right%20click%20%2B%20Save%20link%20as...)-blue" alt="Download (Right click + Save link as)" title="Download Notebook"></a>

# *xomx tutorial:* **predicting the epitope from a TCR $\beta$-chain CDR3 sequence**

In this tutorial, we train an extra-trees classifier to predict whether a TCR $\beta$-chain CDR3 sequence is associated with a given epitope. The classifier takes in input a one-hot encoding of the CDR3 sequence. 

We use data taken from the [VDJdb](https://vdjdb.cdr3.net/) and [MsPAS-TCR](http://friedmanlab.weizmann.ac.il/McPAS-TCR/) databases.

In [None]:
# imports:
import os
import joblib
from IPython.display import clear_output
import matplotlib.pyplot as plt
try:
    import xomx
except ImportError:
    !pip install git+https://github.com/perrin-isir/xomx.git
    clear_output()
    import xomx
try:
    import scanpy as sc
except ImportError:
    !pip install scanpy
    clear_output()
    import scanpy as sc
try:
    import logomaker
except ImportError:
    !pip install logomaker
    clear_output()
    import logomaker
import umap
import pandas as pd
import numpy as np

We first define `save_dir`, the folder in which everything will be saved.

In [None]:
save_dir = os.path.join(os.path.expanduser('~'), 'results', 'xomx-tutorials', 'tcr')
os.makedirs(save_dir, exist_ok=True)

We import both the VDJdb and the MsPAS-TCR databases.

For the McPAS-TCR database, a link must be provided. Please got to [http://friedmanlab.weizmann.ac.il/McPAS-TCR/](http://friedmanlab.weizmann.ac.il/McPAS-TCR/), 
right-click "Copy link adress" on the "Download the complete database..." button, and paste the url in the following definition:

In [None]:
mcpas_tcr_url = ""

In [None]:
vdjdb_file = 'vdjdb-2021-09-05.zip'
vdjdb_url = os.path.join('https://github.com/antigenomics/vdjdb-db/releases/download/2021-09-05/', vdjdb_file)
mcpas_tcr_file = 'McPAS-TCR.csv'

if not os.path.isfile(os.path.join(save_dir, vdjdb_file)):
    !wget {vdjdb_url} --directory-prefix={save_dir}
    !unzip {os.path.join(save_dir, vdjdb_file)} -d {save_dir}
    
if not os.path.isfile(os.path.join(save_dir, mcpas_tcr_file)):
    !wget {mcpas_tcr_url} --output-document={os.path.join(save_dir, mcpas_tcr_file)}

We convert the VDJdb and the MsPAS-TCR databases to pandas DataFrames, go through these dataframes, and construct two dictionnaries:  
`dic_cdr3beta` and `dic_epitopes`. 

For every $\beta$-chain CDR3 sequence `s`, `dic_cdr3beta[s]` is the set of epitopes to which it is associated.  
For every epitope sequence `s`, `dic_epitopes[s]` is the set of $\beta$-chain CDR3 sequences to which it is associated.

**Remark: we only take into account CDR3 sequences of length at most 22 (arbitrary choice).**

In [None]:
dic_epitopes_file = 'dic_epitopes.joblib'
dic_cdr3beta_file = 'dic_cdr3beta.joblib'
cdr3_max_length = 22

if not os.path.isfile(os.path.join(save_dir, dic_epitopes_file)) or not not os.path.isfile(os.path.join(save_dir, dic_cdr3beta_file)):
    vdjdb_df = pd.read_csv(os.path.join(save_dir, 'vdjdb_full.txt'), delimiter="\t", low_memory=False)
    mcpas_tcr_df = pd.read_csv("/home/perrin/Desktop/data/McPAS-TCR.csv", encoding='cp1252', delimiter=",", low_memory=False)

    dic_cdr3beta = {}
    dic_epitopes = {}

    def dic_iteration(cdr3beta_seq, epitope_seq):
        if cdr3beta_seq == cdr3beta_seq and epitope_seq == epitope_seq:  # filter NaNs
            if not set(epitope_seq).difference(xomx.tl.aminoacids) and not set(cdr3beta_seq).difference(xomx.tl.aminoacids):  #filter undefined symbols
                if len(cdr3beta_seq) <= cdr3_max_length:  # filter long sequences
                    dic_cdr3beta.setdefault(cdr3beta_seq, set())
                    dic_cdr3beta[cdr3beta_seq].add(epitope_seq)
                    dic_epitopes.setdefault(epitope_seq, set())
                    dic_epitopes[epitope_seq].add(cdr3beta_seq)

    for i in range(len(vdjdb_df)):
        cdr3beta = vdjdb_df["cdr3.beta"].values[i]
        epitope = vdjdb_df["antigen.epitope"].values[i]
        dic_iteration(cdr3beta, epitope)

    for i in range(len(mcpas_tcr_df)):
        cdr3beta = mcpas_tcr_df["CDR3.beta.aa"].values[i]
        epitope = mcpas_tcr_df["Epitope.peptide"].values[i]
        dic_iteration(cdr3beta, epitope)
    
    joblib.dump(dic_epitopes, os.path.join(save_dir, dic_epitopes_file))
    joblib.dump(dic_cdr3beta, os.path.join(save_dir, dic_cdr3beta_file))
else:
    dic_epitopes = joblib.load(os.path.join(save_dir, dic_epitopes_file))
    dic_cdr3beta = joblib.load(os.path.join(save_dir, dic_cdr3beta_file))

In [None]:
print(f'{len(dic_cdr3beta)} beta-chain CDR3 sequences in total')
cdr3_single_epitope = set()
for key, value in dic_cdr3beta.items():
    if len(value) == 1:
        cdr3_single_epitope.add(key)
print(f'{len(cdr3_single_epitope)} beta-chain CDR3 sequences that recognize a unique epitope')

Optionally, we recompute `dic_epitopes` to keep only CDR3 beta sequences that are associated to a unique epitope:

In [None]:
dic_epitopes= {}
for key in dic_cdr3beta:
    if len(dic_cdr3beta[key]) == 1:
        epitope = list(dic_cdr3beta[key])[0]
        dic_epitopes.setdefault(epitope, set())
        dic_epitopes[epitope].add(key)
        assert(key in cdr3_single_epitope)

We sort the epitopes by decreasing number of associated beta-chain CDR3 sequences:

In [None]:
sorted_epitopes = sorted(dic_epitopes, key=lambda k: len(dic_epitopes[k]), reverse=True)

We use one-hot encodings to represent the CDR3 sequences. Their dimension is `cdr3_max_length` x `len(xomx.tl.aminoacids)`, where `xomx.tl.aminoacids` is the list of the 20 standard amino acid characters:

In [None]:
dimension = cdr3_max_length * len(xomx.tl.aminoacids)
dimension

We construct annotated data with the one-hot encodings of the beta-chain CDR3 sequences corresponding to the `K=30` most frequent epitopes.

In [None]:
K = 30
nr_samples = sum([len(dic_epitopes[sorted_epitopes[j]]) for j in range(K)])
nr_samples

In [None]:
xd = sc.AnnData(shape=(nr_samples, dimension))
xd.obs_names = np.hstack([sorted(list(dic_epitopes[sorted_epitopes[j]])) for j in range(K)])
xd.obs['labels'] = np.hstack([[sorted_epitopes[j]] * len(dic_epitopes[sorted_epitopes[j]]) for j in range(K)])
xd.uns['all_labels'] = xomx.tl.all_labels(xd.obs['labels'])
xd.uns['obs_indices_per_label'] = xomx.tl.indices_per_label(xd.obs['labels'])
xd.X = np.zeros((xd.n_obs, xd.n_vars))
for i in range(xd.n_obs):
    xd.X[i, :] = xomx.tl.onehot(xd.obs_names[i], cdr3_max_length)

We randomly separate the sequences into a training (75%) and a test set (25%):

In [None]:
rng = np.random.RandomState(0)
xomx.tl.train_and_test_indices(xd, "obs_indices_per_label", test_train_ratio=0.25, rng=rng)

A UMAP plot based on 4000 randomly chosen sequences:

In [None]:
xomx.pl.plot_2d_embedding(xd, umap.UMAP(), subset_indices=rng.choice(xd.n_obs, 4000, replace=False), height=750)

In [None]:
classifier = {}

In [None]:
xd.uns['all_labels'][:10]

We define a binary classifier for the epitope GILGFVFTL:

In [None]:
index = 2  # index of GILGFVFTL
classifier[xd.uns['all_labels'][index]] = xomx.fs.RFEExtraTrees(
    xd,
    xd.uns['all_labels'][index],
    n_estimators=450,
    random_state=rng,
)

The class `xomx.fs.RFEExtraTrees()` is normally used to perform recursive feature elimination, but if we only run its init() method, it simply trains a classifier 
(using the Extra-Trees algorithm) to learn to discriminate between a reference label (here, "GILGFVFTL") and the rest.

So here the classifier is trained to discriminate between beta-chain CDR3 sequences associated to GILGFVFTL, and other beta-chain CDR3 sequences:

In [None]:
classifier[xd.uns['all_labels'][index]].init()

We plot the result of the classifier on 6000 random samples from the test set (points above the red line are classified as being associated with GILGFVFTL):

In [None]:
classifier[xd.uns['all_labels'][index]].plot(random_subset_size=6000, rng=rng, height=750)

We compute the Matthews Correlation Coefficient on the test set:

In [None]:
xomx.tl.matthews_coef(classifier[xd.uns['all_labels'][index]].confusion_matrix)

We gather the predictions on the test set (`True` means that classifier considers the sequence to be GILGFVFTL-related) in an array:

In [None]:
predictions_on_test_set = (classifier[xd.uns['all_labels'][index]].predict(xd.X[xd.uns['test_indices']]) == 1)
predictions_on_test_set

We define `ok_samples_test_set`, the indices of the beta-chain CDR3 sequences in the test set that are classified as being associated with GILGFVFTL by the classifier:

In [None]:
ok_samples_test_set = [xd.uns['test_indices'][j] for j, val in enumerate(predictions_on_test_set) if val]

We also define `all_GILGFVFTL_samples`, the indices of all the sequences labelled "GILGFVFTL":

In [None]:
all_GILGFVFTL_samples = xd.uns['obs_indices_per_label']["GILGFVFTL"]

We can now plot the logo computed from all these beta-chain CDR3 sequences, using logomaker. The function `xomx.tl.compute_logomaker_df()` computes the pandas DataFrame required by logomaker (a table of probabilities for all characters and positions).

In [None]:
df_logo = xomx.tl.compute_logomaker_df(xd, all_GILGFVFTL_samples)
_, ax = plt.subplots(figsize=[12, 3])
logo = logomaker.Logo(df_logo, ax=ax)
plt.show()

Logos obtained from sequences of variable lengths are not ideal for interpretation, so in `xomx.tl.compute_logomaker_df()` we can use the option `fixed_length` to 
compute distinct logos for various lengths.

Below, we compute logos for sequences of length 12, 13, 14 and 15. We compare the logos computed from `ok_samples_test_set` (column on the left), and `all_GILGFVFTL_samples` (column on the right):

In [None]:
df_logo_fixed_length = {}
lengths = np.arange(12, 16)
fig, ax = plt.subplots(len(lengths), 2, figsize=[18, 3 * len(lengths)], sharex=True)
ax[0, 0].set_title("Sequences (in the test set) classified as being associated with GILGFVFTL by the classifier")
ax[0, 1].set_title('All sequences labelled "GILGFVFTL"')
for i in lengths:   
    logomaker.Logo(xomx.tl.compute_logomaker_df(xd, ok_samples_test_set, fixed_length=i), 
                          ax=ax[i - lengths[0], 0])
    logomaker.Logo(xomx.tl.compute_logomaker_df(xd, all_GILGFVFTL_samples, fixed_length=i), 
                          ax=ax[i - lengths[0], 1])
fig.tight_layout()
plt.show()

Remark: you can train classifiers on other epitopes than GILGFVFTL, but not all the epitopes lead to classifiers that perform relatively well on the test set. For some epitopes, after the training, results on the test set are not better than random predictions. The complexity of the patterns associated with an epitope may vary greatly depending on the epitope.