In [1]:
import torch
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
import pandas as pd
import numpy as np
tqdm.pandas()

import sys
import os
sys.path.insert(1, os.getcwd() + '/src')
from cnn import CNN
from utils import embed_genes, embed_seqs, ALPHABET
from scipy.stats import mannwhitneyu

EPOCHS = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DATA_DIR = os.getcwd() + '/data/'
MODEL_DIR = os.getcwd() + '/models/'

In [2]:
def train_epoch(model, device, train_tuple, test_tuple, optimizer, criterion, epoch, batch_size=10000):
    model.train()
    x,v,j,y = train_tuple
    xt,vt,jt,yt = test_tuple

    train_loss = []

    for i in range(0, x.shape[0], batch_size):
        x_i = x[i:i+batch_size]
        v_i = v[i:i+batch_size]
        j_i = j[i:i+batch_size]
        y_i = y[i:i+batch_size]
        x_i, v_i, j_i, y_i = x_i.to(device), v_i.to(device), j_i.to(device), y_i.to(device)
        optimizer.zero_grad()
        output = model(x_i, v_i, j_i)
        loss = criterion(output, y_i)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        x_i, v_i, j_i, y_i = x_i.cpu(), v_i.cpu(), j_i.cpu(), y_i.cpu()
        train_loss.append(loss.detach().cpu().numpy())

    train_loss = np.array(train_loss).mean()
    
    model.eval()
    test_loss = criterion(model(xt, vt, jt), yt).detach().cpu().numpy()
    print('Epoch [{}], Train loss: {}, Test loss: {}'.format(epoch, train_loss, test_loss), end='')

## Load data

In [3]:
data = pd.read_csv(DATA_DIR + 'mouse/C57BL6_mouse_1234_merged_reduced.csv.gz', compression='gzip')[
    ['aaSeqCDR3', 'bestVHit', 'bestDHit', 'bestJHit', 'source', 'individual']]
in_alph = data.aaSeqCDR3.progress_apply(lambda x: len(set(x) - set(ALPHABET)) == 0)
data = data[in_alph]
conserved_C = (data.aaSeqCDR3.str[0] == 'C')
data = data[conserved_C]
conserved_W = (data.aaSeqCDR3.str[-1] == 'W')
data = data[conserved_W]

100%|██████████| 21675696/21675696 [00:32<00:00, 662007.74it/s]


In [4]:
data['length'] = data.aaSeqCDR3.progress_apply(lambda x: len(x))
data = data[data.length <= 32]
data = data[data.length >= 7]

100%|██████████| 20805480/20805480 [00:15<00:00, 1328361.00it/s]


In [5]:
n_v = 90

J_GENES = data.bestJHit.value_counts()[:4].index.tolist()
J_GENES = {J_GENES[i]:i for i in range(len(J_GENES))}
V_GENES = data[data.source == 'naive'].bestVHit.value_counts()[:n_v].index.tolist()
V_GENES = {V_GENES[i]:i for i in range(len(V_GENES))}

data = data[data.bestVHit.isin(V_GENES)]
data = data[data.bestJHit.isin(J_GENES)]

## Train model

In [6]:
### Here we train on 3 mice and hold out the 4th
### In the paper we report results by averaging all four possible such models
### Here we train only one.
### Results will differ slightly, also due to replicate variance

mouse = pd.unique(data.individual)[0]

train = data[data.individual != mouse]
valid = data[data.individual == mouse]

train = train.groupby('source').sample(n=int(0.7 * train.source.value_counts().min())).reset_index(drop = True)
valid = valid.groupby('source').sample(n=10000).reset_index(drop = True)

train_x = embed_seqs(train.aaSeqCDR3.to_numpy(), pad_length=32, alph=ALPHABET).float()
train_v = embed_genes(train['bestVHit'].to_numpy(), V_GENES).float()
train_j = embed_genes(train['bestJHit'].to_numpy(), J_GENES).float()
train_y = torch.from_numpy(
    (1*(train['source'] == 'preB') + 2*(train['source'] == 'naive')).to_numpy()).long()
train_tuple = (train_x, train_v, train_j, train_y)

