In [1]:
from numpy import random
from tensorflow import matmul, math, cast, float32
from tensorflow.keras.layers import Layer
from tensorflow.keras.backend import softmax

input_seq_length = 5  # Maximum length of the input sequence
d_k = 64  # Dimensionality of the linearly projected queries and keys
d_v = 64  # Dimensionality of the linearly projected values
batch_size = 64  # Batch size from the training process

# Implementing the Scaled-Dot Product Attention
class DotProductAttention(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, queries, keys, values, d_k, mask=None):
        # Scoring the queries against the keys after transposing the latter, and scaling
        scores = matmul(queries, keys, transpose_b=True) / math.sqrt(cast(d_k, float32))

        # Apply mask to the attention scores
        if mask is not None:
            scores += -1e9 * mask

        # Computing the weights by a softmax operation
        weights = softmax(scores)

        # Computing the attention by a weighted sum of the value vectors
        return matmul(weights, values)

queries = random.random((batch_size, input_seq_length, d_k))
keys = random.random((batch_size, input_seq_length, d_k))
values = random.random((batch_size, input_seq_length, d_v))

attention = DotProductAttention()
print(attention(queries, keys, values, d_k))

tf.Tensor(
[[[0.46470755 0.7384408  0.2860181  ... 0.31849805 0.48213288 0.4744081 ]
  [0.45854688 0.7264023  0.2824839  ... 0.3174438  0.47815514 0.46740547]
  [0.46358797 0.7459711  0.28433827 ... 0.32279712 0.48154926 0.49903095]
  [0.47782344 0.747559   0.29384312 ... 0.32176462 0.48443544 0.46941072]
  [0.46284714 0.7325383  0.26369342 ... 0.30441666 0.50560844 0.4716068 ]]

 [[0.60151136 0.27857032 0.62798256 ... 0.46278474 0.26041555 0.42701924]
  [0.6013931  0.2806837  0.6344434  ... 0.4687489  0.26205397 0.42693996]
  [0.60211295 0.26842213 0.6249392  ... 0.45708007 0.25561857 0.4305825 ]
  [0.59268886 0.30392352 0.61443347 ... 0.4551766  0.27553713 0.44794706]
  [0.6004778  0.2865997  0.6208826  ... 0.46500584 0.2733471  0.43604922]]

 [[0.59839356 0.65023565 0.4653538  ... 0.56444085 0.70241755 0.64458144]
  [0.5837182  0.62851477 0.4857946  ... 0.550605   0.7079192  0.65832466]
  [0.5659439  0.61847943 0.47208583 ... 0.54991466 0.7142468  0.67001987]
  [0.60217845 0.629879 