In [None]:
from IPython import get_ipython
from IPython.display import display
# %%
from tensorflow import math, matmul, reshape, shape, transpose, cast, float32, concat
from tensorflow.keras.layers import Dense, Layer
from tensorflow.keras.backend import softmax
# Implementing the Scaled-Dot Product Attention

class DotProductAttention(Layer):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def call(self, queries, keys, values, mask=None):
    d_k = queries.shape[-1]
    # Scoring the queries against the keys after transposing the latter, and scaling
    scores = matmul(queries, keys, transpose_b=True) / math.sqrt(cast(d_k, float32))
    # Apply mask to the attention scores
    if mask is not None:
      scores += -1e9 * mask
    # Computing the weights by a softmax operation
    weights = softmax(scores)
    # Computing the attention by a weighted sum of the value vectors
    return matmul(weights, values)
    # Implementing the Multi-Head Attention

class MultiHeadAttention(Layer):
  def __init__(self, h, d_k, d_v, d_model, **kwargs):
    super().__init__(**kwargs)
    self.attention = DotProductAttention() # Scaled dot product attention
    self.heads = h # Number of attention heads to use
    self.d_k = d_k # Dimensionality of the linearly projected queries and keys
    self.d_v = d_v # Dimensionality of the linearly projected values
    self.d_model = d_model # Dimensionality of the model
    self.W_q = Dense(d_k) # Learned projection matrix for the queries
    self.W_k = Dense(d_k) # Learned projection matrix for the keys
    self.W_v = Dense(d_v) # Learned projection matrix for the values
    self.W_o = Dense(d_model) # Learned projection matrix for the multi-head output

  def reshape_tensor(self, x, heads, flag):
    if flag:
      # Tensor shape after reshaping and transposing:
      # (batch_size, heads, seq_length, -1)
      x = reshape(x, shape=(shape(x)[0], shape(x)[1], heads, -1))
      x = transpose(x, perm=(0, 2, 1, 3))
    else:
        x = transpose(x, perm=(0, 2, 1, 3))
        x_shape = shape(x)
        new_shape = (x_shape[0], x_shape[1], x_shape[2] * x_shape[3])
        x = reshape(x, new_shape)

    return x

  def call(self, queries, keys, values, mask=None):
    # Rearrange the queries to be able to compute all heads in parallel
    q_reshaped = self.reshape_tensor(self.W_q(queries), self.heads, True)
    # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)
    # Rearrange the keys to be able to compute all heads in parallel
    k_reshaped = self.reshape_tensor(self.W_k(keys), self.heads, True)
    # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)
    # Rearrange the values to be able to compute all heads in parallel
    v_reshaped = self.reshape_tensor(self.W_v(values), self.heads, True)
    # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)
    # Compute the multi-head attention output using the reshaped queries,
    # keys, and values
    o_reshaped = self.attention(q_reshaped, k_reshaped, v_reshaped, mask=mask)
    # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)
    # Rearrange back the output into concatenated form
    output = self.reshape_tensor(o_reshaped, self.heads, False)
    # Resulting tensor shape: (batch_size, input_seq_length, d_model)
    return self.W_o(output)
# %%
from numpy import random
input_seq_length = 5 # Maximum length of the input sequence
h = 8 # Number of self-attention heads
d_k = 64 # Dimensionality of the linearly projected queries and keys
d_v = 64 # Dimensionality of the linearly projected values
d_model = 512 # Dimensionality of the model sub-layers' outputs
batch_size = 64 # Batch size from the training process
queries = random.random((batch_size, input_seq_length, d_k))
keys = random.random((batch_size, input_seq_length, d_k))
values = random.random((batch_size, input_seq_length, d_v))
multihead_attention = MultiHeadAttention(h, d_k, d_v, d_model)
print(multihead_attention(queries, keys, values))

tf.Tensor(
[[[ 0.14855118  0.20013703 -0.02880823 ...  0.01359212 -0.44502118
    0.06768567]
  [ 0.13617986  0.19600089 -0.02540962 ...  0.02122036 -0.43507555
    0.06835034]
  [ 0.13912079  0.20453653 -0.02111277 ...  0.01619427 -0.4403253
    0.07720984]
  [ 0.13880375  0.20353049 -0.02084962 ...  0.01248611 -0.4367996
    0.0766429 ]
  [ 0.1421045   0.19639619 -0.02397741 ...  0.01725038 -0.44577417
    0.07329702]]

 [[ 0.28757727  0.09672768 -0.04910994 ...  0.15489005 -0.63113034
    0.19256432]
  [ 0.29114443  0.08817305 -0.03802683 ...  0.15269296 -0.63121635
    0.19065249]
  [ 0.29930386  0.09466944 -0.04478071 ...  0.14228976 -0.6378579
    0.20184374]
  [ 0.29785025  0.09283872 -0.04667297 ...  0.14812633 -0.6405736
    0.19957878]
  [ 0.2941439   0.0961503  -0.04171791 ...  0.15648103 -0.6281992
    0.1887621 ]]

 [[ 0.27810428  0.39767057 -0.03655063 ... -0.1052499  -0.5261304
    0.06247696]
  [ 0.28281605  0.39534903 -0.044312   ... -0.09199257 -0.51383215
    0.06096