In [2]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [14]:
class EncoderLayer(layers.Layer):
    def __init__(self,num_heads,key_dim,feature_dim,ff_dim,dropout):
        super().__init__()
        self.multiheadatt = layers.MultiHeadAttention(num_heads=num_heads,key_dim=key_dim,\
        	dropout=dropout)
        self.feed_forward_layer = keras.Sequential([
        	layers.Dense(ff_dim,activation='relu'),\
        	layers.Dense(feature_dim)])
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(dropout)
        self.dropout2 = layers.Dropout(dropout)	

    def call(self,input,training,attention_mask):
    	attention_output = self.multiheadatt(input,input,input,attention_mask=attention_mask)
    	attention_output = self.dropout1(attention_output,training=training)
    	out1 = self.layernorm1(input+attention_output)

    	ffn_output = self.feed_forward_layer(out1)
    	ffn_output = self.dropout2(ffn_output,training=training)
    	out2 = self.layernorm2(out1+ffn_output)
    	return out2

In [13]:
sample_encoder_layer = EncoderLayer(num_heads=8,key_dim=512,feature_dim=512,ff_dim=2048,dropout=0.1)

In [14]:
sample_encoder_layer_output = sample_encoder_layer(tf.random.uniform((64, 43, 512)), False, None)

In [15]:
sample_encoder_layer_output

