In [1]:
import tensorflow as tf



In [2]:
class FNN(tf.keras.layers.Layer):
    def __init__(self, d_model, d_ff):

        super().__init__()
        self.dense1 = tf.keras.layers.Dense(d_ff, activation="relu")
        self.dense2 = tf.keras.layers.Dense(d_model) # no activation

    def call(self, x):
        return self.dense2(self.dense1(x))

In [3]:
x = tf.random.uniform((2, 6, 512))

In [4]:
ffn = FNN(d_model = 512, d_ff=2048)


In [5]:
ffn

<FNN name=fnn, built=False>

In [6]:
output = ffn(x)

In [7]:
output

<tf.Tensor: shape=(2, 6, 512), dtype=float32, numpy=
array([[[-8.35852772e-02, -4.55801524e-02,  4.02150929e-01, ...,
         -2.74178863e-01, -7.84548998e-01,  1.41986385e-01],
        [-1.07644871e-01,  1.74158692e-01,  4.90542978e-01, ...,
         -1.23695202e-01, -5.29394865e-01, -1.69134632e-01],
        [-8.01319331e-02,  2.08575070e-01,  3.63009125e-01, ...,
          1.83665663e-01, -2.78570652e-01, -6.65538758e-02],
        [ 4.61086631e-04,  1.92499086e-01,  1.86467752e-01, ...,
         -3.53020668e-01, -4.32115048e-01, -1.57496244e-01],
        [ 3.12806293e-02,  1.00791916e-01,  2.11452082e-01, ...,
         -3.02582026e-01, -2.43920624e-01, -7.78396949e-02],
        [-4.53638464e-01,  1.58476144e-01,  4.48831052e-01, ...,
         -3.92302752e-01, -3.92625391e-01, -1.62144527e-01]],

       [[-1.75863981e-01,  4.49056178e-02,  2.20573321e-01, ...,
         -1.16870470e-01, -3.56314421e-01, -1.48955151e-01],
        [-2.19064087e-01,  1.72210768e-01,  4.33883905e-01, ...

### Scaled Dot-Product Attention

In [8]:
import tensorflow as tf 
import numpy as np 

def scaled_dot_product_attention(Q, K, V, mask=None):

    """ 
    Q, K, V must have shape: (batch_size, num_heads, seq_len, depth)
    """

    dk = tf.cast(tf.shape(K)[-1], tf.float32)
    print(f"Depth of key: {dk}")

    scores = tf.matmul(Q, K, transpose_b=True)
    print(f"Scores: {scores}")
    scores = scores / tf.math.sqrt(dk)
    print(f"Scores after scaling: {scores}")

    if mask is not None:
        scores += (mask * -1e9)
    print(f"Scores after masking: {scores}")

    weights = tf.nn.softmax(scores, axis=-1)
    output = tf.matmul(weights, V)

    return output, weights

In [23]:
batch_size = 2
num_heads = 2
seq_len = 3
depth = 4

Q = tf.random.uniform((batch_size, num_heads, seq_len, depth))
K = tf.random.uniform((batch_size, num_heads, seq_len, depth))
V = tf.random.uniform((batch_size, num_heads, seq_len, depth))

print(Q)

output, weights = scaled_dot_product_attention(Q, K, V)

tf.Tensor(
[[[[0.13192785 0.06096387 0.8971474  0.96041584]
   [0.48912    0.02534592 0.26345062 0.5428848 ]
   [0.5472523  0.32141066 0.9289235  0.4597925 ]]

  [[0.3177898  0.9874586  0.3536203  0.48786688]
   [0.98315907 0.67036283 0.6021272  0.7656299 ]
   [0.03581572 0.44106424 0.873765   0.9436176 ]]]


 [[[0.26351404 0.94685376 0.5665641  0.14089334]
   [0.2653072  0.81594086 0.62459075 0.30572188]
   [0.7382891  0.04632449 0.03618693 0.91745996]]

  [[0.09616554 0.69231737 0.3298658  0.47897792]
   [0.01625562 0.74581444 0.7972766  0.41946185]
   [0.7451682  0.6699667  0.12812448 0.24165046]]]], shape=(2, 2, 3, 4), dtype=float32)
Depth of key: 4.0
Scores: [[[[1.0587585  1.0833192  1.4869896 ]
   [0.6707403  0.96783924 0.70198315]
   [1.2723552  1.1933345  1.3883924 ]]

  [[1.1699665  0.7106823  1.3150005 ]
   [1.3136827  0.97592485 1.8117888 ]
   [0.5738417  1.1524079  1.0484695 ]]]


 [[[1.2551565  0.9133693  1.2422342 ]
   [1.2884903  0.9173759  1.2582126 ]
   [1.4167737  1.1

In [24]:
output.shape, weights.shape

(TensorShape([2, 2, 3, 4]), TensorShape([2, 2, 3, 3]))

In [25]:
output

<tf.Tensor: shape=(2, 2, 3, 4), dtype=float32, numpy=
array([[[[0.73807085, 0.729236  , 0.6344718 , 0.29381454],
         [0.7079821 , 0.7370317 , 0.64335686, 0.29303288],
         [0.7290781 , 0.7360657 , 0.6410845 , 0.28820172]],

        [[0.6396093 , 0.41814047, 0.5314377 , 0.42670676],
         [0.6338031 , 0.40491855, 0.51910657, 0.40121862],
         [0.570755  , 0.3676917 , 0.49975106, 0.37539488]]],


       [[[0.39296493, 0.38775754, 0.43081352, 0.38847193],
         [0.3944624 , 0.38647643, 0.4304597 , 0.38673842],
         [0.3844385 , 0.39777744, 0.46474037, 0.4072292 ]],

        [[0.32698074, 0.48194206, 0.7055749 , 0.5351881 ],
         [0.32080248, 0.47820938, 0.6978057 , 0.5394257 ],
         [0.37857187, 0.51062906, 0.7075405 , 0.53352463]]]],
      dtype=float32)>

In [26]:
weights

<tf.Tensor: shape=(2, 2, 3, 3), dtype=float32, numpy=
array([[[[0.30758616, 0.3113867 , 0.38102722],
         [0.31487194, 0.36529875, 0.31982931],
         [0.3310168 , 0.31819323, 0.35079002]],

        [[0.34842852, 0.27693728, 0.37463424],
         [0.31975225, 0.2700663 , 0.4101814 ],
         [0.2775227 , 0.37062317, 0.35185412]]],


       [[[0.35255077, 0.297169  , 0.35028023],
         [0.35516202, 0.29501227, 0.34982577],
         [0.38862512, 0.33548805, 0.27588677]],

        [[0.29144812, 0.39926153, 0.3092903 ],
         [0.28383666, 0.42077142, 0.2953919 ],
         [0.35259444, 0.35211542, 0.29529017]]]], dtype=float32)>

### Multi-Head Attention