In [100]:
import tensorflow as tf
import numpy as np
import math
print(tf.__version__)

2.0.0-beta1


# 一 基础知识
## 如何实现 tf.keras.layers.Layer

根据 [tf.keras.layers.Layer 文档](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#class_layer)，实现一个自定义的 layer 需要实现三个方法:
1. \_\_init\_\_(): 保存一些设置
2. build(input_shape): 通常用于创建 layer 需要的参数。
3. call(inputs, **kwargs): 前向的逻辑。

例如下面自定义的 Dense Layer:

In [27]:
class MyDense(tf.keras.layers.Layer):
    def __init__(self, hidden_size):
        super(MyDense, self).__init__()
        self.hidden_size = hidden_size
    
    def build(self, input_shape):
        length = input_shape.as_list()[-1]
        with tf.name_scope('weights'):
            self.dense_weights = self.add_weight(
                'weights',
                shape=[length, self.hidden_size],
                dtype='float32',
                initializer=tf.random_normal_initializer(mean=0, stddev=0.1))
        super(MyDense, self).build(input_shape)
    
    def call(self, inputs):
        with tf.name_scope('dense'):
            outputs = tf.matmul(inputs, self.dense_weights)
            return outputs
        
    def get_config(self):
        return {
            'hidden_size': self.hidden_size
        }
    
def test_dense():
    layer = MyDense(20)
    x = tf.ones([5, 10])
    y = layer(x)
    print(f"Dense: (5, 10) x (10, 20) -> {y.shape}")
    
test_dense()

Dense: (5, 10) x (10, 20) -> (5, 20)


## 如何实现 tf.keras.Model

根据 [tf.keras.Model 文档](https://www.tensorflow.org/api_docs/python/tf/keras/Model)，实现一个自定义的 model 需要实现两个方法：
1. \_\_init\_\_(): 保存一些设置，通常设置 model 的 layer 信息。
2. call(inputs, training): 模型的前向计算。

例如下面的例子：

In [99]:
class MyModel(tf.keras.Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
    self.dropout = tf.keras.layers.Dropout(0.5)

  def call(self, inputs, training=False):
    x = self.dense1(inputs)
    if training:
      x = self.dropout(x, training=training)
    return self.dense2(x)

def test_model():
    model = MyModel()
    x = tf.ones([8, 16])
    y = model.predict(x)
    print(f'{x.shape} x model -> {y.shape}' )
test_model()

(8, 16) x model -> (8, 5)


# 二 实现 transformer
## Embedding 层
根据论文的第3.4小节，在 transformer 中，emebedding 层和 输出的 softmax 之前的映射层是共享权重的。

In [117]:
class EmbeddingSharedWeights(tf.keras.layers.Layer):
    """实现论文的3.4节.
    """
    def __init__(self, vocab_size, hidden_size):
        super(EmbeddingSharedWeights, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        
    def build(self, input_shape):
        # 这里权重初始化方法是随便选的..
        self.shared_weights = self.add_weight(
            "weights",
            shape=[self.vocab_size, self.hidden_size],
            dtype='float32',
            initializer=tf.random_normal_initializer(
              mean=0., stddev=self.hidden_size**-0.5))
        super(EmbeddingSharedWeights, self).build(input_shape)
        
    def call(self, inputs, mode='embedding'):
        if mode == 'embedding':
            return self._embedding(inputs)
        elif mode == 'linear':
            return self._linear(inputs)
        else:
            raise ValueError(f'unknown mode {mode}')
        
    def get_config(self):
        return {
            'vocab_size': self.vocab_size,
            'hidden_size': self.hidden_size
        }
    
    def _embedding(self, inputs):
        """inputs: [batch_size, length]
        """
        with tf.name_scope('embedding'):
            mask = tf.cast(tf.not_equal(inputs, 0), tf.float32)
            # mask shape: [batch_size, length, 1]
            mask = tf.expand_dims(mask, -1)
            # embeddings shape: [batch_size, length, hidden_size]
            embeddings = tf.gather(self.shared_weights, inputs)
            embeddings *= mask
            embeddings *= (self.hidden_size ** 0.5)
            
            return embeddings
        
    def _linear(self, inputs):
        """inputs: [batch_size, length, hidden_size]
        """
        with tf.name_scope('presoftmax_linear'):
            logits = tf.matmul(inputs, self.shared_weights, transpose_b=True)
            return logits
        
def test_embedding_layer():
    vocab_size = 100
    hidden_size = 256
    layer = EmbeddingSharedWeights(vocab_size, hidden_size)
    
    batch_size = 8
    length = 32
    emb_inputs = tf.ones([batch_size, length], tf.int32)
    emb_outputs = layer(emb_inputs, mode='embedding')
    print(f'Embedding: {emb_inputs.shape} x {layer.shared_weights.shape} -> {emb_outputs.shape}')
    
    linear_inputs = tf.ones([batch_size, length, hidden_size])
    linear_outputs = layer(linear_inputs, mode='linear')
    print(f'Linear: {linear_inputs.shape} x {layer.shared_weights.shape} -> {linear_outputs.shape}')
    
test_embedding_layer()

Embedding: (8, 32) x (100, 256) -> (8, 32, 256)
Linear: (8, 32, 256) x (100, 256) -> (8, 32, 100)


## position encoding
由于全连接无法获取词的位置信息，我们需要对词的位置进行编码。编码有两种方式，要么学一个位置的 emebdding ，要么直接定义一个位置函数对位置做 encoding。Transformer 中采用的是第二种方式，参考论文3.5节。

Position encoding 实现如下：

In [147]:
def get_position_encoding(
    length, hidden_size, min_timescale=1.0, max_timescale=1.0e4):
  """其实这儿 min_timescale 和 max_timescale 的含义我并没有很理解。。
  """
  position = tf.cast(tf.range(length), tf.float32)
  num_timescales = hidden_size // 2
  log_timescale_increment = (
      math.log(float(max_timescale) / float(min_timescale)) /
      (tf.cast(num_timescales, tf.float32) - 1))
  inv_timescales = min_timescale * tf.exp(
      tf.cast(tf.range(num_timescales), tf.float32) * -log_timescale_increment)
  scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
  # 这一步与论文中不完全一致，论文中是区分奇数偶数位置的，这里是 encoding 的前一半用 sin，后一半用 cos
  signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
  return signal

def test_position_encoding():
    length = 32
    hidden_size = 256
    encoding = get_position_encoding(length, hidden_size)
    print(f'Position encoding shape: {encoding.shape}')
    
    emb = tf.ones([8, length, hidden_size])
    emb = emb + encoding
    print(f'embedding+position encoding shape: {emb.shape}')
test_position_encoding()

Position encoding shape: (32, 256)
embedding+position encoding shape: (8, 32, 256)


##  Attention layer
attention 机制可以总结为把 query, key, value 映射到 output 的通用框架。其中，output 是所有 value 的加权平均，而权重由 query 和 对应的 key 计算得到。通常情况下， key 和 value 是同一个向量。在 transformer 中，这种框架用在了三个地方（论文第3.2.3节）。
1. encoder multi-head self-attention: 在 encoder 中，query, key, value 是一样的，都是上一层的输出。例如上一层输出了 length 个维度为 hidden_size 的向量，则 self-attention 通过计算这些向量的点积和 sofmax 得到 length * length 个向量权重，最后取加权平均。Multihead 就是把向量映射到不同的空间，计算多次权重。最后把不同权重得到的加权平均拼在一起。
2. decoder multi-head self-attention: 在 docoder 中，self-attention 大体上和 encoder 的一样。不同的一点是，由于在预测时，当前位置的词是不知道后面位置的词是什么的，为了防止在训练时当前位置和后面的位置做 attention，我们对每个位置都做了 mask，即 mask 掉当前位置的所有后面的词。
3. encoder-decoder multi-head attention: 在 encoder-decoder attention 中，query 是 decoder 的输入（或者说是上一层的输出），key 和 value 是 encoder 的输出。这样的结构保证 decoder 的每个位置都可以和 input 的所有位置 attend 到。

In [47]:
def _float32_softmax(logits, name=None):
  """Computes a softmax activation in float32.

  When training a model using float16, softmax is still done in float32 for
  numeric stability.

  Args:
    logits: A tensor, with any shape accepted by `tf.nn.softmax`.

  Returns:
    A tensor with the same dtype as `logits`.
  """
  input_dtype = logits.dtype
  logits = tf.cast(logits, tf.float32)
  output = tf.nn.softmax(logits, name=name)
  return tf.cast(output, input_dtype)


class Attention(tf.keras.layers.Layer):
  """Multi-headed attention layer."""

  def __init__(self, hidden_size, num_heads, attention_dropout):
    """Initialize Attention.

    Args:
      hidden_size: int, output dim of hidden layer.
      num_heads: int, number of heads to repeat the same attention structure.
      attention_dropout: float, dropout rate inside attention for training.
    """
    if hidden_size % num_heads:
      raise ValueError(
          "Hidden size ({}) must be divisible by the number of heads ({})."
          .format(hidden_size, num_heads))

    super(Attention, self).__init__()
    self.hidden_size = hidden_size
    self.num_heads = num_heads
    self.attention_dropout = attention_dropout

  def build(self, input_shape):
    """Builds the layer."""
    # Layers for linearly projecting the queries, keys, and values.
    self.q_dense_layer = tf.keras.layers.Dense(
        self.hidden_size, use_bias=False, name="q")
    self.k_dense_layer = tf.keras.layers.Dense(
        self.hidden_size, use_bias=False, name="k")
    self.v_dense_layer = tf.keras.layers.Dense(
        self.hidden_size, use_bias=False, name="v")
    self.output_dense_layer = tf.keras.layers.Dense(
        self.hidden_size, use_bias=False, name="output_transform")
    super(Attention, self).build(input_shape)

  def get_config(self):
    return {
        "hidden_size": self.hidden_size,
        "num_heads": self.num_heads,
        "attention_dropout": self.attention_dropout,
    }

  def split_heads(self, x):
    """Split x into different heads, and transpose the resulting value.

    The tensor is transposed to insure the inner dimensions hold the correct
    values during the matrix multiplication.

    Args:
      x: A tensor with shape [batch_size, length, hidden_size]

    Returns:
      A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads]
    """
    with tf.name_scope("split_heads"):
      batch_size = tf.shape(x)[0]
      length = tf.shape(x)[1]

      # Calculate depth of last dimension after it has been split.
      depth = (self.hidden_size // self.num_heads)

      # Split the last dimension
      x = tf.reshape(x, [batch_size, length, self.num_heads, depth])

      # Transpose the result
      return tf.transpose(x, [0, 2, 1, 3])

  def combine_heads(self, x):
    """Combine tensor that has been split.

    Args:
      x: A tensor [batch_size, num_heads, length, hidden_size/num_heads]

    Returns:
      A tensor with shape [batch_size, length, hidden_size]
    """
    with tf.name_scope("combine_heads"):
      batch_size = tf.shape(x)[0]
      length = tf.shape(x)[2]
      x = tf.transpose(x, [0, 2, 1, 3])  # --> [batch, length, num_heads, depth]
      return tf.reshape(x, [batch_size, length, self.hidden_size])

  def call(self, x, y, bias, training, cache=None):
    """Apply attention mechanism to x and y.

    Args:
      x: a tensor with shape [batch_size, length_x, hidden_size]
      y: a tensor with shape [batch_size, length_y, hidden_size]
      bias: attention bias that will be added to the result of the dot product.
      training: boolean, whether in training mode or not.
      cache: (Used during prediction) dictionary with tensors containing results
        of previous attentions. The dictionary must have the items:
            {"k": tensor with shape [batch_size, i, key_channels],
             "v": tensor with shape [batch_size, i, value_channels]}
        where i is the current decoded length.

    Returns:
      Attention layer output with shape [batch_size, length_x, hidden_size]
    """
    # Linearly project the query (q), key (k) and value (v) using different
    # learned projections. This is in preparation of splitting them into
    # multiple heads. Multi-head attention uses multiple queries, keys, and
    # values rather than regular attention (which uses a single q, k, v).
    q = self.q_dense_layer(x)
    k = self.k_dense_layer(y)
    v = self.v_dense_layer(y)

    if cache is not None:
      # Combine cached keys and values with new keys and values.
      k = tf.concat([tf.cast(cache["k"], k.dtype), k], axis=1)
      v = tf.concat([tf.cast(cache["v"], k.dtype), v], axis=1)

      # Update cache
      cache["k"] = k
      cache["v"] = v

    # Split q, k, v into heads.
    q = self.split_heads(q)
    k = self.split_heads(k)
    v = self.split_heads(v)

    # Scale q to prevent the dot product between q and k from growing too large.
    depth = (self.hidden_size // self.num_heads)
    q *= depth ** -0.5

    # Calculate dot product attention
    logits = tf.matmul(q, k, transpose_b=True)
    logits += bias
    weights = _float32_softmax(logits, name="attention_weights")
    if training:
      weights = tf.nn.dropout(weights, rate=self.attention_dropout)
    attention_output = tf.matmul(weights, v)

    # Recombine heads --> [batch_size, length, hidden_size]
    attention_output = self.combine_heads(attention_output)

    # Run the combined outputs through another linear projection layer.
    attention_output = self.output_dense_layer(attention_output)
    return attention_output


class SelfAttention(Attention):
  """Multiheaded self-attention layer."""

  def call(self, x, bias, training, cache=None):
    return super(SelfAttention, self).call(x, x, bias, training, cache)


def test_attention():
    batch_size = 2
    length = 16
    hidden_size = 512
    num_heads = 8
    
    x = tf.ones([batch_size, length, hidden_size])
    bias = tf.zeros([batch_size, 1, 1, length])
    layer = SelfAttention(hidden_size, num_heads, 0.5)
    outputs = layer(x, bias, True)
    print(f"Encoder attention: {x.shape} -> {outputs.shape}")

test_attention()

Encoder attention: (2, 16, 512) -> (2, 16, 512)


## Position-wise Feed-Forward Networks
在 attention layer 之后，transfomer 会接两个全连接层（论文第3.3节），即:

$l1 = RELU(xW1 + b1) \ \ \ l2 = l1W2 + b2$

需要注意的是，这里的全连接不是把上一层的输出拼在一起计算，而是每个位置单独计算，且权重是共享的。

In [66]:
class FeedForwardNetwork(tf.keras.layers.Layer):
    def __init__(self, hidden_size, filter_size, relu_dropout):
        super(FeedForwardNetwork, self).__init__()
        self.hidden_size = hidden_size
        self.filter_size = filter_size
        self.relu_dropout = relu_dropout
    
    def build(self, input_shape):
        self.filter_dense_layer = tf.keras.layers.Dense(
            self.filter_size,
            use_bias=True,
            activation=tf.nn.relu,
            name='filter_layer')
        self.output_dense_layer = tf.keras.layers.Dense(
            self.hidden_size,
            use_bias=True,
            name='output_layer')
        super(FeedForwardNetwork, self).build(input_shape)
        
    def get_config(self):
        return {
            'hidden_size': self.hidden_size,
            'filter_size': self.filter_size,
            'relu_dropout': self.relu_dropout
        }
    
    def call(self, inputs, training):
        """inputs: [batch_size, length, hidden_size]
        """
        output = self.filter_dense_layer(inputs)
        if training:
            output = tf.nn.dropout(output, rate=self.relu_dropout)
        output = self.output_dense_layer(output)
        
        return output
    
def test_ffn():
    batch_size = 8
    length = 16
    hidden_size = 256
    
    layer = FeedForwardNetwork(hidden_size, hidden_size/2, 0.5)
    inputs = tf.zeros([batch_size, length, hidden_size])
    outputs = layer(inputs, True)
    layer1_shape = layer.filter_dense_layer.weights[0].shape
    layer2_shape = layer.output_dense_layer.weights[0].shape
    print(f'FeedForwardNetwork: {inputs.shape} x layer1 {layer1_shape} x layer2 {layer2_shape} -> {outputs.shape}')
    
test_ffn()

FeedForwardNetwork: (8, 16, 256) x layer1 (256, 128) x layer2 (128, 256) -> (8, 16, 256)


## Layer Normalization
关于 layer normalization，可以看我的这篇文章: [Layer Normalization](https://nbviewer.jupyter.org/github/wzpfish/paper-note/blob/master/notes/nlp/layer_normalization.ipynb)

在这里实现如下：

In [63]:
class LayerNormalization(tf.keras.layers.Layer):
    def __init__(self, hidden_size):
        super(LayerNormalization, self).__init__()
        self.hidden_size = hidden_size
    
    def build(self, input_shape):
        self.scale = self.add_weight(
            'layer_norm_scale',
            shape=[self.hidden_size],b
            dtype='float32',
            initializer=tf.ones_initializer())
        self.bias = self.add_weight(
            'layer_norm_bias',
            shape=[self.hidden_size],
            dtype='float32',
            initializer=tf.zeros_initializer())
        super(LayerNormalization, self).build(input_shape)
    
    def get_config(self):
        return {
            'hidden_size': self.hidden_size
        }
    
    def call(self, inputs, epsilon=1e-6):
        """inputs: [batch_size, length, hidden_size]
        """
        mean = tf.reduce_mean(inputs, axis=[-1], keepdims=True)
        variance = tf.reduce_mean(tf.square(inputs - mean), axis=[-1], keepdims=True)
        norm_inputs = (inputs - mean) * tf.math.rsqrt(variance + epsilon)
        outputs = norm_inputs * self.scale + self.bias
        return outputs
    

def test_layer_norm():
    batch_size = 8
    length = 16
    hidden_size = 256
    
    layer = LayerNormalization(hidden_size)
    inputs = tf.zeros([batch_size, length, hidden_size])
    outputs = layer(inputs)
    scale_shape = layer.scale.shape
    bias_shape = layer.bias.shape
    print(f'{inputs.shape}, layer norm scale: {scale_shape}, layer norm bias: {bias_shape} -> {outputs.shape}')

test_layer_norm()

(8, 16, 256), layer norm scale: (256,), layer norm bias: (256,) -> (8, 16, 256)


## Encoder
在 transformer 中，encoder 由6个相同的 layer stack 在一起，每个 layer 又由两个 sublayer stack 在一起。其中，第一个 sublayer 是 multi-head self-attention，第二个 sublayer 是 point-wise feed forward netword。在每个 sublayer 又有个 residual connection 以及 layer normalization。

encoder 的实现如下：

In [75]:
class LayerNormResidualWrapper(tf.keras.layers.Layer):
    def __init__(self, layer, params):
        super(LayerNormResidualWrapper, self).__init__()
        self.layer = layer
        self.params = params
        
    def build(self, input_shape):
        self.layer_norm = LayerNormalization(self.params['hidden_size'])
        super(LayerNormResidualWrapper, self).build(input_shape)
        
    def get_config(self):
        return {
            'params': self.params
        }
    
    def call(self, inputs, *args, **kwargs):
        training = kwargs['training']
        outputs = self.layer_norm(inputs)
        outputs = self.layer(outputs, *args, **kwargs)
        
        if training:
            outputs = tf.nn.dropout(outputs, rate=self.params['layer_normresidual_dropout'])
        outputs = outputs + inputs
        return outputs
    

class EncoderStack(tf.keras.layers.Layer):
    def __init__(self, params):
        super(EncoderStack, self).__init__()
        self.params = params
        self.layers = []
        
    def build(self, input_shape):
        params = self.params
        for _ in range(params['num_hidden_layers']):
            self_attention_layer = SelfAttention(params['hidden_size'], params['num_heads'], params['attention_dropout'])
            feed_forward_network = FeedForwardNetwork(params['hidden_size'], params['filter_size'], params['relu_dropout'])
        
            self.layers.append([
                LayerNormResidualWrapper(self_attention_layer, params),
                LayerNormResidualWrapper(feed_forward_network, params)
            ])
        
        self.output_normalization = LayerNormalization(params['hidden_size'])
        super(EncoderStack, self).build(input_shape)
    
    def get_config(self):
        return {
            'params': params
        }
    
    def call(self, inputs, attention_bias, training):
        """inputs: [batch_size, length, hidden_size]
           attention_bias: [batch_size, 1, 1, length]
        """
        for i, layer in enumerate(self.layers):
            self_attention_layer = layer[0]
            feed_forward_network = layer[1]
            
            with tf.name_scope('layer_{}'.format(i)):
                with tf.name_scope('self_attention'):
                    inputs = self_attention_layer(inputs, attention_bias, training=training)
                with tf.name_scope('ffn'):
                    inputs = feed_forward_network(inputs, training=training)
        
        outputs = self.output_normalization(inputs)
        return outputs
    

def test_encoder_stack():
    params = {
        'num_hidden_layers': 6,
        'hidden_size': 512,
        'num_heads': 8,
        'attention_dropout': 0.5,
        'relu_dropout': 0.5,
        'layer_normresidual_dropout': 0.5,
        'filter_size': 256
    }
    batch_size = 4
    length = 6
    
    encoder = EncoderStack(params)
    inputs = tf.ones([batch_size, length, params['hidden_size']])
    attention_bias = tf.zeros([batch_size, 1, 1, length])
    outputs = encoder(inputs, attention_bias, True)
    print(f'{inputs.shape} x encoder -> {outputs.shape}')
    
test_encoder_stack()

(4, 6, 512) x encoder -> (4, 6, 512)


## Decoder
decoder 也是由6个相同的 layer stack 在一起。与 encoder 不同的是，每个 layer 由3个 sublayer 构成，分别是 masked multi-head self-attention, encoder-decoder attention 和 point-wise feed forward network。同样，每个 sublayer 又有个 residual connection 以及 layer normalization。

In [94]:
def get_decoder_self_attetion_bias(length):
    """ 如果长度为3，则每个位置 mask 掉后面的位置后得到：
                            1 0 0
                            1 1 0
                            1 1 1
        然后 reshape 成 [1, 1, length, length].
    """
    neg_inf = -1e9
    with tf.name_scope('decoder_self_attention_bias'):
        # 计算下三角矩阵.
        valid_locs = tf.linalg.band_part(tf.ones([length, length]), -1, 0)
        # 为了适配 attention 计算时的 [batch_size, num_head, length, length]
        valid_locs = tf.reshape(valid_locs, [1, 1, length, length])
        decoder_bias = neg_inf * (1.0 - valid_locs)
    return decoder_bias

class DecoderStack(tf.keras.layers.Layer):
    def __init__(self, params):
        super(DecoderStack, self).__init__()
        self.params = params
        self.layers = []
        
    def build(self, input_shape):
        params = self.params
        for _ in range(params['num_hidden_layers']):
            self_attention_layer = SelfAttention(
                params['hidden_size'],
                params['num_heads'],
                params['attention_dropout'])
            enc_dec_attention_layer = Attention(
                params['hidden_size'],
                params['num_heads'],
                params['attention_dropout'])
            feed_forward_network = FeedForwardNetwork(params['hidden_size'], params['filter_size'], params['relu_dropout'])
            
            self.layers.append([
                LayerNormResidualWrapper(self_attention_layer, params),
                LayerNormResidualWrapper(enc_dec_attention_layer, params),
                LayerNormResidualWrapper(feed_forward_network, params)
            ])
        self.output_normalization = LayerNormalization(params['hidden_size'])
        super(DecoderStack, self).build(input_shape)
    
    def get_config(self):
        return {
            'params': self.params
        }
    
    def call(self, decoder_inputs, encoder_outputs, decoder_self_attention_bias, input_attention_bias, training, cache=None):
        """decoder_inputs: (batch_size, target_length, hidden_size), embedding+position encoding 的输出.
           encoder_outputs: (batch_size, input_length, hidden_size), encoder 的输出.
           decoder_self_attention_bias: decoder 每个位置后续位置的 mask.
           input_attention_bias: 输入的 mask，用于 encoder-decoder attention.
        """
        for n, layer in enumerate(self.layers):
            self_attention_layer = layer[0]
            enc_dec_attention_layer = layer[1]
            feed_forward_network = layer[2]
            
            layer_name = 'layer_{}'.format(n)
            layer_cache = cache[layer_name] if cache is not None else None
            with tf.name_scope(layer_name):
                with tf.name_scope('self_attention'):
                    decoder_inputs = self_attention_layer(
                        decoder_inputs,
                        decoder_self_attention_bias,
                        training=training,
                        cache=layer_cache)
                with tf.name_scope('encdec_attention'):
                    decoder_inputs = enc_dec_attention_layer(
                        decoder_inputs,
                        encoder_outputs,
                        input_attention_bias,
                        training=training)
                with tf.name_scope('ffn'):
                    decoder_inputs = feed_forward_network(decoder_inputs, training=training)
        outputs = self.output_normalization(decoder_inputs)
        return outputs


def test_decoder_stack():
    params = {
        'num_hidden_layers': 6,
        'num_heads': 8,
        'hidden_size': 512,
        'filter_size': 256,
        'attention_dropout': 0.5,
        'relu_dropout': 0.5,
        'layer_normresidual_dropout': 0.5
    }
    
    layer = DecoderStack(params)
    
    batch_size = 4
    input_length = 16
    target_length = 20
    decoder_inputs = tf.ones([batch_size, target_length, params['hidden_size']])
    encoder_outputs = tf.ones([batch_size, input_length, params['hidden_size']])
    decoder_self_attention_bias = get_decoder_self_attetion_bias(target_length)
    input_attention_bias = tf.zeros([batch_size, 1, 1, input_length])
    outputs = layer(decoder_inputs, encoder_outputs, decoder_self_attention_bias, input_attention_bias, training=True)
    print(f'decoder inputs: {decoder_inputs.shape}, encoder outputs: {encoder_outputs.shape} x decoder -> {outputs.shape}')
    
test_decoder_stack()

decoder inputs: (4, 20, 512), encoder outputs: (4, 16, 512) x decoder -> (4, 20, 512)


## Beam Search
beam search 细节可以看[Beam Search 原理及实现](https://nbviewer.jupyter.org/github/wzpfish/paper-note/blob/master/notes/nlp/beam_search.ipynb)

In [128]:
INF = 1. * 1e7

class _StateKeys(object):
    """State 中的 key 的定义
    """
    #TODO: 每个加一个注释.
    CUR_INDEX = "CUR_INDEX"
    ALIVE_SEQ = "ALIVE_SEQ"
    ALIVE_LOG_PROBS = "ALIVE_LOG_PROBS"
    ALIVE_CACHE = "ALIVE_CACHE"
    FINISHED_SEQ = "FINISHED_SEQ"
    FINISHED_SCORES = "FINISHED_SCORES"
    FINISHED_FLAGS = "FINISHED_FLAGS"


class SequenceBeamSearch:
    def __init__(self, symbols_to_logits_fn, vocab_size, batch_size, beam_size, alpha, max_decode_length, eos_id):
        """
        """
        self.symbols_to_logits_fn = symbols_to_logits_fn
        self.vocab_size = vocab_size
        self.batch_size = batch_size
        self.beam_size = beam_size
        self.alpha = alpha
        self.max_decode_length = max_decode_length
        self.eos_id = eos_id
        
    def search(self, initial_ids, initial_cache):
        state, state_shapes = self._create_initial_state(initial_ids, initial_cache)
        
        finished_state = tf.while_loop(
            self._continue_search, self._search_step, loop_vars=[state],
            shape_invariants=[state_shapes], parallel_iterations=1, back_prop=False)
        finished_state = finished_state[0]

        alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
        alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS]
        finished_seq = finished_state[_StateKeys.FINISHED_SEQ]
        finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
        finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]
        
        # 由于有可能 finished_seq 里一个序列都没有，即没有任何一个序列走到了 eos token，这时候需要把
        # alive_seq 作为 backup.
        finished_cond = tf.reduce_any(finished_flags, 1, name="finished_cond")
        seq_cond = _expand_to_same_rank(finished_cond, finished_seq)
        score_cond = _expand_to_same_rank(finished_cond, finished_scores)
        finished_seq = tf.where(seq_cond, finished_seq, alive_seq)
        finished_scores = tf.where(score_cond, finished_scores, alive_log_probs)
        return finished_seq, finished_scores
        
    def _create_initial_state(self, initial_ids, initial_cache):
        """inital_ids: 预测时的初始 id (一般设为0)，维度为 (batch_size, )
        如果 batch_size 为 3， 则 initial_ids 为 [0, 0, 0]
        """
        # 当前 decode 到哪个位置，初始为 0
        cur_index = tf.constant(0)
        
        # 还没有 decode 完成的 sequence， 即没有decode 到 eos token.
        alive_seq = _expand_to_beam_size(initial_ids, self.beam_size)
        # (batch_size, beam_size, 1)
        alive_seq = tf.expand_dims(alive_seq, axis=2)
        
        # alive_log_probs 保存每个 batch 每个 beam 下的 sequence 的 log probability。
        # 初始化 sequence 的概率为1，即 log prob 为 0.
        # 维度为 (batch_size, beam_size)
        # 例如，当 batch size 为3， beam size 为4时，alive_log_probs 初始化为:
        # [[  0. -inf -inf -inf]
        #  [  0. -inf -inf -inf]
        #  [  0. -inf -inf -inf]]
        initial_log_probs = tf.constant([[0.] + [-float("inf")] * (self.beam_size - 1)])
        alive_log_probs = tf.tile(initial_log_probs, [self.batch_size, 1])
        
        # 将 cache 中保存的每一个变量都加一维 beam_size 维，使得不同 beam 下 cache 的变量不一样。
        alive_cache = tf.nest.map_structure(lambda t: _expand_to_beam_size(t, self.beam_size), initial_cache)
        
        # 初始化用户保存已经预测完成的 sequence 的变量。
        finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32)
        # 初始化用户保存已经预测完成的 sequence 的 log probability.
        finished_scores = tf.ones([self.batch_size, self.beam_size]) * -INF
        # 初始化用户保存 sequence 是否已经预测完成的变量。
        finished_flags = tf.zeros([self.batch_size, self.beam_size], tf.bool)
        
        # 初始化 state，这个 state 命名是根据 tf.while_loop 来的（类比 rnn 中的初始化 state）。
        state = {
            _StateKeys.CUR_INDEX: cur_index,
            _StateKeys.ALIVE_SEQ: alive_seq,
            _StateKeys.ALIVE_LOG_PROBS: alive_log_probs,
            _StateKeys.ALIVE_CACHE: alive_cache,
            _StateKeys.FINISHED_SEQ: finished_seq,
            _StateKeys.FINISHED_SCORES: finished_scores,
            _StateKeys.FINISHED_FLAGS: finished_flags
        }
        
        # 在 tf.while_loop 为了保证正确性，每个 loop 都会检查 state 中变量的 shape 是不是和 shape_invariants 设置的 shape 保持一致。
        # 如果不一致，就会报错。因此，如果 state 中的变量在 loop 的时候 shape 会变，则需要把它设置的 general 一点，比如 None。
        # 另外，如果 dimension 的值会根据 state 的输入不同而不同，不能提前确定，也要设置成 None，比如 batch size.
        state_shape_invariants = {
            _StateKeys.CUR_INDEX: tf.TensorShape([]),
            _StateKeys.ALIVE_SEQ: tf.TensorShape([None, self.beam_size, None]),
            _StateKeys.ALIVE_LOG_PROBS: tf.TensorShape([None, self.beam_size]),
            _StateKeys.ALIVE_CACHE: tf.nest.map_structure(
                _get_shape_keep_last_dim, alive_cache),
            _StateKeys.FINISHED_SEQ: tf.TensorShape([None, self.beam_size, None]),
            _StateKeys.FINISHED_SCORES: tf.TensorShape([None, self.beam_size]),
            _StateKeys.FINISHED_FLAGS: tf.TensorShape([None, self.beam_size])
        }
        
        return state, state_shape_invariants
    
    def _continue_search(self, state):
        """判断 decode 是否应该停止，decode 停止条件有两个：
            1. 达到最大 decode 长度。
            2. 已经生成的序列的最低分比当前序列的最高分还高，即找不到更好预测序列了。
        """
        i = state[_StateKeys.CUR_INDEX]
        alive_log_probs = state[_StateKeys.ALIVE_LOG_PROBS]
        finished_scores = state[_StateKeys.FINISHED_SCORES]
        finished_flags = state[_StateKeys.FINISHED_FLAGS]
    
        not_at_max_decode_length = tf.less(i, self.max_decode_length)
        max_length_norm = _length_normalization(self.alpha, self.max_decode_length)
        
        # 为什么是取第0个beam，因为存的时候就是排序好的。
        best_alive_scores = alive_log_probs[:, 0] / max_length_norm
        
        finished_scores *= tf.cast(finished_flags, tf.float32)
        # 当前预测完成的序列的最低分, 维度 (batch_size, )
        lowest_finished_scores = tf.reduce_min(finished_scores, axis=1)
        
        # 如果某个batch一个已完成的序列都没有，则把分数设为一个最小值。
        finished_batches = tf.reduce_any(finished_flags, 1)
        lowest_finished_scores += (1.0 - tf.cast(finished_batches, tf.float32)) * -INF
        
        worst_finished_score_better_than_best_alive_score = tf.reduce_all(
            tf.greater(lowest_finished_scores, best_alive_scores)
        )

        return tf.logical_and(
            not_at_max_decode_length,
            tf.logical_not(worst_finished_score_better_than_best_alive_score)
        )
        
    def _search_step(self, state):
        # Step 1. 对于每一个 batch 的每一个 beam，都去 decode 下一个 token。并保留 beam_size * 2个概率最高的序列。
        # 保留 beam_size * 2 的目的是保证至少有 beam_size 个序列是还没 decode 完成的。例如假如每个 beam 都是 eos token 概率
        # 最高，多取一个可以保证能取到非 eos 的 token。
        new_seq, new_log_probs, new_cache = self._grow_alive_seq(state)

        # Step 2. 从 beam_size * 2 个概率最高的序列中，拿出 beam_size 个概率最高的，且还没有 decode 完成的序列。
        alive_state = self._get_new_alive_state(new_seq, new_log_probs, new_cache)
        
        # Step 3. 把新得到的已完成的序列与原来得到的已完成的序列拼在一起，得到新的 beam_size 个 log prob 最高的「已完成」序列。
        finished_state = self._get_new_finished_state(state, new_seq, new_log_probs)
        
        # Step 4. 更新 state.
        new_state = {_StateKeys.CUR_INDEX: state[_StateKeys.CUR_INDEX] + 1}
        new_state.update(alive_state)
        new_state.update(finished_state)
        return [new_state]
        
    def _get_new_finished_state(self, state, new_seq, new_log_probs):
        i = state[_StateKeys.CUR_INDEX]
        finished_seq = state[_StateKeys.FINISHED_SEQ]
        finished_scores = state[_StateKeys.FINISHED_SCORES]
        finished_flags = state[_StateKeys.FINISHED_FLAGS]
        
        length_norm = _length_normalization(self.alpha, i + 1)
        new_scores = new_log_probs / length_norm
        
        new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id)
        new_scores += (1 - tf.cast(new_finished_flags, tf.float32)) * -INF
        
        finished_seq = tf.concat([finished_seq, tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)], axis=2)
        
        finished_seq = tf.concat([finished_seq, new_seq], axis=1)
        finished_scores = tf.concat([finished_scores, new_scores], axis=1)
        finished_flags = tf.concat([finished_flags, new_finished_flags], axis=1)
        
        top_finished_seq, top_finished_scores, top_finished_flags = (
            _gather_topk_beams([finished_seq, finished_scores, finished_flags],
                               finished_scores, self.batch_size, self.beam_size))
        
        return {
            _StateKeys.FINISHED_SEQ: top_finished_seq,
            _StateKeys.FINISHED_SCORES: top_finished_scores,
            _StateKeys.FINISHED_FLAGS: top_finished_flags
        }
        
    def _grow_alive_seq(self, state):
        """ 对于还没有decode完成的每一个 sequence，继续decode下一个词，并保留 beam_size * 2 个序列概率最大的序列。
        Returns:
          topk_seq: 概率最大的topk个序列，shape: (batch_size, beam_size*2, i+2)
          topk_log_probs: topk个序列对应的log prob，shape: (batch_size, beam_size*2)
          new_cache: 序列对应的 attention 中的 k, v 等信息。
        """
        i = state[_StateKeys.CUR_INDEX]
        alive_seq = state[_StateKeys.ALIVE_SEQ]
        alive_log_probs = state[_StateKeys.ALIVE_LOG_PROBS]
        alive_cache = state[_StateKeys.ALIVE_CACHE]
        
        beams_to_keep = 2 * self.beam_size
        
        
        # 把 batch_size 和 beam_size 合并，以便喂到模型中。因为模型并不接受 beam_size 这一维
        flat_ids = _flatten_beam_dim(alive_seq)
        flat_cache = tf.nest.map_structure(_flatten_beam_dim, alive_cache)
        flat_logits, flat_cache = self.symbols_to_logits_fn(flat_ids, i, flat_cache)
        
        # shape: [batch_size, beam_size, vocab_size]
        logits = _unflatten_beam_dim(flat_logits, self.batch_size, self.beam_size)
        new_cache = tf.nest.map_structure(lambda t: _unflatten_beam_dim(t, self.batch_size, self.beam_size), flat_cache)
        
        # shape: [batch_size, beam_size, vocab_size] 即下一个词为词表中每个词的 log prob.
        candidate_log_probs = _log_prob_from_logits(logits)
        log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2)
        
        # 对于每个 batch，都有 beam_size * vocab_size 个 candidate 序列，我们需要从这些序列中找出 log prob 最高的 topk 个。
        flat_log_probs = tf.reshape(log_probs, [-1, self.beam_size * self.vocab_size])
        topk_log_probs, topk_indices = tf.nn.top_k(flat_log_probs, k=beams_to_keep)
        
        # shape: (batch_size, beams_to_keep)
        topk_beam_indices = topk_indices // self.vocab_size
        
        topk_seq, new_cache = _gather_beams(
            [alive_seq, new_cache], topk_beam_indices, self.batch_size,
            beams_to_keep)
        
        topk_word_ids = topk_indices % self.vocab_size
        # shape: (batch_size, beams_to_keep, 1)
        topk_word_ids = tf.expand_dims(topk_word_ids, axis=2)
        topk_seq = tf.concat([topk_seq, topk_word_ids], axis=2)
        return topk_seq, topk_log_probs, new_cache
    
    def _get_new_alive_state(self, new_seq, new_log_probs, new_cache):
        new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id)
        new_log_probs += tf.cast(new_finished_flags, tf.float32) * -INF
        
        top_alive_seq, top_alive_log_probs, top_alive_cache = _gather_topk_beams(
            [new_seq, new_log_probs, new_cache], new_log_probs, self.batch_size, self.beam_size)
        
        return {
            _StateKeys.ALIVE_SEQ: top_alive_seq,
            _StateKeys.ALIVE_LOG_PROBS: top_alive_log_probs,
            _StateKeys.ALIVE_CACHE: top_alive_cache
        }
        
def _gather_topk_beams(nested, score_or_log_prob, batch_size, beam_size):
    _, topk_indexes = tf.nn.top_k(score_or_log_prob, k=beam_size)
    return _gather_beams(nested, topk_indexes, batch_size, beam_size)

def _expand_to_beam_size(tensor, beam_size):
    """给 tensor 添加一维 beam_size 的维度，添加到第一维。比如 tensor 是 (batch_size, ) 则输出是 (batch_size, beam_size) 
    例如: tensor = [1, 2, 3, 4], beam_size 是3，则结果为: 
    [
     [1, 1, 1]
     [2, 2, 2]
     [3, 3, 3]
     [4, 4, 4]
    ]
    """
    tensor = tf.expand_dims(tensor, axis=1)
    tile_dims = [1] * tensor.shape.ndims
    tile_dims[1] = beam_size
    
    return tf.tile(tensor, tile_dims)

def _get_shape_keep_last_dim(tensor):
    """只保留 shape 的最后一维，其它都设为 None。
    """
    shape_list = _shape_list(tensor)
    
    for i in range(len(shape_list) - 1):
        shape_list[i] = None
    
    # 这句话用在什么情况？
    if isinstance(shape_list[-1], tf.Tensor):
        shape_list[-1] = None
    
    return tf.TensorShape(shape_list)
    
def _shape_list(tensor):
    shape = tensor.get_shape().as_list()
    dynamic_shape = tf.shape(tensor)
    for i in range(len(shape)):
        if shape[i] is None:
            shape[i] = dynamic_shape[i]
    return shape
    
def _length_normalization(alpha, length):
    """长度归一化，使得 beam search 给短的 sequence 一些惩罚。
    """
    return tf.pow(((5. + tf.cast(length, tf.float32)) / 6.), alpha)

def _flatten_beam_dim(tensor):
    """ 合并 batch_size 和 beam_size 这俩维到 batch_size * beam_size 一维。
    即 (batch_size, beam_size, ...) -> (batch_size * beam_size, ...)
    """
    shape = _shape_list(tensor)
    shape[0] *= shape[1]
    shape.pop(1)
    return tf.reshape(tensor, shape)

def _unflatten_beam_dim(tensor, batch_size, beam_size):
    """ 与 flatten_beam_dim 效果相反。
    即：(batch_size * beam_size, ...) -> (batch_size, beam_size, ...)
    """
    shape = _shape_list(tensor)
    new_shape = [batch_size, beam_size] + shape[1:]
    return tf.reshape(tensor, new_shape)

def _log_prob_from_logits(logits):
    """ 计算log概率： log(exp(xi) / sigma(exp(xj)))
    """
    return logits - tf.reduce_logsumexp(logits, axis=2, keepdims=True)

def _gather_beams(nested, beam_indices, batch_size, new_beam_size):
    # 生成一个 batch_size * new_beam_size 的 tensor，每个 batch 下面都是对应的 batch 下标。
    # 例如 batch_size = 2, new_beam_size = 3, 则 batch_pos 为:
    # [[0, 0, 0],
    #  [1, ,1 ,1]]
    batch_pos = tf.range(batch_size * new_beam_size) // new_beam_size
    batch_pos = tf.reshape(batch_pos, [batch_size, new_beam_size])
    
    # 把 batch_pos 和 beam_indices 拼在一起，得到一个 (batch_size, new_beam_size, 2) 的指示下标。
    # 最后一维的每个元素都是一个 (batch下标, beam下标).
    # 这个是用于传给 tf.gather_nd 来获取对应下标的元素的。
    indices = tf.stack([batch_pos, beam_indices], axis=2)
    
    return tf.nest.map_structure(lambda state: tf.gather_nd(state, indices), nested)

def _expand_to_same_rank(tensor, target):
  if tensor.shape.rank is None:
    raise ValueError("Expect rank for tensor shape, but got None.")
  if target.shape.rank is None:
    raise ValueError("Expect rank for target shape, but got None.")

  with tf.name_scope("expand_rank"):
    diff_rank = target.shape.rank - tensor.shape.rank
    for _ in range(diff_rank):
      tensor = tf.expand_dims(tensor, -1)
    return tensor

## 完整的 Transformer Model
有了各个 layer 的定义，我们就可以定义完整的 transformer 模型了。

回顾一下一个 transformer 的组成部分：

1. Encoder: inputs -> embedding + position encoding, (multi-head self-attention + feed forward) x N
2. Decoder: targets -> embedding + position encoding, (masked multi-head self-attention + encoder-decoder attention + feed forward) x N
3. Probabilities: Linear + Softmax


In [164]:
def get_padding_bias(inputs):
    """inputs: (batch_size, input_length)
    """
    neg_inf = -1e9
    with tf.name_scope('attention_bias'):
        padding = tf.cast(tf.equal(inputs, 0), tf.float32)
        attention_bias = padding * neg_inf
        # 为了适配 attetion layer，reshape 成 (batch_size, 1, 1, length)
        attention_bias = tf.expand_dims(tf.expand_dims(attention_bias, axis=1), axis=1)
    return attention_bias

EOS_ID = 1

class Transformer(tf.keras.Model):
    def __init__(self, params, name=None):
        super(Transformer, self).__init__(name=name)
        self.params = params
        self.embedding_softmax_layer = EmbeddingSharedWeights(params['vocab_size'], params['hidden_size'])
        self.encoder_stack = EncoderStack(params)
        self.decoder_stack = DecoderStack(params)
    
    def get_config(self):
        return {
            'params': params
        }
    
    def call(self, inputs, training):
        """inputs[0](x): (batch_size, input_length)
           inputs[1](y): (batch_size, target_length)
        """
        if len(inputs) == 2:
            inputs, targets = inputs[0], inputs[1]
        else:
            inputs, targets = inputs[0], None
        
        with tf.name_scope('Transformer'):
            attention_bias = get_padding_bias(inputs)
            encoder_outputs = self._encode(inputs, attention_bias, training)
            
            if targets is None:
                return self._predict(encoder_outputs, attention_bias, training)
            else:
                logits = self._decode(targets, encoder_outputs, attention_bias, training)
                return logits
    
    def _encode(self, inputs, attention_bias, training):
        """inputs(x): (batch_size, input_length)
           attention_bias: (batch_size, 1, 1, input_length)
        """
        with tf.name_scope('encode'):
            embedded_inputs = self.embedding_softmax_layer(inputs)
            with tf.name_scope('add_pos_encoding'):
                length = tf.shape(inputs)[1]
                # 这里 pos_encoding 需要 mask 嘛，其实不 mask 也没关系。因为 self-attention 用了 mask，pad 的
                # embedding 也不会被用上。
                pos_encoding = get_position_encoding(length, self.params['hidden_size'])
                encoder_inputs = embedded_inputs + pos_encoding
            
            if training:
                encoder_inputs = tf.nn.dropout(encoder_inputs, rate=self.params['embedding_dropout'])
            
            return self.encoder_stack(encoder_inputs, attention_bias, training=training)
        
    def _decode(self, targets, encoder_outputs, attention_bias, training):
        """targets: (batch_size, target_length)
           encoder_outputs: (batch_size, input_length, hidden_size) encoder 的输出。
        """
        with tf.name_scope('decode'):
            embedded_targets = self.embedding_softmax_layer(targets)
            with tf.name_scope('shift_targets'):
                # embedded_targets: (batch_size, target_length, hidden_size)
                # 把每个 sequence 向右移一位，即第0维不动，第1维前面pad一个0，并去掉最后一维，第2维不动。
                embedded_targets = tf.pad(embedded_targets, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
            with tf.name_scope('add_pos_encoding'):
                length = tf.shape(targets)[1]
                pos_encoding = get_position_encoding(length, self.params['hidden_size'])
                decoder_inputs = embedded_targets + pos_encoding
            if training:
                decoder_inputs = tf.nn.dropout(decoder_inputs, rate=self.params['embedding_dropout'])
            
            decoder_self_attention_bias = get_decoder_self_attetion_bias(length)
            outputs = self.decoder_stack(
                decoder_inputs,
                encoder_outputs,
                decoder_self_attention_bias,
                attention_bias,
                training=training)
            
            logits = self.embedding_softmax_layer(outputs, mode='linear')
            return logits
    
    def _get_symbols_to_logits_fn(self, max_decode_length, training):
        pos_encoding = get_position_encoding(max_decode_length + 1, self.params['hidden_size'])
        decoder_self_attention_bias = get_decoder_self_attetion_bias(max_decode_length)
        
        def symbols_to_logits_fn(ids, i, cache):
            decoder_input = ids[:, -1:]
            decoder_input = self.embedding_softmax_layer(decoder_input)
            decoder_input += pos_encoding[i: i+1]
            
            self_attention_bias = decoder_self_attention_bias[:, :, i:i+1, :i+1]
            decoder_outputs = self.decoder_stack(
                decoder_input,
                cache.get('encoder_outputs'),
                self_attention_bias,
                cache.get('encoder_decoder_attention_bias'),
                training=training,
                cache=cache)
            
            logits = self.embedding_softmax_layer(decoder_outputs, mode='linear')
            logits = tf.squeeze(logits, axis=[1])
            return logits, cache
        
        return symbols_to_logits_fn
            
        
    def _predict(self, encoder_outputs, attention_bias, training):
        batch_size = tf.shape(encoder_outputs)[0]
        input_length = tf.shape(encoder_outputs)[1]
        max_decode_length = input_length + self.params["extra_decode_length"]
        
        symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length, training)
        
        initial_ids = tf.zeros([batch_size], dtype=tf.int32)
        cache = {
            'layer_{}'.format(layer): {
                'k': tf.zeros([batch_size, 0, self.params['hidden_size']]),
                'v': tf.zeros([batch_size, 0, self.params['hidden_size']])
            } for layer in range(self.params['num_hidden_layers'])
        }
        cache['encoder_outputs'] = encoder_outputs
        cache['encoder_decoder_attention_bias'] = attention_bias
        
        decoded_ids, scores = self._beam_search(
            symbols_to_logits_fn, initial_ids, cache, self.params['vocab_size'],
            self.params['beam_size'], self.params['alpha'], max_decode_length, EOS_ID)
        
        top_decode_ids = decoded_ids[:, 0, 1:]
        top_scores = scores[:, 0]
        
        return {
            'outputs': top_decode_ids,
            'scores': top_scores
        }
        
    def _beam_search(self, symbols_to_logits_fn, initial_ids, initial_cache, 
                     vocab_size, beam_size, alpha, max_decode_length, eos_id):
        batch_size = tf.shape(initial_ids)[0]
        sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size, beam_size, alpha, max_decode_length, eos_id)
        return sbs.search(initial_ids, initial_cache)

