<a href="https://colab.research.google.com/github/yoheikikuta/TensorFlow2-check/blob/master/colab/ParameterSharedTransformerEncoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Parameter-sharing Transformer Encoder

It can be run, but is not validated yet...

In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

try:
    # %tensorflow_version only exists in Colab.
    %tensorflow_version 2.x
except Exception:
    pass

import tensorflow as tf

TensorFlow 2.x selected.


In [2]:
tf.__version__

'2.0.0-rc2'

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import copy
import json
import math
import re
import numpy as np
import six

In [0]:
from tensorflow.keras import layers

In [0]:
def gelu(x):
    """Gaussian Error Linear Unit.
    This is a smoother version of the RELU.
    Original paper: https://arxiv.org/abs/1606.08415
    Args:
    x: float Tensor to perform activation.
    Returns:
    `x` with the GELU activation applied.
    """
    cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
    return x * cdf

In [0]:
def get_shape_list(tensor, expected_rank=None, name=None):
    """Returns a list of the shape of tensor, preferring static dimensions.
    Args:
    tensor: A tf.Tensor object to find the shape of.
    expected_rank: (optional) int. The expected rank of `tensor`. If this is
        specified and the `tensor` has a different rank, and exception will be
        thrown.
    name: Optional name of the tensor for the error message.
    Returns:
    A list of dimensions of the shape of tensor. All static dimensions will
    be returned as python integers, and dynamic dimensions will be returned
    as tf.Tensor scalars.
    """
    # if name is None:
    #     name = tensor.name

    # if expected_rank is not None:
    #     assert_rank(tensor, expected_rank, name)

    shape = tensor.shape.as_list()

    non_static_indexes = []
    for (index, dim) in enumerate(shape):
        if dim is None:
            non_static_indexes.append(index)

    if not non_static_indexes:
        return shape

    dyn_shape = tf.shape(tensor)
    for index in non_static_indexes:
        shape[index] = dyn_shape[index]
    return shape

In [0]:
def reshape_to_matrix(input_tensor):
    """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
    ndims = input_tensor.shape.ndims
    if ndims < 2:
        raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
                        (input_tensor.shape))
    if ndims == 2:
        return input_tensor

    width = input_tensor.shape[-1]
    output_tensor = tf.reshape(input_tensor, [-1, width])
    return output_tensor


def reshape_from_matrix(output_tensor, orig_shape_list):
    """Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
    if len(orig_shape_list) == 2:
        return output_tensor

    output_shape = get_shape_list(output_tensor)

    orig_dims = orig_shape_list[0:-1]
    width = output_shape[-1]

    return tf.reshape(output_tensor, orig_dims + [width])

In [0]:
def create_initializer(initializer_range=0.02):
    """Creates a `truncated_normal_initializer` with the given range."""
    return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)

In [0]:
def dropout(input_tensor, dropout_prob):
    """Perform dropout.
    Args:
    input_tensor: float Tensor.
    dropout_prob: Python float. The probability of dropping out a value (NOT of
        *keeping* a dimension as in `tf.nn.dropout`).
    Returns:
    A version of `input_tensor` with dropout applied.
    """
    if dropout_prob is None or dropout_prob == 0.0:
        return input_tensor

    output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
    return output

In [0]:
def layer_norm(input_tensor, name=None):
    """Run layer normalization on the last dimension of the tensor."""
    return tf.keras.layers.LayerNormalization()(input_tensor)
    # return tf.contrib.layers.layer_norm(
    #     inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)

