In [1]:
import tensorflow as tf

class BahdanauAttention(tf.keras.layers.Layer):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)  # Dense layer for the encoder hidden states
        self.W2 = tf.keras.layers.Dense(units)  # Dense layer for the decoder hidden state
        self.V = tf.keras.layers.Dense(1)       # Dense layer to compute alignment scores

    def call(self, query, values):
        """
        Args:
            query: Decoder hidden state (shape: [batch_size, hidden_size]).
            values: Encoder outputs (shape: [batch_size, seq_len, hidden_size]).
        Returns:
            context_vector: Weighted sum of encoder outputs (shape: [batch_size, hidden_size]).
            attention_weights: Attention weights (shape: [batch_size, seq_len]).
        """
        # Add time axis to query for broadcasting (shape: [batch_size, 1, hidden_size])
        query_with_time_axis = tf.expand_dims(query, 1)

        # Compute the alignment scores (shape: [batch_size, seq_len, 1])
        score = self.V(tf.nn.tanh(self.W1(values) + self.W2(query_with_time_axis)))

        # Remove the last axis (shape: [batch_size, seq_len])
        attention_weights = tf.nn.softmax(score, axis=1)

        # Compute the context vector as the weighted sum of values (shape: [batch_size, hidden_size])
        context_vector = tf.reduce_sum(attention_weights * values, axis=1)

        return context_vector, attention_weights

# Example usage
if __name__ == "__main__":
    # Define batch size, sequence length, and hidden size
    batch_size = 64
    seq_len = 10
    hidden_size = 256
    attention_units = 128

    # Instantiate the attention layer
    attention = BahdanauAttention(units=attention_units)

    # Simulated encoder outputs (values) and decoder hidden state (query)
    encoder_outputs = tf.random.normal([batch_size, seq_len, hidden_size])
    decoder_hidden_state = tf.random.normal([batch_size, hidden_size])

    # Apply the attention mechanism
    context_vector, attention_weights = attention(decoder_hidden_state, encoder_outputs)

    print("Context vector shape:", context_vector)  # Expected: [batch_size, hidden_size]
    print("Attention weights shape:", attention_weights)  # Expected: [batch_size, seq_len]


Context vector shape: tf.Tensor(
[[ 0.255141    0.36995074 -0.06074618 ...  0.14289172 -0.33948275
   0.08204244]
 [ 0.17553207  0.20287433 -0.21972388 ...  0.38024572  0.10507832
  -0.1749211 ]
 [ 0.49297252 -0.71173334  0.06620681 ...  0.03723908  0.06681156
   0.04217959]
 ...
 [ 0.5043702   0.14670074 -0.13256606 ... -0.00935103  0.23354119
   0.47012478]
 [ 0.53213084  0.14736369  0.6249427  ...  0.03083069  0.03995124
  -0.91318446]
 [-0.07197922  0.0929099   0.4411536  ...  0.07401175 -0.03586886
   0.24210005]], shape=(64, 256), dtype=float32)
Attention weights shape: tf.Tensor(
[[[0.1168863 ]
  [0.04839835]
  [0.07178821]
  [0.02614368]
  [0.05491129]
  [0.15900852]
  [0.08935837]
  [0.06533986]
  [0.03839758]
  [0.3297678 ]]

 [[0.0757827 ]
  [0.04703011]
  [0.08379294]
  [0.02264802]
  [0.09782696]
  [0.06997547]
  [0.0836059 ]
  [0.10231104]
  [0.33210832]
  [0.08491845]]

 [[0.04823288]
  [0.1771738 ]
  [0.04916682]
  [0.05279026]
  [0.07524373]
  [0.06919555]
  [0.0938158