def test_transformer():
    params = {
        'vocab_size': 10,
        'hidden_size': 4,
        'num_hidden_layers': 6,
        'num_heads': 2,
        'embedding_dropout': 0.5,
        'attention_dropout': 0.5,
        'layer_normresidual_dropout': 0.5,
        'relu_dropout': 0.5,
        'filter_size': 8,
        'extra_decode_length': 3,
        'beam_size': 3,
        'alpha': 0.6
    }
    model = Transformer(params)
    
    inputs = np.array([[1, 0, 3], [1, 2, 0]])
    targets = np.array([[1, 2, 3, 4], [2, 3, 4, 0]])
    outputs = model([inputs], training=False)
    print(f'Predict: {inputs.shape} x transformer -> seq: {outputs["outputs"].shape}, score: {outputs["scores"].shape}')
    
    outputs = model([inputs, targets], training=True)
    print(f'Train:  inputs-{inputs.shape} targets-{targets.shape} -> logits-{outputs.shape}')
test_transformer()

Predict: (2, 3) x transformer -> seq: (2, 6), score: (2,)
Train:  inputs-(2, 3) targets-(2, 4) -> logits-(2, 4, 10)


## Loss Layer

Loss 采用的是 kl 散度，并且用了 label smoothing。

label smoothing 意思是，本来预测一个单词，如果单词在target里，则认为这个词的概率是1，vocab 中的其它词概率为0。加入 label smoothing 后，使得target里的测概率是1-smooth，其他词的概率为 smooth/(vocab_size-1)。