In [0]:
class Attention(layers.Layer):
    """Attention layer"""

    def __init__(self,
                 num_attention_heads=1,
                 size_per_head=512,
                 query_act=None,
                 key_act=None,
                 value_act=None,
                 initializer_range=0.02,
                 attention_probs_dropout_prob=0.0,
                 do_return_2d_tensor=False,
                 name="attention",
                 **kwargs):
        # Scalar dimensions referenced here:
        #   B = batch size (number of sequences)
        #   F = `from_tensor` sequence length
        #   T = `to_tensor` sequence length
        #   N = `num_attention_heads`
        #   H = `size_per_head`
        super(Attention, self).__init__(name=name, **kwargs)
        self.num_attention_heads = num_attention_heads
        self.size_per_head = size_per_head
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.do_return_2d_tensor = do_return_2d_tensor

        # `query_layer` = [B*F, N*H]
        self.query_layer = layers.Dense(
            units=self.num_attention_heads * self.size_per_head,
            activation=query_act,
            name="query",
            kernel_initializer=create_initializer(initializer_range))
        # `key_layer` = [B*T, N*H]
        self.key_layer = layers.Dense(
            units=self.num_attention_heads * self.size_per_head,
            activation=key_act,
            name="key",
            kernel_initializer=create_initializer(initializer_range))
        # `value_layer` = [B*T, N*H]
        self.value_layer = layers.Dense(
            units=self.num_attention_heads * self.size_per_head,
            activation=value_act,
            name="value",
            kernel_initializer=create_initializer(initializer_range))


    def call(self, from_tensor, to_tensor, attention_mask=None,
             batch_size=None, from_seq_length=None, to_seq_length=None):
        from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
        to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])

        if len(from_shape) != len(to_shape):
            raise ValueError(
                "The rank of `from_tensor` must match the rank of `to_tensor`.")

        if len(from_shape) == 3:
            batch_size = from_shape[0]
            from_seq_length = from_shape[1]
            to_seq_length = to_shape[1]
        elif len(from_shape) == 2:
            if (batch_size is None or from_seq_length is None or to_seq_length is None):
                raise ValueError(
                    "When passing in rank 2 tensors to attention_layer, the values "
                    "for `batch_size`, `from_seq_length`, and `to_seq_length` "
                    "must all be specified.")

        from_tensor_2d = reshape_to_matrix(from_tensor)
        to_tensor_2d = reshape_to_matrix(to_tensor)

        # `query_layer` = [B, N, F, H]
        query_layer = self.query_layer(from_tensor_2d)
        query_layer = self.transpose_for_scores(query_layer, batch_size,
                                            self.num_attention_heads, from_seq_length,
                                            self.size_per_head)

        # `key_layer` = [B, N, T, H]
        key_layer = self.key_layer(to_tensor_2d)
        key_layer = self.transpose_for_scores(key_layer, batch_size, self.num_attention_heads,
                                        to_seq_length, self.size_per_head)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        # `attention_scores` = [B, N, F, T]
        attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
        attention_scores = tf.multiply(attention_scores,
                                       1.0 / math.sqrt(float(self.size_per_head)))

        if attention_mask is not None:
            # `attention_mask` = [B, 1, F, T]
            attention_mask = tf.expand_dims(attention_mask, axis=[1])

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0

            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
            attention_scores += adder

        # Normalize the attention scores to probabilities.
        # `attention_probs` = [B, N, F, T]
        attention_probs = tf.nn.softmax(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = dropout(attention_probs, self.attention_probs_dropout_prob)

        # `value_layer` = [B, T, N, H]
        value_layer = self.value_layer(to_tensor_2d)
        # value_layer = layers.Reshape([batch_size, to_seq_length,
        #                               self.num_attention_heads, self.size_per_head])(value_layer)
        value_layer = layers.Reshape([to_seq_length,
                                      self.num_attention_heads, self.size_per_head])(value_layer)

        # `value_layer` = [B, N, T, H]
        value_layer = tf.transpose(value_layer, [0, 2, 1, 3])

        # `context_layer` = [B, N, F, H]
        context_layer = tf.matmul(attention_probs, value_layer)

        # `context_layer` = [B, F, N, H]
        context_layer = tf.transpose(context_layer, [0, 2, 1, 3])

        if self.do_return_2d_tensor:
            # `context_layer` = [B*F, N*H]
            context_layer = tf.reshape(
                context_layer,
                [batch_size * from_seq_length, self.num_attention_heads * self.size_per_head])
        else:
            # `context_layer` = [B, F, N*H]
            context_layer = tf.reshape(
                context_layer,
                [batch_size, from_seq_length, self.num_attention_heads * self.size_per_head])

        return context_layer

    def transpose_for_scores(self, input_tensor, batch_size, num_attention_heads,
                             seq_length, width):
        # output_tensor = layers.Reshape([batch_size, seq_length,
        #                                 num_attention_heads, width])(input_tensor)
        output_tensor = layers.Reshape([seq_length, num_attention_heads, width])(input_tensor)
        output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
        return output_tensor

In [0]:
test = Attention()

In [0]:
dummy_from = tf.constant([[[i for i in range(712)]]])
dummy_to = tf.constant([[[i for i in range(712)]]])

In [188]:
test(dummy_from, dummy_to).shape

TensorShape([1, 1, 512])

In [0]:
class ParameterSharedTransformerEncoder(tf.keras.Model):
    """Transformer Encoder model with parameter sharing"""

    def __init__(self,
                 attention_mask=None,
                 hidden_size=768,
                 num_hidden_layers=12,
                 intermediate_act_fn=gelu,
                 hidden_dropout_prob=0.1,
                 attention_probs_dropout_prob=0.1,
                 initializer_range=0.02,
                 do_return_all_layers=False,
                 name="transformer_encoder",
                 **kwargs):
        super(ParameterSharedTransformerEncoder, self).__init__(name=name, **kwargs)
        self.attention_mask = attention_mask
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.intermediate_act_fn = intermediate_act_fn
        self.num_attention_heads = int(self.hidden_size / 64)
        self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
        self.intermediate_size = int(4 * self.hidden_size)
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.initializer_range = initializer_range
        self.do_return_all_layers = do_return_all_layers
        self.attention_layer = Attention(
            self.num_attention_heads,
            self.attention_head_size,
            do_return_2d_tensor=True)
        self.attention_output_layer = layers.Dense(
            self.hidden_size,
            kernel_initializer=create_initializer(initializer_range))
        self.intermediate_layer = layers.Dense(
            self.intermediate_size,
            activation=self.intermediate_act_fn,
            kernel_initializer=create_initializer(initializer_range))
        self.output_layer = layers.Dense(
            self.hidden_size,
            kernel_initializer=create_initializer(initializer_range))


    def call(self, input_tensor):
        input_shape = self.get_input_tensor_shape(input_tensor)
        batch_size, seq_length, input_width = input_shape
        prev_output = reshape_to_matrix(input_tensor)

        all_layer_outputs = []
        for layer_idx in range(self.num_hidden_layers):
            layer_input = prev_output

            ### Attention
            attention_heads = []
            attention_head = self.attention_layer(
                layer_input,
                layer_input,
                self.attention_mask,
                batch_size,
                seq_length,
                seq_length)
            attention_heads.append(attention_head)
            attention_output = None
            if len(attention_heads) == 1:
                attention_output = attention_heads[0]
            else:
                # In the case where we have other sequences, we just concatenate
                # them to the self-attention head before the projection.
                attention_output = tf.concat(attention_heads, axis=-1)
            ### Attention output
            attention_output = self.attention_output_layer(attention_output)
            attention_output = dropout(attention_output, self.hidden_dropout_prob)
            attention_output = layer_norm(attention_output + layer_input)
            ### Intermediate
            intermediate_output = self.intermediate_layer(attention_output)
            ### Output
            layer_output = self.output_layer(intermediate_output)
            layer_output = dropout(layer_output, self.hidden_dropout_prob)
            layer_output = layer_norm(layer_output + attention_output)
            prev_output = layer_output

            all_layer_outputs.append(layer_output)

        if self.do_return_all_layers:
            final_outputs = []
            for layer_output in all_layer_outputs:
                final_output = reshape_from_matrix(layer_output, input_shape)
                final_outputs.append(final_output)
            return final_outputs
        else:
            final_output = reshape_from_matrix(prev_output, input_shape)
            return final_output


    def get_input_tensor_shape(self, input_tensor):
        input_shape = get_shape_list(input_tensor, expected_rank=3)
        # batch_size = input_shape[0]
        # seq_length = input_shape[1]
        input_width = input_shape[2]

        # The Transformer performs sum residuals on all layers so the input needs
        # to be the same as the hidden size.
        if input_width != self.hidden_size:
            raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
                                (input_width, self.hidden_size))
        
        # return batch_size, seq_length, input_width
        return input_shape