valid_x = embed_seqs(valid.aaSeqCDR3.to_numpy(), pad_length=32, alph=ALPHABET).float().to(device)
valid_v = embed_genes(valid['bestVHit'].to_numpy(), V_GENES).float().to(device)
valid_j = embed_genes(valid['bestJHit'].to_numpy(), J_GENES).float().to(device)
valid_y = torch.from_numpy(
    (1*(valid['source'] == 'preB') + 2*(valid['source'] == 'naive')).to_numpy()).long().to(device)
valid_tuple = (valid_x, valid_v, valid_j, valid_y)

model = CNN(use_vj=False, num_v=len(V_GENES), num_j=len(J_GENES)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)
criterion = torch.nn.CrossEntropyLoss()
for epoch in tqdm(range(EPOCHS)):
    shuffle = torch.randperm(train_x.size(0))
    train_x = train_x[shuffle]
    train_v = train_v[shuffle]
    train_j = train_j[shuffle]
    train_y = train_y[shuffle]
    train_tuple = (train_x, train_v, train_j, train_y)
    train_epoch(model, device, train_tuple, valid_tuple, optimizer, criterion, epoch)
    torch.save(model.state_dict(), MODEL_DIR + 'mouse_cnn_{0}.pt'.format(mouse))

train_x = None; train_v = None; train_j = None; train_y = None
valid_x = None; valid_v = None; valid_j = None; valid_y = None

 20%|██        | 1/5 [01:01<04:05, 61.35s/it]

Epoch [0], Train loss: 0.8060317635536194, Test loss: 0.7596032619476318

 40%|████      | 2/5 [01:33<02:12, 44.18s/it]

Epoch [1], Train loss: 0.7630923390388489, Test loss: 0.7555277347564697

 60%|██████    | 3/5 [02:04<01:16, 38.38s/it]

Epoch [2], Train loss: 0.7553618550300598, Test loss: 0.7487244606018066

 80%|████████  | 4/5 [02:37<00:36, 36.03s/it]

Epoch [3], Train loss: 0.7504685521125793, Test loss: 0.7467759251594543

100%|██████████| 5/5 [03:09<00:00, 37.91s/it]

Epoch [4], Train loss: 0.7469162940979004, Test loss: 0.7514888644218445




## Load and score polyreactivity data

In [7]:
poly_data = pd.read_csv(DATA_DIR + 'mouse/poly_mouse.csv')
poly_data['length'] = poly_data.CDRH3.apply(lambda x: len(x))
poly_data = poly_data[poly_data.length >= 7]
poly_data = poly_data[poly_data.length <= 32]
poly_data = poly_data[poly_data.CDRH3.apply(lambda x: len(set(x) - set(ALPHABET))) == 0]

x_p = embed_seqs(poly_data.CDRH3.to_numpy(), pad_length=32, alph=ALPHABET).float().to(device)

scores = model(x_p, None, None).detach().cpu()
scores -= torch.logsumexp(scores, dim=1, keepdim=True)
scores = scores.numpy()

poly_data['cnn_igor'] = scores[:, 0]
poly_data['cnn_pre'] = scores[:, 1]
poly_data['cnn_naive'] = scores[:, 2]

In [8]:
# Statistically significant difference based on naive/pre log odds

naive_pre_odds = poly_data.cnn_naive - poly_data.cnn_pre
non_odds = naive_pre_odds[poly_data.poly_count <= 0]
poly_odds = naive_pre_odds[poly_data.poly_count >= 2]

mannwhitneyu(non_odds, poly_odds)

MannwhitneyuResult(statistic=21557.0, pvalue=8.839563440718206e-06)

In [9]:
# No significant difference based on naive/igor log odds

naive_igor_odds = poly_data.cnn_naive - poly_data.cnn_igor
non_odds = naive_igor_odds[poly_data.poly_count <= 0]
poly_odds = naive_igor_odds[poly_data.poly_count >= 2]

mannwhitneyu(non_odds, poly_odds)

MannwhitneyuResult(statistic=18340.0, pvalue=0.19026368174736452)

In [10]:
# No significant difference based on pre/igor log odds

pre_igor_odds = poly_data.cnn_pre - poly_data.cnn_igor
non_odds = pre_igor_odds[poly_data.poly_count <= 0]
poly_odds = pre_igor_odds[poly_data.poly_count >= 2]

mannwhitneyu(non_odds, poly_odds)

MannwhitneyuResult(statistic=17869.0, pvalue=0.3947940029493612)