In [1]:
import operator as op
import typing as t
import copy
from itertools import chain
from collections import Counter

import numpy as np
import pandas as pd
import tensorflow as tf
# from tensorflow._api.v1.keras import backend as K, layers, models, optimizers, \
#    activations, initializers
from keras import backend as K, layers, models
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from fn import F
from fn.func import identity

from sklab.data import preprocessing as pp
from sklab.attention import core, util

FLOAT_T = np.float32
INT_T = np.int32

%matplotlib notebook

Using TensorFlow backend.


In [2]:
ls sklab/

__init__.py  [1m[34m__pycache__[m[m/ [1m[34mattention[m[m/   [1m[34mdata[m[m/        [1m[34mstats[m[m/


In [3]:
class SequenceEncoder:
    def __init__(self, alphabet: t.Iterable[str]):
        unique = sorted(set(chain.from_iterable(alphabet)))
        self._mapping = dict(
            (val, key) for key, val in enumerate(unique, 1)
        )
        self._oov = len(self._mapping) + 1
    
    def __call__(self, sequence: t.Sequence[str]) -> np.ndarray:
        encoded = (self._mapping.get(char, self._oov) for char in sequence)
        return np.fromiter(encoded, INT_T, len(sequence))
        
    @property
    def mapping(self) -> t.Mapping[str, int]:
        return copy.deepcopy(self._mapping)
    
    @property
    def oov(self) -> int:
        return self._oov


def expand_categories(categories: np.ndarray, dtype=FLOAT_T) -> np.ndarray:
    ncat = categories.shape[0]
    maxcat = categories.max()
    expanded = np.zeros(maxcat*ncat, dtype=FLOAT_T).reshape((ncat, maxcat))
    for i, j in enumerate(categories):
        expanded[i,:j] = 1
    return expanded

In [4]:
train = pd.read_csv('data/abundant_train.tsv', sep='\t')
allotypes = {
    seqrec.id: str(seqrec.seq) for seqrec in SeqIO.parse('data/allotype_binding_regions.faa', 'fasta')
}
consurf = pd.read_csv('data/consurf.tsv', sep='\t')


In [5]:
train.keys()

Index(['accession', 'allotype', 'article_affiliations', 'article_authors',
       'article_chemical_list', 'article_date', 'article_id', 'article_title',
       'bind_or_elution_id', 'category', 'collection', 'inequality',
       'journal_title', 'measurement', 'measurement_ord', 'measurement_type',
       'method', 'object_type', 'peptide', 'quantitative',
       'reference_category_name', 'reference_type', 'source', 'species',
       'submission_affiliations', 'submission_authors', 'submission_date',
       'submission_id', 'submission_title', 'submitter_name', 'units'],
      dtype='object')

In [6]:
consurf.keys()

Index(['position', 'consurf_score'], dtype='object')

In [7]:
# peptide length distribution

train['peptide'].apply(len).value_counts()


9     116089
10     32707
8       5860
11      3439
13       311
12       273
15       223
14       175
18        34
17        31
20        14
7          9
16         6
21         5
19         3
6          2
5          2
26         2
23         1
30         1
Name: peptide, dtype: int64

In [8]:
PEPTIDE_LEN = 16
ALLOTYPE_LEN = 64

# select ALLOTYPE_LEN most variable MHC positions (i.e. positions with large consurf scores)
mhc_positions = consurf.sort_values('consurf_score', ascending=False).iloc[:ALLOTYPE_LEN,0].sort_values()
# remove recods with overly long peptides
train_filt = train[train['peptide'].apply(len) <= PEPTIDE_LEN]

In [9]:
peptides_train = train_filt['peptide']
accessions_train = train_filt['accession']
categories_train = train_filt['measurement_ord']
allotypes_train = (
    F(map, allotypes.get)
    >> (map, op.itemgetter(*mhc_positions))
    >> list
)(accessions_train)


In [10]:
# allotype AA distribution
allotype_train_aa_counts = Counter(chain.from_iterable(allotypes_train))
(pd.Series(allotype_train_aa_counts) / sum(allotype_train_aa_counts.values())).sort_values()

C    0.001142
P    0.022825
N    0.023451
M    0.029600
K    0.030530
I    0.031085
F    0.032541
H    0.035138
W    0.038424
D    0.042060
L    0.044233
S    0.046269
V    0.057988
Y    0.060929
T    0.065004
Q    0.066136
G    0.071495
E    0.091462
R    0.100311
A    0.109377
dtype: float64

In [11]:
# peptide AA distribution
peptide_train_aa_counts = Counter(chain.from_iterable(peptides_train))
(pd.Series(peptide_train_aa_counts) / sum(peptide_train_aa_counts.values())).sort_values()

C    0.014703
W    0.019917
H    0.022215
M    0.032078
Q    0.032353
D    0.037534
N    0.038360
E    0.043513
P    0.046422
G    0.047950
K    0.052160
R    0.053258
Y    0.054931
T    0.059862
F    0.059978
S    0.063748
I    0.068264
V    0.068847
A    0.068972
L    0.114938
dtype: float64

In [12]:
allotype_aa = chain.from_iterable(allotypes_train)
peptide_aa = chain.from_iterable(peptides_train)
seqencoder = SequenceEncoder(
    (F(chain) >> (filter, lambda aa: aa != '-'))(allotype_aa, peptide_aa)
)

In [13]:
peptides_enc_train, peptides_mask_train = pp.stack(
    [seqencoder(pep) for pep in peptides_train],
    shape=(PEPTIDE_LEN,), dtype=INT_T, filler=0
)
allotypes_enc_train, allotypes_mask_train = pp.stack(
    [seqencoder(allo) for allo in allotypes_train],
    shape=(ALLOTYPE_LEN,), dtype=INT_T, filler=0 
)
categories_enc_train = expand_categories(categories_train)


  op.setitem(stacked, [i, *slices_], op.getitem(arr, slices_))
  op.setitem(mask, [i, *slices_], True)


In [16]:
batch = 256
emb_d = 8
depth = 3
r = 4
ffn_hid = emb_d * r
ffn_act = 'elu'
dropout = 0.1

input_allotypes = layers.Input(
    shape=(ALLOTYPE_LEN,), name='input_allotypes', dtype='int32'
)
input_peptides = layers.Input(
    shape=(PEPTIDE_LEN,), name='input_peptides', dtype='int32'
)

embeddings = layers.Embedding(
    input_dim=seqencoder.oov+1, output_dim=emb_d,
    mask_zero=False, name='embeddings'
)

# embedded_allotypes = layers.Lambda(
#     identity, output_shape=(ALLOTYPE_LEN, emb_d), name='embed_allotypes'
# )(embeddings(input_allotypes))
embedded_peptides = layers.Lambda(
    identity, output_shape=(lambda s: [s[0], s[1], s[2]]), name='embed_peptides'
)(embeddings(input_peptides))


In [17]:
embedded_peptides.shape


TensorShape([Dimension(None), Dimension(16), Dimension(8)])

In [18]:
# Keras tensor
KTensor = t.NewType('KTensor', tf.Tensor)

# TODO find a way to specify a list of length 3 as input and a list
# TODO of length 2 as output
QKVAttention = t.Callable[[t.List[KTensor]], t.List[KTensor]]

# TODO implement as Layer and Model objects



class LayerNormalisation(layers.Layer):

    def __init__(self, eps=K.epsilon(), **kwargs):
        self.eps = eps
        self.gamma = None  # set in LaterNormalisation.__build__
        self.beta = None  # set in LaterNormalisation.__build__
        super().__init__(**kwargs)

    def build(self, input_shape):
        self.gamma = self.add_weight(
            name='gamma', shape=input_shape[-1:],
            initializer=initializers.Ones(), trainable=True
        )
        self.beta = self.add_weight(
            name='beta', shape=input_shape[-1:],
            initializer=initializers.Zeros(), trainable=True
        )
        super().build(input_shape)

    def call(self, inputs, **kwargs) -> KTensor:
        """
        :param inputs: a Keras tensor
        :param kwargs:
        :return:
        """
        x = inputs
        mean = K.mean(x, axis=-1, keepdims=True)
        std = K.std(x, axis=-1, keepdims=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

    def compute_output_shape(self, input_shape):
        return input_shape


class BatchDot(layers.Layer):
    """
    A wrapper around keras.backend.batch_dot
    """
    
    def __init__(self, axes: t.Optional[t.Union[int, t.Tuple[int, int]]], **kwargs):
        super().__init__(**kwargs)
        self.axes = axes

    def call(self, inputs, **kwargs) -> KTensor:
        return layers.Lambda(
            lambda x: K.batch_dot(x[0], x[1], axes=self.axes)
        )(inputs)

    def compute_output_shape(self, input_shape):
        x_shape, y_shape = input_shape
        x_ndim, y_ndim = map(len, input_shape)
        if x_ndim < 2 or y_ndim < 2:
            raise ValueError(
                f'Can not do batch_dot on inputs with rank < 2. Received inputs '
                f'with shapes {x_shape} and {y_shape}.'
            )
        x_batch = x_shape[0]
        y_batch = y_shape[0]
        if not (x_batch is None or y_batch is None) and x_batch != y_batch:
            raise ValueError(
                f'Can not do batch_dot on inputs with different batch sizes. '
                f'Received inputs with shapes {x_shape} and {y_shape}.'
            )
        # resolve different axes cases
        axes = (
            [self.axes, self.axes] if isinstance(self.axes, int) else
            list(self.axes) if self.axes is not None else
            [x_ndim - 1, y_ndim - 1] if y_ndim == 2 else
            [x_ndim - 1, y_ndim - 2]
        )
        # make sure all axes are either None or integers
        # TODO rewrite this message and condition
        if any([isinstance(axis, (list, tuple)) for axis in axes]):
            raise ValueError(
                f'Multiple target dimensions are not supported. '
                f'Expected: None, int, (int, int). Received: {axes}.'
            )
        # resolve negative indices
        axes_noneg = [
            axes[0] if axes[0] >= 0 else axes[0] + x_ndim,
            axes[1] if axes[1] >= 0 else axes[1] + y_ndim
        ]
        # make sure we are not multiplying along the batch axis
        if 0 in axes:
            raise ValueError(
                'Can not perform batch_dot over axis 0. If your inputs are not '
                'batched, add a dummy batch dimension to your inputs using '
                'K.expand_dims(x, 0)'
            )
        # use a dummy Dot layer to calculate output shape
        dot = layers.Dot(axes_noneg)
        return dot.compute_output_shape(input_shape)


class SplitHeads(layers.Layer):
    # TODO add docs and argument checks
    def __init__(self, r: int, **kwargs):
        super().__init__(**kwargs)
        self.r = r

    def _split(self, x: tf.Tensor) -> tf.Tensor:
        return util.split_heads(self.r, x)

    def call(self, inputs: KTensor, **kwargs) -> KTensor:
        return layers.Lambda(
            self._split, output_shape=self.compute_output_shape
        )(inputs)

    def compute_output_shape(self, input_shape):
        b, l, d = input_shape
        d_r = d // self.r
        rb = None if b is None else b * self.r
        return rb, l, d_r


class MergeHeads(layers.Layer):
    # TODO add docs and argument checks
    def __init__(self, r: int, **kwargs):
        super().__init__(**kwargs)
        self.r = r

    def _merge(self, x: tf.Tensor) -> tf.Tensor:
        return util.merge_heads(self.r, x)

    def call(self, inputs, **kwargs):
        return layers.Lambda(
            self._merge, output_shape=self.compute_output_shape
        )(inputs)

    def compute_output_shape(self, input_shape):
        rb, l, d_r = input_shape
        d = self.r * d_r
        b = None if rb is None else rb // self.r
        return b, l, d
    

class GroupAttentions(layers.Layer):
    # TODO add docs and argument checks
    def __init__(self, r: int, **kwargs):
        super().__init__(**kwargs)
        self.r = r

    def _group(self, x: tf.Tensor) -> tf.Tensor:
        return util.group_attentions(self.r, x)

    def call(self, inputs, **kwargs) -> KTensor:
        return layers.Lambda(
            self._group, output_shape=self.compute_output_shape
        )(inputs)

    def compute_output_shape(self, input_shape):
        rb, l_q, l_k = input_shape
        b = None if rb is None else rb // self.r
        return b, l_q, self.r, l_k


class ScaledDotProductAttention(layers.Layer):
    """
    Build a subgraph for scaled dot product attention.
    """

    def __init__(self, dropout: float, return_drop=False, **kwargs):
        """
        :param dropout:
        :param return_drop: return attention matrix after dropout
        :param kwargs:
        """
        super().__init__(**kwargs)
        self.dropout = layers.Dropout(dropout) if dropout else None
        self.return_drop = return_drop

    def call(self, inputs: t.List[KTensor], **kwargs) -> t.List[KTensor]:
        q, k, v = inputs
        return self._call(q, k, v)

    # TODO merge call and _call
    def _call(self, q: KTensor, k: KTensor, v: KTensor) -> t.List[KTensor]:
        r"""
        Argument shape legend: b - batch, l - length (number of entries in a
        sequence), d – entry length (embedding dimensions)
        Given:
            $ Q \in {R}^{ {l}_{q} \times d } $
            $ K \in {R}^{ {l}_{k} \times d } $
        the scale dot-product attention matrix is defined as
        $$
        A = softmax( \frac{ Q \times {K}^{T}) }{ \sqrt{d} } )
        $$
        Given a value $ V \in {R}^{ {l}_{v} \times d } $, such that
        ${l}_{v} = {l}_{k}$ this layer calculates returns both the attention
         matrix and the $ A \times V $ product
        :param q: a query tensor of shape [b, l_q,  d]
        :param k: a key tensor of shape [b, l_k, d]
        :param v: a value tensor of shape [b, l_v, d], such that l_v == l_k
        :return: $ A \times V $ tensor of shape [b, l_v, d], attention
        matrix of shape [b, l_q, l_k]
        """
        d = K.shape(q)[-1]
        scaling_factor = K.sqrt(K.cast(d, dtype=K.floatx()))
        # Q \times {K}^{T} => shape = [b, l_q, l_k]
        similarity = BatchDot(axes=(2, 2))([q, k])
        att_scaled = layers.Activation('softmax')(similarity / scaling_factor)
        att_drop = self.dropout(att_scaled) if self.dropout else att_scaled
        # A \times V => shape = [b, l_v, d]
        att_v = BatchDot(axes=None)([att_drop, v])
        return [att_v, att_drop if self.return_drop else att_scaled]

    def compute_output_shape(self, input_shape):
        q_shape, k_shape, v_shape = input_shape
        b_q, l_q, d_q = q_shape
        b_k, l_k, d_k = k_shape
        b_v, l_v, d_v = v_shape
        # TODO check that:
        #     1. b_q == b_k == b_v (if they are not None)
        #     2. d_q == d_k; these must not be None
        #     3. l_k == l_v; these must not be None
        #     4. d_v is not None
        # if not (b_q is None or b_k is None) and b_q != b_k:
        #     raise ValueError(
        #         '...'
        #     )
        # if not (d_q is None or d_k is None) and d_q != d_k:
        #     raise ValueError(
        #         '...'
        #     )
        product_shape = (b_q, l_v, d_v)
        attention_shape = (b_q, l_q, l_k)
        return [product_shape, attention_shape]


class MultiHeadAttention(layers.Layer):
    """
    Transform a single-headed attention block into a multi-headed attention
    """

    def __init__(self, attention: QKVAttention, r: int, d_r: int, **kwargs):
        # TODO check d and r compatibility
        # TODO check dropout
        super().__init__(**kwargs)
        self.attention = attention
        self.r = r
        self.d_r = d_r
        self.d = d_r * r
        # head splitter and merger
        self.splitter = SplitHeads(self.r)
        self.merger = MergeHeads(self.r)
        self.att_grouper = GroupAttentions(self.r)
        # create linear mappings for Q, K and V
        self.q_map = layers.Dense(self.d, use_bias=False)
        self.k_map = layers.Dense(self.d, use_bias=False)
        self.v_map = layers.Dense(self.d, use_bias=False)
        # create a linear mapping for A \times V
        self.att_v_map = layers.Dense(self.d, use_bias=False)

    def call(self, inputs, **kwargs) -> t.List[KTensor]:
        q, k, v = inputs
        return self._call(q, k, v)

    def _call(self, q: KTensor, k: KTensor, v: KTensor) -> t.List[KTensor]:
        """
        :param q:
        :param k:
        :param v:
        :return: returns a grouped attention matrix (for more details see
        util.group_attentions)
        """
        # transform subspaces and split heads
        q_split = self.splitter(self.q_map(q))
        k_split = self.splitter(self.k_map(k))
        v_split = self.splitter(self.v_map(v))
        # calculate attention heads
        att_v_split, att_split = self.attention([q_split, k_split, v_split])
        # merge heads and apply a linear map
        att_v_merged = self.merger(att_v_split)
        att_v = self.att_v_map(att_v_merged)
        att_groups = self.att_grouper(att_split)
        return [att_v, att_groups]
    
    def compute_output_shape(self, input_shape):
        q_shape, k_shape, v_shape = input_shape
        q_split_shape = self.splitter.compute_output_shape(q_shape)
        k_split_shape = self.splitter.compute_output_shape(k_shape)
        v_split_shape = self.splitter.compute_output_shape(v_shape)
        att_v_split_shape, att_split_shape = self.attention.compute_output_shape(
            [q_split_shape, k_split_shape, v_split_shape]
        )
        att_v_merge_shape = self.merger.compute_output_shape(att_v_split_shape)
        att_v_shape = self.att_v_map.compute_output_shape(att_v_merge_shape)
        att_groups_shape = self.att_grouper.compute_output_shape(att_split_shape)
        return [att_v_shape, att_groups_shape]


In [19]:
ph_x = K.placeholder(shape=(None, PEPTIDE_LEN, emb_d))
ph_y = K.placeholder(shape=(None, ALLOTYPE_LEN, emb_d))

In [20]:
ph_z = BatchDot([2, 2])([ph_x, ph_y])
ph_z 

<tf.Tensor 'batch_dot_1/lambda_1/MatMul:0' shape=(?, 16, 64) dtype=float32>

In [21]:
BatchDot(None)([ph_z, ph_y])

<tf.Tensor 'batch_dot_2/lambda_2/MatMul:0' shape=(?, 16, 8) dtype=float32>

In [22]:
K.batch_dot(ph_z, ph_y)

<tf.Tensor 'MatMul:0' shape=(?, 16, 8) dtype=float32>

In [23]:
_split = SplitHeads(4)(ph_x)
_merge = MergeHeads(4)(_split)

K.int_shape(_split), K.int_shape(_merge)

((None, 16, 2), (None, 16, 8))

In [24]:
layers.Dense(8)(_merge)

<tf.Tensor 'dense_5/add:0' shape=(?, 16, 8) dtype=float32>

In [25]:
attention = ScaledDotProductAttention(dropout)
mhattention = MultiHeadAttention(attention, r, emb_d // r)

In [26]:
attention([ph_x, ph_y, ph_y])

[<tf.Tensor 'scaled_dot_product_attention_2/batch_dot_4/lambda_6/MatMul:0' shape=(?, 16, 8) dtype=float32>,
 <tf.Tensor 'scaled_dot_product_attention_2/activation_1/truediv:0' shape=(?, 16, 64) dtype=float32>]

In [27]:
list(map(K.int_shape, attention([ph_x, ph_y, ph_y])))

[(None, 64, 8), (None, 16, 64)]

In [28]:
list(map(K.int_shape, attention([ph_x, ph_x, ph_x])))

[(None, 16, 8), (None, 16, 16)]

In [29]:
mhattention([ph_x, ph_x, ph_x])

[<tf.Tensor 'multi_head_attention_2/dense_9/Reshape_2:0' shape=(?, 16, 8) dtype=float32>,
 <tf.Tensor 'multi_head_attention_2/group_attentions_2/lambda_17/transpose:0' shape=(?, ?, 4, ?) dtype=float32>]

In [30]:
list(map(K.int_shape, mhattention([ph_x, ph_x, ph_x])))

[(None, 16, 8), (None, 16, 4, 16)]