In [0]:
test = ParameterSharedTransformerEncoder()

In [196]:
test.__dict__["num_attention_heads"], test.__dict__["attention_head_size"]

(12, 64)

In [197]:
# dummy_input = tf.constant([[[i for i in range(768)], [i for i in range(768)]], [[i for i in range(768)], [i for i in range(768)]]])
dummy_input = tf.constant([[[i for i in range(768)]], [[i for i in range(768)]]], dtype=tf.float32)
dummy_input.shape

TensorShape([2, 1, 768])

In [0]:
result = test(tf.constant(dummy_input))

In [199]:
result.shape

TensorShape([2, 1, 768])

# Trial and Errors

In [0]:
def attention_layer(from_tensor,
                    to_tensor,
                    attention_mask=None,
                    num_attention_heads=1,
                    size_per_head=512,
                    query_act=None,
                    key_act=None,
                    value_act=None,
                    attention_probs_dropout_prob=0.0,
                    initializer_range=0.02,
                    do_return_2d_tensor=False,
                    batch_size=None,
                    from_seq_length=None,
                    to_seq_length=None):

  def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
                           seq_length, width):
    output_tensor = tf.reshape(
        input_tensor, [batch_size, seq_length, num_attention_heads, width])

    output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
    return output_tensor

  from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
  to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])

  if len(from_shape) != len(to_shape):
    raise ValueError(
        "The rank of `from_tensor` must match the rank of `to_tensor`.")

  if len(from_shape) == 3:
    batch_size = from_shape[0]
    from_seq_length = from_shape[1]
    to_seq_length = to_shape[1]
  elif len(from_shape) == 2:
    if (batch_size is None or from_seq_length is None or to_seq_length is None):
      raise ValueError(
          "When passing in rank 2 tensors to attention_layer, the values "
          "for `batch_size`, `from_seq_length`, and `to_seq_length` "
          "must all be specified.")

  # Scalar dimensions referenced here:
  #   B = batch size (number of sequences)
  #   F = `from_tensor` sequence length
  #   T = `to_tensor` sequence length
  #   N = `num_attention_heads`
  #   H = `size_per_head`

  from_tensor_2d = reshape_to_matrix(from_tensor)
  to_tensor_2d = reshape_to_matrix(to_tensor)

  # `query_layer` = [B*F, N*H]
  query_layer = tf.layers.dense(
      from_tensor_2d,
      num_attention_heads * size_per_head,
      activation=query_act,
      name="query",
      kernel_initializer=create_initializer(initializer_range))

  # `key_layer` = [B*T, N*H]
  key_layer = tf.layers.dense(
      to_tensor_2d,
      num_attention_heads * size_per_head,
      activation=key_act,
      name="key",
      kernel_initializer=create_initializer(initializer_range))

  # `value_layer` = [B*T, N*H]
  value_layer = tf.layers.dense(
      to_tensor_2d,
      num_attention_heads * size_per_head,
      activation=value_act,
      name="value",
      kernel_initializer=create_initializer(initializer_range))

  # `query_layer` = [B, N, F, H]
  query_layer = transpose_for_scores(query_layer, batch_size,
                                     num_attention_heads, from_seq_length,
                                     size_per_head)

  # `key_layer` = [B, N, T, H]
  key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
                                   to_seq_length, size_per_head)

  # Take the dot product between "query" and "key" to get the raw
  # attention scores.
  # `attention_scores` = [B, N, F, T]
  attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
  attention_scores = tf.multiply(attention_scores,
                                 1.0 / math.sqrt(float(size_per_head)))

  if attention_mask is not None:
    # `attention_mask` = [B, 1, F, T]
    attention_mask = tf.expand_dims(attention_mask, axis=[1])

    # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
    # masked positions, this operation will create a tensor which is 0.0 for
    # positions we want to attend and -10000.0 for masked positions.
    adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0

    # Since we are adding it to the raw scores before the softmax, this is
    # effectively the same as removing these entirely.
    attention_scores += adder

  # Normalize the attention scores to probabilities.
  # `attention_probs` = [B, N, F, T]
  attention_probs = tf.nn.softmax(attention_scores)

  # This is actually dropping out entire tokens to attend to, which might
  # seem a bit unusual, but is taken from the original Transformer paper.
  attention_probs = dropout(attention_probs, attention_probs_dropout_prob)

  # `value_layer` = [B, T, N, H]
  value_layer = tf.reshape(
      value_layer,
      [batch_size, to_seq_length, num_attention_heads, size_per_head])

  # `value_layer` = [B, N, T, H]
  value_layer = tf.transpose(value_layer, [0, 2, 1, 3])

  # `context_layer` = [B, N, F, H]
  context_layer = tf.matmul(attention_probs, value_layer)

  # `context_layer` = [B, F, N, H]
  context_layer = tf.transpose(context_layer, [0, 2, 1, 3])

  if do_return_2d_tensor:
    # `context_layer` = [B*F, N*H]
    context_layer = tf.reshape(
        context_layer,
        [batch_size * from_seq_length, num_attention_heads * size_per_head])
  else:
    # `context_layer` = [B, F, N*H]
    context_layer = tf.reshape(
        context_layer,
        [batch_size, from_seq_length, num_attention_heads * size_per_head])

  return context_layer
