In [2]:
import numpy as np
import tensorflow as tf
import pandas as pd

In [28]:
class FFN(tf.keras.layers.Layer):
  def  __init__(self, d_model, dff):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(dff, activation="relu")
        self.dense2 = tf.keras.layers.Dense(d_model)

  def call(self, inputs):
        outputs = self.dense1(inputs)
        outputs = self.dense2(outputs)
        return outputs


In [34]:
class Block(tf.keras.layers.Layer):
    def __init__(self, d_model: int, dff: int = 2048, heads: int = 8):
        super().__init__()

        #parameters
        self.d_model = d_model      #model dims 
        self.dff = dff                           # ffn dense layer units
        self.heads = heads               # number of heads

        #layers
        self.ffn = FFN(d_model, dff)
        self.mha = tf.keras.layers.MultiHeadAttention(num_heads=heads, key_dim=d_model)
        self.ln1 = tf.keras.layers.LayerNormalization()
        self.ln2 = tf.keras.layers.LayerNormalization()

    def call(self, inputs, mask):
        attention_outputs = self.mha(query=inputs, value=inputs, attention_mask=mask)       # output shape (None, query_len, d_model)
        outputs = self.ln1(inputs+attention_outputs)
        ffn_outputs = self.ffn(outputs)     # output shape (None, query_len, d_model)
        outputs = self.ln2(inputs+ffn_outputs)        # output shape (None, query_len, d_model)
        return outputs

In [35]:
layer = Block(d_model=16, dff=2048, heads=8)
mask = tf.keras.Input(shape=[4, 4])
source = tf.keras.Input(shape=[4, 16])
outputs = layer(inputs=source, mask=mask)
print(outputs.shape)

(None, 4, 16)
