In [1]:
import torch
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
import pandas as pd
import numpy as np

import sys
import os
sys.path.insert(1, os.getcwd() + '/src')
from cnn_paired import CNN_Paired
from utils import embed_genes, embed_seqs, ALPHABET, embed_sl, embed_dl, train_epoch_cell

from scipy.stats import fisher_exact

EPOCHS = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DATA_DIR = os.getcwd() + '/data/'
MODEL_DIR = os.getcwd() + '/models/'

## Load Jaffe et al. BCR repertoires

In [2]:
sl = pd.read_csv(DATA_DIR + 'jaffe/naive_single_light.csv')
dl = pd.read_csv(DATA_DIR + 'jaffe/combined_double_light.csv')

### Restrict to B cells annotated naive and realistic CDR3 lengths
### Longer CDR3H are very rare
### Longer CDR3L are likely misannotated

donor_celltype = pd.read_csv(DATA_DIR + 'jaffe/donor_and_class.csv').set_index('dataset')
sl['donor'] = sl.source.apply(lambda x: donor_celltype.donor.loc[x])
dl['donor'] = dl.source.apply(lambda x: donor_celltype.donor.loc[x])
sl['celltype'] = sl.source.apply(lambda x: donor_celltype.flow_class.loc[x])
dl['celltype'] = dl.source.apply(lambda x: donor_celltype.flow_class.loc[x])
dl = dl[dl.celltype == 'naive']
sl = sl[sl.celltype == 'naive']

dl['length_L_1'] = dl.apply(lambda x: len(x.CDRL3_1), axis=1)
dl['length_L_2'] = dl.apply(lambda x: len(x.CDRL3_2), axis=1)
dl = dl[(dl.length_L_1 <= 20) & (dl.length_L_2 <= 20)]
dl['length_H'] = dl.apply(lambda x: len(x.CDRH3), axis=1)
dl = dl[dl.length_H <= 32]

sl['length_L'] = sl.apply(lambda x: len(x.CDRL3), axis=1)
sl = sl[sl.length_L <= 20]
sl['length_H'] = sl.apply(lambda x: len(x.CDRH3), axis=1)
sl = sl[sl.length_H <= 32]

In [3]:
### Restrict to heavy and light V-genes present in more than 1 donor. 
### This is true for all J-genes in both heavy and light

v_gene_donor_counts = sl.groupby('v_gene_L').apply(lambda x: pd.unique(x.donor).size)
v_genes = v_gene_donor_counts[v_gene_donor_counts > 1].index.values
v_genes_l = np.array([entry for entry in v_genes if entry[2] in ['K', 'L']])
v_genes_l = {v_genes_l[i]:i for i in range(v_genes_l.size)}

j_genes_l = np.array(['IGLJ2', 'IGKJ1', 'IGKJ2', 'IGLJ3', 'IGKJ4', 'IGLJ1', 'IGKJ3', 'IGKJ5', 'IGLJ7'])
j_genes_l = {j_genes_l[i]:i for i in range(j_genes_l.size)}

v_gene_donor_counts = sl.groupby('v_gene_H').apply(lambda x: pd.unique(x.donor).size)
v_genes = v_gene_donor_counts[v_gene_donor_counts > 1].index.values
v_genes_h = np.array([entry for entry in v_genes if entry[2] in ['H']])
v_genes_h = {v_genes_h[i]:i for i in range(v_genes_h.size)}

j_genes_h = np.array(['IGHJ1', 'IGHJ2', 'IGHJ3', 'IGHJ4', 'IGHJ5', 'IGHJ6'])
j_genes_h = {j_genes_h[i]:i for i in range(j_genes_h.size)}

sl = sl[sl.v_gene_H.isin(v_genes_h) & sl.v_gene_L.isin(v_genes_l)]
dl = dl[dl.v_gene_H.isin(v_genes_h) & dl.v_gene_L_1.isin(v_genes_l) & dl.v_gene_L_2.isin(v_genes_l)]

vj_genes = (v_genes_h, j_genes_h, v_genes_l, j_genes_l)

## Train model on light chains only

In [4]:
### Here we train on 3 donors and hold out the 4th
### In the paper we report results by averaging all four possible such models
### Here we train only one.
### Results will differ slightly, also due to replicate variance

test_donor = 4

dl_train = dl[dl.donor != test_donor]
sl_train = sl[sl.donor != test_donor]

dl_test = dl[dl.donor == test_donor]
sl_test = sl[sl.donor == test_donor].sample(2 * dl_test.shape[0])
sl_test = embed_sl(sl_test, device, vj_genes, drop_heavy=True)
(dl_test_1, dl_test_2) = embed_dl(dl_test, device, vj_genes, drop_heavy=True)

model = CNN_Paired(use_vj=True, num_v=sl_test[1].shape[1], num_j=sl_test[2].shape[1]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)
criterion = torch.nn.CrossEntropyLoss()

num_samp = 2 * dl_train.shape[0]