另外由于 target 中有 padding，在计算时需要把 padding 的影响去掉。

In [None]:
def _pad_tensors_to_same_length(x, y):
    """让x和y第一维大小相同，用0来填充。
    TODO: 有必要吗，logits 和 labels 第一维不是应该一样长的吗？ 
    """
    with tf.name_scope('pad_to_same_length'):
        x_length = tf.shape(x)[1]
        y_length = tf.shape(y)[1]
        max_length = tf.maximum(x_length, y_length)
        
        x = tf.pad(x, [[0, 0], [0, max_length - x_length], [0, 0]])
        y = tf.pad(y, [[0, 0], [0, max_length - y_length]])
        return x, y
    
def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
    """logits: (batch_size, logits_length, vocab_size)
       labels: (batch_size, target_length)
    """
    with tf.name_scope('loss'):
        logits, labels = _pad_tensors_to_same_length(logits, labels)
        
        # label词概率为 1-smoothing。非label词概率为 (smoothing)/(vocab_size-1)
        with tf.name_scope('smoothing_cross_entropy'):
            confidence = 1.0 - smoothing
            low_confidence = smoothing / tf.cast(vocab_size - 1, tf.float32)
            soft_targets = tf.one_hot(tf.cast(labels, tf.int32), depth=vocab_size, on_value=confidence, off_value=low_confidence)
            xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=soft_targets)
           
        # 相当于计算 KL-散度作为loss
        normalizing_constant = -(
            confidence * tf.math.log(confidence) +
            tf.cast(vocab_size - 1, tf.float32) * low_confidence *
            tf.math.log(low_confidence + 1e-20))
        xentropy -= normalizing_constant
        
        weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
        return xentropy * weights, weights

