In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense

class LuongAttention(Layer):
    def __init__(self, attention_type, hidden_size):
        super(LuongAttention, self).__init__()
        self.attention_type = attention_type
        self.hidden_size = hidden_size

        if attention_type == "general":
            self.attention_weight = Dense(hidden_size)
        elif attention_type == "concat":
            self.attention_weight = Dense(hidden_size)
            self.v = tf.Variable(tf.random.normal([hidden_size]), trainable=True)

    def score(self, hidden, encoder_outputs):
        if self.attention_type == "dot":
            # Dot product between hidden state and encoder outputs
            return tf.matmul(encoder_outputs, tf.expand_dims(hidden, axis=-1))[:, :, 0]

        elif self.attention_type == "general":
            # Linear transformation followed by dot product
            energy = self.attention_weight(encoder_outputs)
            return tf.matmul(energy, tf.expand_dims(hidden, axis=-1))[:, :, 0]

        elif self.attention_type == "concat":
            # Concatenate hidden state with encoder outputs
            hidden_expanded = tf.expand_dims(hidden, axis=1)
            hidden_expanded = tf.tile(hidden_expanded, [1, tf.shape(encoder_outputs)[1], 1])
            concat_input = tf.concat([hidden_expanded, encoder_outputs], axis=-1)
            energy = tf.tanh(self.attention_weight(concat_input))
            return tf.reduce_sum(energy * self.v, axis=2)

        else:
            raise ValueError("Unknown attention type: {}".format(self.attention_type))

    def call(self, hidden, encoder_outputs):
        # Compute alignment scores
        alignment_scores = self.score(hidden, encoder_outputs)

        # Softmax normalization to obtain attention weights
        attention_weights = tf.nn.softmax(alignment_scores, axis=1)

        # Compute the context vector as the weighted sum of encoder outputs
        context_vector = tf.matmul(tf.expand_dims(attention_weights, axis=1), encoder_outputs)
        context_vector = tf.squeeze(context_vector, axis=1)

        return context_vector, attention_weights


# Example usage
if __name__ == "__main__":
    batch_size = 2
    seq_len = 5
    hidden_size = 10

    # Simulated inputs
    hidden = tf.random.normal([batch_size, hidden_size])  # Decoder hidden state
    encoder_outputs = tf.random.normal([batch_size, seq_len, hidden_size])  # Encoder outputs

    # Instantiate Luong Attention (dot, general, or concat)
    attention_type = "dot"  # Options: "dot", "general", "concat"
    attention_layer = LuongAttention(attention_type, hidden_size)

    # Forward pass
    context_vector, attention_weights = attention_layer(hidden, encoder_outputs)

    print("Context vector:", context_vector.numpy())
    print("Attention weights:", attention_weights.numpy())


Context vector: [[ 0.81676483  0.06053445  0.19244039 -0.72452813 -0.7267856   0.856664
   1.1044954   1.2002246   0.31152475  0.3382845 ]
 [ 2.384232   -1.3005022  -0.0826828  -0.21933986  1.2408372   0.11178041
  -0.3798153   0.28194258  0.16247584 -0.6566643 ]]
Attention weights: [[1.16724755e-06 1.03445565e-02 5.45606136e-01 4.33148965e-02
  4.00733232e-01]
 [7.02524900e-01 8.75239402e-06 2.93246597e-01 4.05862648e-03
  1.61005635e-04]]
