In [1]:
import operator as op
import typing as t
import copy
from itertools import chain
from collections import Counter

import numpy as np
import pandas as pd
import tensorflow as tf
# from tensorflow._api.v1.keras import backend as K, layers, models, optimizers, \
#    activations, initializers
from keras import backend as K, layers, models
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from fn import F
from fn.func import identity

from sklab.data import preprocessing as pp
from attention import core

FLOAT_T = np.float32
INT_T = np.int32

%matplotlib notebook

Using TensorFlow backend.


In [2]:
class SequenceEncoder:
    def __init__(self, alphabet: t.Iterable[str]):
        unique = sorted(set(chain.from_iterable(alphabet)))
        self._mapping = dict(
            (val, key) for key, val in enumerate(unique, 1)
        )
        self._oov = len(self._mapping) + 1
    
    def __call__(self, sequence: t.Sequence[str]) -> np.ndarray:
        encoded = (self._mapping.get(char, self._oov) for char in sequence)
        return np.fromiter(encoded, INT_T, len(sequence))
        
    @property
    def mapping(self) -> t.Mapping[str, int]:
        return copy.deepcopy(self._mapping)
    
    @property
    def oov(self) -> int:
        return self._oov


def expand_categories(categories: np.ndarray, dtype=FLOAT_T) -> np.ndarray:
    ncat = categories.shape[0]
    maxcat = categories.max()
    expanded = np.zeros(maxcat*ncat, dtype=FLOAT_T).reshape((ncat, maxcat))
    for i, j in enumerate(categories):
        expanded[i,:j] = 1
    return expanded

In [3]:
train = pd.read_csv('data/abundant_train.tsv', sep='\t')
allotypes = {
    seqrec.id: str(seqrec.seq) for seqrec in SeqIO.parse('data/allotype_binding_regions.faa', 'fasta')
}
consurf = pd.read_csv('data/consurf.tsv', sep='\t')


In [4]:
train.keys()

Index(['accession', 'allotype', 'article_affiliations', 'article_authors',
       'article_chemical_list', 'article_date', 'article_id', 'article_title',
       'bind_or_elution_id', 'category', 'collection', 'inequality',
       'journal_title', 'measurement', 'measurement_ord', 'measurement_type',
       'method', 'object_type', 'peptide', 'quantitative',
       'reference_category_name', 'reference_type', 'source', 'species',
       'submission_affiliations', 'submission_authors', 'submission_date',
       'submission_id', 'submission_title', 'submitter_name', 'units'],
      dtype='object')

In [5]:
consurf.keys()

Index(['position', 'consurf_score'], dtype='object')

In [6]:
# peptide length distribution

train['peptide'].apply(len).value_counts()


9     116089
10     32707
8       5860
11      3439
13       311
12       273
15       223
14       175
18        34
17        31
20        14
7          9
16         6
21         5
19         3
6          2
5          2
26         2
23         1
30         1
Name: peptide, dtype: int64

In [7]:
PEPTIDE_LEN = 16
ALLOTYPE_LEN = 64

# select ALLOTYPE_LEN most variable MHC positions (i.e. positions with large consurf scores)
mhc_positions = consurf.sort_values('consurf_score', ascending=False).iloc[:ALLOTYPE_LEN,0].sort_values()
# remove recods with overly long peptides
train_filt = train[train['peptide'].apply(len) <= PEPTIDE_LEN]

In [8]:
peptides_train = train_filt['peptide']
accessions_train = train_filt['accession']
categories_train = train_filt['measurement_ord']
allotypes_train = (
    F(map, allotypes.get)
    >> (map, op.itemgetter(*mhc_positions))
    >> list
)(accessions_train)


In [9]:
# allotype AA distribution
allotype_train_aa_counts = Counter(chain.from_iterable(allotypes_train))
(pd.Series(allotype_train_aa_counts) / sum(allotype_train_aa_counts.values())).sort_values()

