In [4]:
import tensorflow as tf
from tensorflow.keras.layers import Layer

### Bahdanau attention (https://arxiv.org/abs/1409.0473) 을 구현합니다.

Attention을하기위한 energy 는다음과같이 계산합니다.

$e_{i j}=a\left(s_{i-1}, h_{j}\right)=v_{a}^{\top} \tanh \left(W_{a} s_{i-1}+U_{a} h_{j}\right)$

Attention weights는 energy를 softmax연산을 하여 확률분포의 형태를 만들어 구합니다.

$\alpha_{i j}=\frac{\exp \left(e_{i j}\right)}{\sum_{k} \exp \left(e_{i k}\right)}$

마지막으로 context vector는 energy 와 encoder hidden state들을 weighted sum하여 구합니다.

$c_{i}=\sum_{j} \alpha_{i j} h_{j}$

In [8]:
class Attention(Layer):
    def __init__(self, **kwargs):
        super(Attention, self).__init__(**kwargs)
        
    def build(self, input_shape):
        batch, enc_len, enc_dim = input_shape[0]
        batch, dec_len, dec_dim = input_shape[1]
        
        # (dec_dim, enc_dim)
        self.W = self.add_weight("W", shape=(dec_dim, enc_dim))
        # (enc_dim, enc_dim)
        self.U = self.add_weight("U",  shape=(enc_dim, enc_dim))
        # (enc_dim, 1)
        self.v = self.add_weight("v", shape=(enc_dim, 1))
      
    
    def _get_attention_context(self, enc_states, dec_state):
        # enc_states : (batch, enc_len, enc_dim)
        # dec_state : (batch, dec_dim)
        
        # (batch, enc_dim)
        w = dec_state @ self.W
        # (batch, enc_len, enc_dim)
        u = enc_states @ self.U
        # (batch, enc_len, enc_dim)
        t = tf.tanh(tf.expand_dims(w, axis=1) + u)
        # (batch, enc_len, 1)
        energy = t @ self.v
        # (batch, enc_len, 1)
        weight = tf.nn.softmax(energy, axis=1)
        # (batch, enc_dim)
        context = tf.math.reduce_sum(weight * enc_states, axis=1)
        
        return context
        
    def call(self, inputs):
        # (batch, enc_len, enc_dim), (batch, dec_len, dec_dim)
        enc_states, dec_states = inputs
        batch, enc_len, enc_dim = enc_states.shape
        batch, dec_len, dec_dim = dec_states.shape
        
        contexts = []
        for i in range(dec_len):
            # (batch, dec_dim)
            dec_state = dec_states[:, i]
            context = self._get_attention_context(enc_states, dec_state)
            contexts.append(context)
        
        # (batch, dec_len, enc_dim)
        contexts = tf.stack(contexts, axis=1)
        
        return contexts

In [9]:
batch = 5
enc_len = 100
dec_len = 50
enc_dim = 256
dec_dim = 128

dec_states = tf.random.normal(shape=(batch, dec_len, dec_dim))
enc_states = tf.random.normal(shape=(batch, enc_len, enc_dim))

contexts = Attention(name='attention_layer')([enc_states, dec_states])
print(contexts.shape)

(5, 50, 256)
