In [None]:
import tensorflow as tf
import tensorflow.keras as tf_keras

class OptimizedMultiQueryAttentionLayerWithDownSampling(tf_keras.layers.Layer):
  """Multi Query Attention with spatial downsampling.

   3 parameters are introduced for the spatial downsampling:
   1. kv_strides: downsampling factor on Key and Values only.
   2. query_h_strides: vertical strides on Query only.
   3. query_w_strides: horizontal strides on Query only.

  This is an optimized version.
  1. Projections in Attention is explict written out as 1x1 Conv2D.
  2. Additional reshapes are introduced to bring a up to 3x speed up.
  """

  def __init__(
      self,
      num_heads: int,
      key_dim: int,
      value_dim: int,
      query_h_strides: int = 1,
      query_w_strides: int = 1,
      kv_strides: int = 1,
      dropout: float = 0,
      dw_kernel_size: int = 3,
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
  ):
    """Initializer.

    Args:
      num_heads: Number of attention heads.
      key_dim: Size of the attention key dimension.
      value_dim: Size of the attention value dimension.
      query_h_strides: Vertical stride size for query only.
      query_w_strides: Horizontal stride size for query only.
      kv_strides: Key and value stride size.
      dropout: Dropout probability (between 0 and 1).
      dw_kernel_size: Spatial dimension of the depthwise kernel.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: Momentum value for use with normalization moving average.
      norm_epsilon: Small float added to norm variance to avoid dividing by
        zero.
    """
    super().__init__()
    self._num_heads = num_heads
    self._key_dim = key_dim
    self._value_dim = value_dim
    self._query_h_strides = query_h_strides
    self._query_w_strides = query_w_strides
    self._kv_strides = kv_strides
    self._dw_kernel_size = dw_kernel_size
    self._dropout = dropout
    self._norm_momentum = norm_momentum
    self._norm_epsilon = norm_epsilon

    if use_sync_bn:
      self._norm = tf_keras.layers.experimental.SyncBatchNormalization
    else:
      self._norm = tf_keras.layers.BatchNormalization
    if tf_keras.backend.image_data_format() == 'channels_last':
      self._bn_axis = -1
    else:
      self._bn_axis = 1

  def build(self, input_shape):
    """Create layer state."""
    self._channel_dim = input_shape[-1]

    if self._query_h_strides > 1 or self._query_w_strides > 1:
      self._query_downsampling = tf_keras.layers.AvgPool2D(
          pool_size=(self._query_h_strides, self._query_w_strides),
          padding='same',
      )
      self._query_downsampling_norm = self._norm(
          axis=self._bn_axis,
          momentum=self._norm_momentum,
          epsilon=self._norm_epsilon,
      )

    self._query_proj = tf_keras.layers.Conv2D(
        filters=self._num_heads * self._key_dim,
        kernel_size=1,
        strides=1,
        padding='valid',
        use_bias=False,
    )

    if self._kv_strides > 1:
      self._key_dw_conv = tf_keras.layers.DepthwiseConv2D(
          kernel_size=self._dw_kernel_size,
          strides=self._kv_strides,
          padding='same',
          depth_multiplier=1,
          use_bias=False,
      )
      self._key_dw_norm = self._norm(
          axis=self._bn_axis,
          momentum=self._norm_momentum,
          epsilon=self._norm_epsilon,
      )
    self._key_proj = tf_keras.layers.Conv2D(
        filters=self._key_dim,
        kernel_size=1,
        strides=1,
        padding='same',
        use_bias=False,
    )

    if self._kv_strides > 1:
      self._value_dw_conv = tf_keras.layers.DepthwiseConv2D(
          kernel_size=self._dw_kernel_size,
          strides=self._kv_strides,
          padding='same',
          depth_multiplier=1,
          use_bias=False,
      )
      self._value_dw_norm = self._norm(
          axis=self._bn_axis,
          momentum=self._norm_momentum,
          epsilon=self._norm_epsilon,
      )
    self._value_proj = tf_keras.layers.Conv2D(
        filters=self._value_dim,
        kernel_size=1,
        strides=1,
        padding='same',
        use_bias=False,
    )

    self._output_proj = tf_keras.layers.Conv2D(
        filters=self._channel_dim,
        kernel_size=1,
        strides=1,
        padding='valid',
        use_bias=False,
    )
    if self._query_h_strides > 1 or self._query_w_strides > 1:
      self._upsampling = tf_keras.layers.UpSampling2D(
          size=(self._query_h_strides, self._query_w_strides),
          interpolation='bilinear',
      )
    self._dropout_layer = tf_keras.layers.Dropout(rate=self._dropout)

  def _reshape_input(self, t):
    """Reshapes a tensor to three dimensions, keeping the first and last."""
    s = tf.shape(t)
    # Propagate the shape statically where possible.
    static_num = t.shape[1:-1].num_elements()
    num = static_num or tf.math.reduce_prod(s[1:-1])
    return tf.ensure_shape(
        tf.reshape(t, [s[0], num, s[-1]]), [t.shape[0], static_num, t.shape[-1]]
    )

  def _reshape_projected_query(self, t, num_heads, h_px, w_px, key_dim):
    """Reshapes projected query: [b, n, n, h x k] -> [b, n x n, h, k]."""
    s = tf.shape(t)
    return tf.reshape(t, [s[0], h_px * w_px, num_heads, key_dim])

  def _get_pixels(self, t):
    s = tf.shape(t)
    static_num = t.shape[1]
    px = static_num or s[1]
    return px

  def _reshape_output(self, t, num_heads, h_px, w_px):
    """Reshape output:[b, n x n x h, k] -> [b, n, n, hk]."""
    s = tf.shape(t)
    # Propagate the shape statically where possible.
    static_last_dim = t.shape[-1]
    last_dim = (static_last_dim or s[-1]) * num_heads
    return tf.reshape(t, [t.shape[0] or s[0], h_px, w_px, last_dim])

  def call(self, inputs):
    """Run layer computation."""
    x = inputs
    px = self._get_pixels(x)

    if self._query_h_strides > 1 or self._query_w_strides > 1:
      q = self._query_downsampling(x)
      q = self._query_downsampling_norm(q)
      q = self._query_proj(q)
    else:
      q = self._query_proj(x)

    # desired q shape: [b, n x n, h, k] - [b, l, h, k]
    q = self._reshape_projected_query(
        q,
        self._num_heads,
        px // self._query_h_strides,
        px // self._query_w_strides,
        self._key_dim,
    )

    if self._kv_strides > 1:
      k = self._key_dw_conv(x)
      k = self._key_dw_norm(k)
      k = self._key_proj(k)
    else:
      k = self._key_proj(x)
    # output shape of k: [b, k, p], p = m x m
    k = self._reshape_input(k)

    # desired q shape: [b, n x n, h, k]
    # desired k shape: [b, m x m, k]
    # desired logits shape: [b, n x n, h, m x m]
    logits = tf.einsum('blhk,bpk->blhp', q, k)

    logits = logits / tf.math.sqrt(tf.cast(self._key_dim, x.dtype))

    attention_scores = self._dropout_layer(tf.nn.softmax(logits))

    if self._kv_strides > 1:
      v = self._value_dw_conv(x)
      v = self._value_dw_norm(v)
      v = self._value_proj(v)
    else:
      v = self._value_proj(x)

    # output shape of v: [ b, p, k], p = m x m
    v = self._reshape_input(v)
    o = tf.einsum('blhp,bpk->blhk', attention_scores, v)
    # reshape o into [b, n, n, hk]
    o = self._reshape_output(
        o,
        self._num_heads,
        px // self._query_h_strides,
        px // self._query_w_strides,
    )
    if self._query_h_strides > 1 or self._query_w_strides > 1:
      o = self._upsampling(o)

    result = self._output_proj(o)

    return tf.ensure_shape(tf.reshape(result, tf.shape(x)), x.shape)
