In [1]:
import tensorflow as tf


In [12]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self,vocab_sz,units):
        super(Encoder, self).__init__()
        self.embedding = tf.keras.layers.Embedding(
            input_dim=vocab_sz,
            output_dim=units
        )
        self.gru = tf.keras.layers.GRU(
            units = units,
            return_state=True,
            return_sequences=True
        )

    def call(self,x,states):
        print(f'before embedding x shape is {x.shape}')
        x = self.embedding(x)
        print(f'after embedding x shape is {x.shape}')
        enc_outputs,output_states = self.gru(
            inputs=x,
            initial_state=states
        )
        return enc_outputs,output_states

In [14]:
sample_encoder = Encoder(1000,512)
sample_states = tf.random.uniform([64,512])
sample_x = tf.random.uniform([64,10])
sample_outputs,sample_output_states = sample_encoder(sample_x,sample_states)
sample_outputs.shape,sample_output_states.shape

before embedding x shape is (64, 10)
after embedding x shape is (64, 10, 512)


(TensorShape([64, 10, 512]), TensorShape([64, 512]))

In [21]:
class BahdanauAttention(tf.keras.layers.Layer):
    def __init__(self,units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.attention = tf.keras.layers.AdditiveAttention()

    def call(self,query,value,mask):
        w1_query = self.W1(query)
        w2_key = self.W2(value)

        query_mask = tf.ones(tf.shape(query)[:-1],dtype=bool)
        value_mask = mask
        context_vector,attention_weights = self.attention(
            inputs=[w1_query,value,w2_key],
            mask=[query_mask,value_mask],
            return_attention_scores=True
        )

        return context_vector,attention_weights

In [22]:
sample_attention = BahdanauAttention(512)

sample_context_vector,sample_attention_weights = sample_attention(
    tf.random.uniform([64,3,1000]),
    tf.random.uniform([64,10,512]),
    None
)
sample_context_vector.shape,sample_attention_weights.shape

(TensorShape([64, 3, 512]), TensorShape([64, 3, 10]))