C    0.001142
P    0.022825
N    0.023451
M    0.029600
K    0.030530
I    0.031085
F    0.032541
H    0.035138
W    0.038424
D    0.042060
L    0.044233
S    0.046269
V    0.057988
Y    0.060929
T    0.065004
Q    0.066136
G    0.071495
E    0.091462
R    0.100311
A    0.109377
dtype: float64

In [10]:
# peptide AA distribution
peptide_train_aa_counts = Counter(chain.from_iterable(peptides_train))
(pd.Series(peptide_train_aa_counts) / sum(peptide_train_aa_counts.values())).sort_values()

C    0.014703
W    0.019917
H    0.022215
M    0.032078
Q    0.032353
D    0.037534
N    0.038360
E    0.043513
P    0.046422
G    0.047950
K    0.052160
R    0.053258
Y    0.054931
T    0.059862
F    0.059978
S    0.063748
I    0.068264
V    0.068847
A    0.068972
L    0.114938
dtype: float64

In [11]:
allotype_aa = chain.from_iterable(allotypes_train)
peptide_aa = chain.from_iterable(peptides_train)
seqencoder = SequenceEncoder(
    (F(chain) >> (filter, lambda aa: aa != '-'))(allotype_aa, peptide_aa)
)

In [12]:
peptides_enc_train, peptides_mask_train = pp.stack(
    [seqencoder(pep) for pep in peptides_train],
    shape=(PEPTIDE_LEN,), dtype=INT_T, filler=0
)
allotypes_enc_train, allotypes_mask_train = pp.stack(
    [seqencoder(allo) for allo in allotypes_train],
    shape=(ALLOTYPE_LEN,), dtype=INT_T, filler=0 
)
categories_enc_train = expand_categories(categories_train)


  op.setitem(stacked, [i, *slices_], op.getitem(arr, slices_))
  op.setitem(mask, [i, *slices_], True)


In [13]:
batch = 256
emb_d = 16
depth = 3
r = 4
ffn_hid = emb_d * r
ffn_act = 'elu'
dropout = 0.1

input_allotypes = layers.Input(
    shape=(ALLOTYPE_LEN,), name='input_allotypes', dtype='int32'
)
input_peptides = layers.Input(
    shape=(PEPTIDE_LEN,), name='input_peptides', dtype='int32'
)

embeddings = layers.Embedding(
    input_dim=seqencoder.oov+1, output_dim=emb_d,
    mask_zero=False, name='embeddings'
)

# embedded_allotypes = layers.Lambda(
#     identity, output_shape=(ALLOTYPE_LEN, emb_d), name='embed_allotypes'
# )(embeddings(input_allotypes))
embedded_peptides = layers.Lambda(
    identity, output_shape=(PEPTIDE_LEN, emb_d), name='embed_peptides'
)(embeddings(input_peptides))

attention = core.ScaledDotProductAttention(dropout)
mhattention = core.MultiHeadAttention(attention, PEPTIDE_LEN, PEPTIDE_LEN, r, emb_d // r)

# mhattention(PEPTIDE_LEN, PEPTIDE_LEN, embedded_peptides, embedded_peptides, embedded_peptides)
mhattention(embedded_peptides, embedded_peptides, embedded_peptides)

Q Tensor("Shape:0", shape=(3,), dtype=int32)
splits Tensor("strided_slice:0", shape=(), dtype=int32) Tensor("strided_slice_1:0", shape=(), dtype=int32) Tensor("strided_slice_2:0", shape=(), dtype=int32)


(<tf.Tensor 'dense_4/Reshape_2:0' shape=(?, 16, 16) dtype=float32>,
 <tf.Tensor 'lambda_7/transpose:0' shape=(?, ?, 4, ?) dtype=float32>)

In [14]:
embedded_peptides.shape


TensorShape([Dimension(None), Dimension(16), Dimension(16)])