### Transformer: Attention is all you need paper

In [1]:
import tensorflow as tf

In [2]:
tf.enable_eager_execution()

In [3]:
L = tf.keras.layers

In [6]:
class Transformer(tf.keras.Model):
    
    def __init__(self, num_blocks, num_heads, vocab_size, seq_len, d_model, d_k, d_v, d_ff):
        super(Transformer, self).__init__()
        self.num_blocks = num_blocks
        self.num_heads = num_heads
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.seq_len = seq_len
        self.d_k = d_k
        self.d_v = d_v
        self.d_ff = d_ff
        self.word_embed = L.Embedding(vocab_size, d_model)
        
    def _format(self, block, head):
        return str(block) + str(head)
    
    def _init_structure(self, decoder_part=False):
        assert not hasattr(self, "pos_enc"), "The structure is initialized already."
        self.pos_enc = tf.zeros(shape=(1, self.seq_len, self.d_model))
        for pos in range(self.seq_len):
            for i in range(0, self.d_model, 2):
                self.pos_enc[:, pos, i] = tf.sin(pos / (10000 ** ((2 * i)/self.d_model)))
                self.pos_enc[:, pos, i + 1] = tf.cos(pos / (10000 ** ((2 * (i + 1))/self.d_model)))

        for block_id in range(num_blocks):
            for head_id in range(num_heads):
                setattr(self, "Q" + self._format(block_id, head_id), L.Dense(self.d_k))
                setattr(self, "K" + self._format(block_id, head_id), L.Dense(self.d_k))
                setattr(self, "V" + self._format(block_id, head_id), L.Dense(self.d_v))
                if decoder_part:
                    setattr(self, "Qenc" + self._format(block_id, head_id), L.Dense(self.d_k))
                    setattr(self, "Kenc" + self._format(block_id, head_id), L.Dense(self.d_k))
                    setattr(self, "Venc" + self._format(block_id, head_id), L.Dense(self.d_v))
            setattr(self, "O" + str(block_id), L.Dense(self.d_model))
            setattr(self, "FFN1" + str(block_id), L.Dense(self.d_ff, activation="relu"))
            setattr(self, "FFN2" + str(block_id), L.Dense(self.d_model))
            
    def _ffn(self, block_id, attention_output):
        ffn1 = getattr(self, "FFN1" + str(block_id))(attention_output)
        ffn2 = getattr(self, "FFN2" + str(block_id))(ffn1)
        return ffn2
    
    def _scaled_dot_product(self, Q, K, V):
        score = tf.nn.softmax(tf.matmul(Q, K, transpose_b=True)/tf.sqrt(self.d_k), axis=-1)
        score = tf.matmul(scores, V)
        return score
                
    def _multi_head_attention(self, block_id, Q, K, V, connection_head=False):
        head_output = []
        for head_id in self.num_heads:
            if connection_head:
                Q = getattr(self, "Qenc" + self._format(block_id, head_id))(Q)
                K = getattr(self, "Kenc" + self._format(block_id, head_id))(K)
                V = getattr(self, "Venc" + self._format(block_id, head_id))(V)
            else:
                Q = getattr(self, "Q" + self._format(block_id, head_id))(Q)
                K = getattr(self, "K" + self._format(block_id, head_id))(K)
                V = getattr(self, "V" + self._format(block_id, head_id))(V)
            score = self._scaled_dot_product(Q, K, V)
            head_output.append(score)
        head_output = tf.concat(head_output, axis=-1)
        head_output = getattr(self, "O" + str(block_id))(head_output)
        return head_output
    
    def _block_computation(self, *args, **kwargs):
        raise NotImplementedError("Transformer is abstract class. You must implement this function!")
        
    def call(self, *args, **kwargs):
        raise NotImplementedError("Transformer is abstract class. You must implement this function!")

In [8]:
class Encoder(Transformer):
    
    def __init__(self, num_blocks, num_heads, vocab_size, seq_len, d_model, d_k, d_v, d_ff):
        super(Encoder, self).__init__(num_blocks, num_heads, vocab_size, seq_len, d_model, d_k, d_v, d_ff)
        self._init_structure()
    
    def _block_computation(self, block_id, x):
        attention_output = self._multi_head_attention(block_id, x, x, x, False)
        attention_output = L.LayerNormalization()(attention_output + x)
        
        block_output = self._ffn(block_id, attention_output)
        block_output = L.LayerNormalization()(block_output + attention_output)
        return block_output
    
    def call(self, x):
        word_embed = self.word_embed(x)
        word_embed = word_embed + self.pos_enc
        
        block_output = word_embed
        for block_id in range(self.num_blocks):
            block_output = self._block_computation(block_id, block_output)
        return block_output

In [9]:
class Decoder(Transformer):
    
    def __init__(self, num_blocks, num_heads, vocab_size, seq_len, d_model, d_k, d_v, d_ff):
        super(Decoder, self).__init__(num_blocks, num_heads, vocab_size, seq_len, d_model, d_k, d_v, d_ff)
        self._init_structure(decoder_part=True)
        self.logits = L.Dense(units=vocab_size)
    
    def _block_computation(self, block_id, x, encoder_output):
        attention_output = self._multi_head_attention(block_id, x, x, x, connection_head=False)
        attention_output = L.LayerNormalization()(attention_output + x)
        
        connection_output = self._multi_head_attention(block_id, encoder_output, encoder_output, 
                                                       attention_output, connection_head=True)
        connection_output = L.LayerNormalization()(connection_output + attention_output)
        
        block_output = self._ffn(block_id, connection_output)
        block_output = L.LayerNormalization()(block_output + connection_output)
        return block_output
    
    def call(self, x, encoder_output):
        word_embed = self.word_embed(x)
        word_embed = word_embed + self.pos_enc
        
        block_output = word_embed
        for block_id in range(self.num_blocks):
            block_output = self._block_computation(block_id, block_output, encoder_output)
        logits = self.logits(block_output)
        return logits

In [10]:
def loss_function(labels, logits):
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
    return tf.reduce_sum(loss)