<tf.Tensor: shape=(64, 43, 512), dtype=float32, numpy=
array([[[-0.30807516, -0.7205544 ,  2.4360769 , ..., -0.5028709 ,
         -0.36315203,  1.3812314 ],
        [-0.9100085 , -0.753673  ,  0.17345531, ...,  0.3546194 ,
         -0.78518265,  1.040365  ],
        [ 1.2655596 ,  1.1582346 ,  0.2333653 , ..., -0.5825921 ,
          0.1024247 ,  0.08858624],
        ...,
        [ 0.8774717 ,  0.97786987, -0.4663051 , ..., -0.01119186,
          0.03093335,  0.3781778 ],
        [ 1.5765157 ,  1.002756  ,  2.4507341 , ..., -0.33592287,
          0.73702234,  0.83981174],
        [ 1.4055518 , -0.5618804 ,  1.8225296 , ..., -0.19343297,
          0.94391245,  0.84352016]],

       [[ 0.91029596, -0.3154266 ,  1.4921741 , ...,  1.1360452 ,
         -1.2503752 ,  1.2760714 ],
        [ 0.9605291 ,  0.5872383 ,  2.6066182 , ...,  0.2145131 ,
          0.44794998, -0.12872437],
        [ 0.59091246, -0.49139056,  0.36990115, ..., -0.5641501 ,
          1.6124548 , -0.7470893 ],
        ...,

In [4]:
class DecoderLayer(layers.Layer):
    def __init__(self,num_heads,key_dim,feature_dim,ff_dim,dropout):
        super().__init__()
        self.multiheadatt1 = layers.MultiHeadAttention(num_heads=num_heads,key_dim=key_dim,\
            dropout=dropout)
        self.multiheadatt2 = layers.MultiHeadAttention(num_heads=num_heads,key_dim=key_dim,\
            dropout=dropout)

        self.feed_forward_layer = keras.Sequential([
            layers.Dense(ff_dim,activation='relu'),\
            layers.Dense(feature_dim)])

        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = layers.Dropout(dropout)
        self.dropout2 = layers.Dropout(dropout)
        self.dropout3 = layers.Dropout(dropout)

    def call(self,input,encoder_outputs,training,look_ahead_mask,padding_mask):
        attention_output1 = self.multiheadatt1(input,input,input,attention_mask=look_ahead_mask)
        attention_output1 = self.dropout1(attention_output1,training=training)
        out1 = self.layernorm1(input+attention_output1)

        attention_output2 = self.multiheadatt2(value=encoder_outputs,key=encoder_outputs,\
                                               query=out1,attention_mask=padding_mask)
        attention_output2 = self.dropout1(attention_output2,training=training)
        out2 = self.layernorm2(out1+attention_output2)

        ffn_output = self.feed_forward_layer(out2)
        ffn_output = self.dropout2(ffn_output,training=training)
        out2 = self.layernorm3(out1+ffn_output)
        return out2

In [7]:
def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
    return pos * angle_rates

def positional_encoding(position, d_model):
    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

    # apply sin to even indices in the array; 2i
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

    # apply cos to odd indices in the array; 2i+1
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

    pos_encoding = angle_rads[np.newaxis, ...]

    return tf.cast(pos_encoding, dtype=tf.float32)

In [15]:
class Encoder(layers.Layer):
    def __init__(self,num_encoder,num_heads,key_dim,feature_dim,ff_dim,dropout):
        super().__init__()
        patches = 100
        self.num_encoder = num_encoder
        self.pos_encoder = positional_encoding(patches,feature_dim)
        self.encoder_layers = [EncoderLayer(num_heads,key_dim,feature_dim,ff_dim,dropout) \
                              for _ in range(num_encoder)]
        self.dropout = layers.Dropout(dropout)
    
    def call(self, inputs,padding_mask,training=True):
        inputs += self.pos_encoder
        x = self.dropout(inputs,training=training)
        for i in range(self.num_encoder):
            x = self.encoder_layers[i](x,training=training,attention_mask=padding_mask)
        return x

In [16]:
sample_encoder = Encoder(5,8,64,256,512,0.1)

In [17]:
sample_encoder_layer_output = sample_encoder(
    tf.random.uniform((1,100,256)),padding_mask=None)

In [18]:
sample_encoder_layer_output

<tf.Tensor: shape=(1, 100, 256), dtype=float32, numpy=
array([[[-1.1621993 , -1.2075881 ,  0.65932155, ...,  0.58714896,
         -1.0473044 , -0.10137503],
        [-0.90982187, -0.8698658 ,  0.4754767 , ...,  0.2795088 ,
         -0.57237613, -1.2295765 ],
        [-0.54445714, -1.5848739 ,  0.83412457, ..., -0.16353464,
          0.19105615, -0.01728549],
        ...,
        [ 0.39641896, -1.9151562 ,  0.87335896, ..., -0.52655905,
         -0.30221856,  1.3261619 ],
        [-0.9286312 , -2.0479248 , -0.9718388 , ..., -0.35256937,
         -0.6017862 ,  1.2943859 ],
        [-0.49348664, -1.8590837 , -0.07375112, ..., -0.61502576,
         -0.8946515 ,  0.68893313]]], dtype=float32)>

In [24]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self,num_decoder,num_heads,key_dim,feature_dim,ff_dim,dropout):
        super(Decoder, self).__init__()
        patches = 100
        self.num_decoder = num_decoder
        self.pos_encoding = positional_encoding(patches, feature_dim)
        self.dec_layers = [DecoderLayer(num_heads,key_dim,feature_dim,ff_dim,dropout)
                           for _ in range(num_decoder)]
        self.dropout = tf.keras.layers.Dropout(dropout)

    def call(self, x, enc_output, training,
           look_ahead_mask, padding_mask):
        x += self.pos_encoding
        x = self.dropout(x, training=training)
        for i in range(self.num_decoder):
            x = self.dec_layers[i](x, enc_output, training,
                                    look_ahead_mask, padding_mask)
        return x

In [28]:
sample_decoder = Decoder(num_decoder=2, num_heads=4,
                         key_dim=64,feature_dim=256,ff_dim=512,dropout=0.1)
temp_input = tf.random.uniform((1,100,256))
output = sample_decoder(temp_input,
                          enc_output=sample_encoder_layer_output,
                          training=False,
                          look_ahead_mask=None,
                          padding_mask=None)
output.shape

TensorShape([1, 100, 256])

In [29]:
sample_decoder.pos_encoding

<tf.Tensor: shape=(1, 100, 256), dtype=float32, numpy=
array([[[ 0.00000000e+00,  1.00000000e+00,  0.00000000e+00, ...,
          1.00000000e+00,  0.00000000e+00,  1.00000000e+00],
        [ 8.41470957e-01,  5.40302277e-01,  8.01961780e-01, ...,
          1.00000000e+00,  1.07460786e-04,  1.00000000e+00],
        [ 9.09297407e-01, -4.16146845e-01,  9.58144367e-01, ...,
          1.00000000e+00,  2.14921558e-04,  1.00000000e+00],
        ...,
        [ 3.79607737e-01, -9.25147533e-01,  7.45109499e-01, ...,
          9.99937236e-01,  1.04235075e-02,  9.99945700e-01],
        [-5.73381901e-01, -8.19288254e-01, -8.97521228e-02, ...,
          9.99935985e-01,  1.05309617e-02,  9.99944568e-01],
        [-9.99206841e-01,  3.98208797e-02, -8.52340877e-01, ...,
          9.99934673e-01,  1.06384167e-02,  9.99943435e-01]]],
      dtype=float32)>