In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import numpy as np

In [134]:
class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, maxlen, vocab_size, embed_dim):
        super(TokenAndPositionEmbedding, self).__init__()
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions

In [135]:
class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim//num_heads)
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs, training):
        attn_output = self.att(inputs, inputs)
        self.attention_output_result = attn_output
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

In [4]:
vocab_size = 20000  # Only consider the top 20k words
maxlen = 200  # Only consider the first 200 words of each movie review
(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)
print(len(x_train), "Training sequences")
print(len(x_val), "Validation sequences")
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen, padding='post')
x_val = keras.preprocessing.sequence.pad_sequences(x_val, maxlen=maxlen, padding='post')

25000 Training sequences
25000 Validation sequences


In [5]:
word_index = keras.datasets.imdb.get_word_index()
index_word = dict((i, word) for (word, i) in word_index.items())

In [141]:
embed_dim = 32  # Embedding size for each token
num_heads = 2  # Number of attention heads
ff_dim = 32  # Hidden layer size in feed forward network inside transformer

inputs = layers.Input(shape=(maxlen,))
embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
x = embedding_layer(inputs)
transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)
x = transformer_block(x)
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dropout(0.1)(x)
x = layers.Dense(20, activation="relu")(x)
x = layers.Dropout(0.1)(x)
outputs = layers.Dense(2, activation="softmax")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

In [142]:
x_train_emb = model.layers[1](x_train[:8, :])
x_train_tb = model.layers[2](x_train_emb)

In [140]:
model.summary()

Model: "model_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_8 (InputLayer)        [(None, 200)]             0         
                                                                 
 token_and_position_embeddin  (None, 200, 32)          646400    
 g_7 (TokenAndPositionEmbedd                                     
 ing)                                                            
                                                                 
 transformer_block_7 (Transf  (None, 200, 32)          6202      
 ormerBlock)                                                     
                                                                 
 global_average_pooling1d_6   (None, 32)               0         
 (GlobalAveragePooling1D)                                        
                                                                 
 dropout_28 (Dropout)        (None, 32)                0   

In [128]:
[l.shape for l in model.layers[1].weights]

[TensorShape([20000, 32]), TensorShape([200, 32])]

In [15]:
[l.shape for l in model.layers[2].att.weights]

[TensorShape([32, 3, 16]),
 TensorShape([3, 16]),
 TensorShape([32, 3, 16]),
 TensorShape([3, 16]),
 TensorShape([32, 3, 16]),
 TensorShape([3, 16]),
 TensorShape([3, 16, 32]),
 TensorShape([32])]

# Multi-Head Attention

$\large Attention(Q,K,V) = softmax(\frac{QK^{T}}{\sqrt{d_k}})V$

## Prepare tensors: Query, Key, Value

In [167]:
att = model.layers[2].att

query = att._query_dense(x_train_emb)

# `key` = [B, S, N, H]
key = att._key_dense(x_train_emb)

# `value` = [B, S, N, H]
value = att._value_dense(x_train_emb)

In [144]:
x_train_emb.shape

TensorShape([8, 200, 32])

In [145]:
query.shape

TensorShape([8, 200, 2, 16])

In [146]:
np.alltrue(tf.einsum('abc,cde', x_train_emb, att._query_dense.kernel)==query)

True

In [147]:
np.alltrue(tf.einsum('abc,cde', x_train_emb, att._key_dense.kernel)==key)

True

In [148]:
np.alltrue(tf.einsum('abc,cde', x_train_emb, att._value_dense.kernel)==value)

True

## Compute attention score and output

In [168]:
attention_mask = None
training = None


attention_output, attention_scores = att._compute_attention(
    query, key, value, attention_mask, training)
att_output_ = attention_output
attention_output = att._output_dense(attention_output)

In [169]:
import math
query = tf.multiply(query, 1.0 / math.sqrt(float(16)))

print('key shape:', key.shape)
print('query shape:', query.shape)
print('scaled dot product(key*query):', att._dot_product_equation)
att_score = tf.einsum(att._dot_product_equation, key, query)

att_score = att._masked_softmax(att_score, attention_mask)
print('attention score shape:',att_score.shape)

attention_scores_dropout = att._dropout_layer(att_score, training=training)

print()
print('value shape:', value.shape)
print('context vector(score*value):', att._combine_equation)
att_output = tf.einsum(att._combine_equation, attention_scores_dropout, value)
print('context vector shape:',att_output.shape)

key shape (8, 200, 2, 16)
query shape (8, 200, 2, 16)
scaled dot product(key*query): aecd,abcd->acbe
attention score shape: (8, 2, 200, 200)

context vector(score*value): acbe,aecd->abcd
context vector shape: (8, 200, 2, 16)


In [151]:
np.alltrue(attention_scores_dropout == att_score)

True

In [152]:
np.alltrue(att_output_ == att_output)

True

In [153]:
att_score.shape

TensorShape([8, 2, 200, 200])

In [154]:
np.alltrue(attention_output== model.layers[2].attention_output)

True