for epoch in tqdm(range(EPOCHS)):
    sl_train_subs = sl_train.sample(num_samp).reset_index(drop=True)
    dl_train = dl_train.sample(frac=1).reset_index(drop=True)
    train_epoch_cell(model, device,
                     (sl_train_subs, dl_train),
                     (sl_test, dl_test_1, dl_test_2),
                     optimizer, criterion, epoch, vj_genes,
                    drop_heavy=True)
    torch.save(model.state_dict(),
               MODEL_DIR + 'cnn_dl_noH_{0}.pt'.format(test_donor))

 10%|█         | 1/10 [00:05<00:46,  5.17s/it]

Epoch [0], Train loss: 0.6849095821380615, Test recall: 0.21439628482972137

 20%|██        | 2/10 [00:10<00:43,  5.43s/it]

Epoch [1], Train loss: 0.6663867235183716, Test recall: 0.26393188854489164

 30%|███       | 3/10 [00:16<00:40,  5.72s/it]

Epoch [2], Train loss: 0.6542438268661499, Test recall: 0.25402476780185756

 40%|████      | 4/10 [00:22<00:34,  5.69s/it]

Epoch [3], Train loss: 0.653691828250885, Test recall: 0.2657894736842105

 50%|█████     | 5/10 [00:28<00:28,  5.63s/it]

Epoch [4], Train loss: 0.6474936008453369, Test recall: 0.27987616099071205

 60%|██████    | 6/10 [00:31<00:20,  5.06s/it]

Epoch [5], Train loss: 0.6442214250564575, Test recall: 0.2885448916408669

 70%|███████   | 7/10 [00:36<00:15,  5.04s/it]

Epoch [6], Train loss: 0.6404878497123718, Test recall: 0.29272445820433435

 80%|████████  | 8/10 [00:41<00:09,  4.97s/it]

Epoch [7], Train loss: 0.641072690486908, Test recall: 0.3086687306501548

 90%|█████████ | 9/10 [00:46<00:05,  5.04s/it]

Epoch [8], Train loss: 0.640450656414032, Test recall: 0.31222910216718264

100%|██████████| 10/10 [00:52<00:00,  5.28s/it]

Epoch [9], Train loss: 0.6372255682945251, Test recall: 0.295046439628483




## Load and assess on van der Wijst data

In [5]:
### Data is prefiltered to only include naive B cells and vj genes in Jaffe data

sl_ucsf = pd.read_csv(DATA_DIR + 'van_der_wijst/naive_single_light.csv')
dl_ucsf = pd.read_csv(DATA_DIR + 'van_der_wijst/naive_double_light.csv')

x_sl, v_sl, j_sl = embed_sl(sl_ucsf, device, vj_genes, drop_heavy=True)
embed_dl_tuple = embed_dl(dl_ucsf, device, vj_genes, drop_heavy=True)

In [6]:
scores = model(x_sl, v_sl, j_sl).detach().cpu().numpy()
sl_ucsf['scores'] = scores[:, 1] - scores[:, 0]

In [7]:
for i in range(2):
    x, v, j = embed_dl_tuple[i]
    scores = model(x, v, j).detach().cpu().numpy()
    dl_ucsf['scores_{0}'.format(i+1)] = scores[:, 1] - scores[:, 0]

In [8]:
### Fraction of DL cells with sequence worse than 95% of those in SL cells
quantile = 0.05
thres = sl_ucsf.scores.quantile(quantile)
(dl_ucsf[['scores_1', 'scores_2']].min(axis=1) < thres).mean()

0.3655536028119508

## Load and assess on Wardemann polyreactivity data

In [9]:
data_poly = pd.read_csv(DATA_DIR + 'polyreactivity/wardemann_polyreactivity.csv')

In [10]:
x, v, j = embed_sl(data_poly, device, vj_genes, drop_heavy=True)

In [11]:
scores = model(x, v, j).detach().cpu().numpy()
data_poly['scores'] = scores[:, 1] - scores[:, 0]

In [12]:
roc_auc_score(data_poly.auto_count <= 1, data_poly.scores)

0.7004864055874283

## Load and assess on expression data

In [13]:
data_expr = pd.read_csv(DATA_DIR + 'expression/expression.csv')

In [14]:
x, v, j = embed_sl(data_expr, device, vj_genes, drop_heavy=True)

In [15]:
scores = model(x, v, j).detach().cpu().numpy()
data_expr['scores'] = scores[:, 1] - scores[:, 0]

In [17]:
quantile = 0.03
thres = np.quantile(data_expr.scores[data_expr.igm_level == 'HI'], quantile)

counts = np.zeros((2, 2))
counts[1,0] = ((data_expr.igm_level == 'LO') & (data_expr.scores > thres)).sum()
counts[1,1] = ((data_expr.igm_level == 'LO') & (data_expr.scores < thres)).sum()
counts[0,0] = ((data_expr.igm_level == 'HI') & (data_expr.scores > thres)).sum()
counts[0,1] = ((data_expr.igm_level == 'HI') & (data_expr.scores < thres)).sum()

fisher_exact(counts)

SignificanceResult(statistic=1.5304585665635886, pvalue=7.727560238765047e-05)