def transformer_loss(logtis, labels, smoothing, vocab_size):
    xentropy, weights = padded_cross_entropy_loss(logits, labels, smoothing, vocab_size)
    return tf.reduce_sum(xentropy) / tf.reduce_sum(weights)

class LossLayer(tf.keras.layers.Layer):
    def __init__(self, vocab_size, label_smoothing):
        super(LossLayer, self).__init__()
        self.vocab_size = vocab_size
        self.label_smoothing = label_smoothing
        
    def get_config(self):
        return {
            'vocab_size': vocab_size,
            'label_smoothing': label_smoothing
        }
    
    def call(self, inputs):
        logits, targets = inputs[0], inputs[1]
        loss = transformer_loss(logits, targets, self.label_smoothing, self.vocab_size)
        self.add_loss(loss)
        return logits

In [None]:
def create_model(params, is_train):
    with tf.name_scope('model'):
        if is_train:
            inputs = tf.keras.layers.Input((None,), dtype='int64', name='inputs')
            targets = tf.keras.layers.Input((None,), dtype='int64', name='targets')
            internal_model = Transformer(params, name="transformer_v2")
            logits = internal_model([inputs, targets], training=is_train)
            vocab_size = params['vocab_size']
            label_smoothing = params["label_smoothing"]
            logits = metrics.LossLayer(vocab_size, label_smoothing)([logits, targets])
            logits = tf.keras.layers.Lambda(lambda x: x, name="logits")(logits)
            return tf.keras.Model([inputs, targets], logits)
        else:
            inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
            internal_model = Transformer(params, name="transformer_v2")
            ret = internal_model([inputs], training=is_train)
            outputs, scores = ret["outputs"], ret["scores"]
            return tf.keras.Model(inputs, [outputs, scores])

## 总结

Transformer 模型在论文中看起来比较简单，但是真正实现起来还是有许多细节要注意，主要有：

* embedding与输出全连接层共享权重
* position encoding 的含义
* 三个原理相同但应用地方不同的 attention
* layer normalization用在哪一层
* beam search的实现
* loss 的计算

总之，花点时间读+抄了一遍 tranformer 的代码，对照原始论文，还是有新收获的。

## Reference
* [tensorflow official transformer implementation](https://github.com/tensorflow/models/tree/master/official/transformer)
* [Attention is all you need](https://arxiv.org/abs/1706.03762)