In [0]:
# Variety Attention Scores
# Dot
# Scaled Dot
# General
# Concat
# Location



import tensorflow as tf
import numpy as np

In [0]:
class Attention(tf.keras.layers.Layer):
  def __init__(self, alignment_type='global', window_width=None, score_function='general', **kwargs):
    super(Attention, self).__init__(**kwargs)
    self.alignment_type = alignment_type
    self.window_width = window_width
    self.score_function = score_function

  def build(self, input_shape):
    # b, seq_len, H
    self.input_sequence_length = input_shape[0][1]
    self.hidden_dim = input_shape[0][2]

    if 'local-p' in self.alignment_type:
      self.W_p = tf.keras.layers.Dense(units = self.hidden_dim, use_bias=False)
      self.W_p.build(input_shape=(None, None, sel.hidden_dim))
      self._trainable_weights += self.W_p.trainable_weights

      self.v_p = tf.keras.layers.Dense(1, use_bias=False)
      self.v_p.build(input_shape=(None, None, self.hidden_dim))
      self._trainable_weights += self.v_p._trainable_weights

    if 'dot' not in self.score_function: # not for dot product  [Not Dot or Scaled Dot]
      self.W_a = tf.keras.layers.Dense(units=self.hidden_dim, use_bias=False)
      self.W_a.build(input_shape=(None, None, self.hidden_dim))
      self._trainable_weights += self.W_a.trainable_weights

    if self.score_function == 'concat':
      self.U_a = tf.keras.layers.Dense(units=self.hidden_dim, use_bias=False)
      self.U_a.build(input_shape=(None, None, self.hidden_dim))
      self._trainable_weights += self.U_a.trainable_weights

      self.v_a = tf.keras.layers.Dense(units=1, use_bias=False)
      self.v_a.build(input_shape=(None, None, self.hidden_dim))
      self._trainable_weights += self.v_a.trainable_weights

    super(Attention, self).build(input_shape)

  def call(self, inputs):
    source_hidden_states = inputs[0]
    target_hidden_state = inputs[1]
    current_timestep = inputs[2]
  
    if self.alignment_type == 'global':
      source_hidden_states = source_hidden_states

    elif 'local' in self.alignment_type:
      self.window_width = 8 if self.window_width is None else self.window_width

      if self.alignment_type == 'local-m': # monotonic
        aligned_position = current_timestep

        left = int(aligned_position + self.window_width if aligned_position- self.window_width >= 0 else 0)
        right = int(aligned_position+ self.window_width if aligned_position + self.window_width <= self.input_sequence_length else self.input_sequence_length)

        source_hidden_states = tf.keras.layers.Lambda(lambda x: x[:, left:right, :])(source_hidden_states)

      elif self.alignment_type == 'local-p': # predictive
        aligned_position = self.W_p(target_hidden_state)
        aligned_position = tf.keras.layers.Activation('tanh')(aligned_position)
        aligned_position = self.v_p(aligned_position)
        aligned_position = tf.keras.layers.Activation('sigmoid')(aligned_position)

        aligned_position = aligned_positon * self.input_sequence_length

    if 'dot' in self.score_function:
      attention_score = tf.keras.layers.Dot(axes=[2,2])([source_hidden_states, target_hidden_state])
      if self.score_function == 'scaled_dot':
        attention-score *= 1 / tf.math.sqrt(tf.cast(source_hidden_states.shape[2], tf.float64))

    elif self.score_function == 'general':
      weighted_hidden_states = self.W_a(source_hidden_states)
      attention_score = tf.keras.layers.Dot(axes=[2,2])([weighted_hidden_states, target_hidden_state])

    elif self.score_function == 'location':
      weighted_target_state = self.W_a(target_hidden_state)
      attention_score = tf.keras.layers.Activation('softmax')(weighted_target_state)
      attention_score = tf.keras.layers.RepeatVector(source_hidden_states.shape[1])(attention_score)
      attention_score = tf.reduce_sum(attention_score, axis=-1)
      attention_score = tf.expand_dims(attention_score, axis=-1)

    elif self.score_function == 'concat':
      weighted_hidden_states = self.W_a(source_hidden_states)
      weighted_target_state = self.U_a(target_hidden_state)
      weighted_sum = weighted_hidden_states + weighted_target_state
      weighted_sum = tf.keras.layers.Activation('tanh')(weighted_sum)
      attention_score = self.V_a(weighted_sum)

    attention_weights = tf.keras.layers.Activation('softmax')(attention_score)

    if self.alignment_type == 'local-p':
      gaussian_estimation = lambda s: tf.exp(-tf.math.square(s - alignment_position) / (2 * tf.math.square(self.window_width / 2)))
      gaussian_factor = gaussian_estimation(0)

      for i in range(1, self.input_sequence_length):
        gaussian_factor = tf.keras.layers.Concatenate(axis=1)([gaussian_factor, gaussian_estimation(i)])

      attention_weights = attention_weights * gaussian_factor

    context_vector = source_hidden_states * attention_weights

    return context_vector, attention_weights