In [1]:
import tensorflow as tf
import numpy as np

x = tf.keras.Input(shape=[4, 3])
layer  = tf.keras.layers.MultiHeadAttention(num_heads=1, key_dim=2, use_bias=False)

output_tensor = layer(x, x)
print(output_tensor.shape)

(None, 4, 3)


In [2]:
q = np.array([[[ 0.4,  0.3 ]],
              [[-0.1, -0.1]],
              [[ 0.2, -0.1]]])
k = np.array([[[ 0.1,  0.2 ]],
              [[-0.3, -0.4]],
              [[-0.1,  0.2]]])
v = np.array([[[-0.2,  0.1 ]],
              [[-0.4,  0.2]],
              [[ 0.4, -0.6]]])
o = np.array([[[ 0.1, -0.1,  0.6 ],
               [ 0.9,  0.3,  0.1 ]]])

layer.set_weights([q, k, v, o])

In [4]:
key_value = np.array([1., 3., 2., 6., 2., 1., 5., 8., 4., 7., 3., 4.])
key_value = key_value.reshape((1, 4, 3))
key_value = tf.convert_to_tensor(key_value)

print(key_value.shape)
print(key_value)

(1, 4, 3)
tf.Tensor(
[[[1. 3. 2.]
  [6. 2. 1.]
  [5. 8. 4.]
  [7. 3. 4.]]], shape=(1, 4, 3), dtype=float64)


In [5]:
query = np.array([1., 6., 6., 1., 2., 4., 3., 8., 2., 3., 6., 5.])
query = query.reshape((1, 4, 3))
query = tf.convert_to_tensor(query)

print(query.shape)
print(query)

(1, 4, 3)
tf.Tensor(
[[[1. 6. 6.]
  [1. 2. 4.]
  [3. 8. 2.]
  [3. 6. 5.]]], shape=(1, 4, 3), dtype=float64)


In [6]:
output_tensor, weights = layer(key=key_value, value=key_value, query=query, return_attention_scores=True)
print(output_tensor.shape)
print(weights.shape)

(1, 4, 3)
(1, 1, 4, 4)


In [7]:
print(output_tensor[0])
print(weights[0, 0])

tf.Tensor(
[[-0.39957035  0.05262335 -0.86507285]
 [-0.3857062   0.04718071 -0.819083  ]
 [-0.407583    0.04010282 -0.82246053]
 [-0.31891015  0.06590582 -0.79602516]], shape=(4, 3), dtype=float32)
tf.Tensor(
[[0.3307562  0.29122677 0.21947964 0.15853739]
 [0.25756186 0.37732145 0.1217204  0.24339637]
 [0.24064699 0.36782053 0.12205947 0.269473  ]
 [0.20629655 0.48195118 0.05307306 0.25867924]], shape=(4, 4), dtype=float32)


## verify 

In [10]:
key_value = np.array([1., 3., 2., 6., 2., 1., 5., 8., 4., 7., 3., 4.])
key_value = key_value.reshape((4, 3))

print(key_value.shape)
print(key_value)

(4, 3)
[[1. 3. 2.]
 [6. 2. 1.]
 [5. 8. 4.]
 [7. 3. 4.]]


In [11]:
query = np.array([1., 6., 6., 1., 2., 4., 3., 8., 2., 3., 6., 5.])
query = query.reshape((4, 3))

print(query.shape)
print(query)

(4, 3)
[[1. 6. 6.]
 [1. 2. 4.]
 [3. 8. 2.]
 [3. 6. 5.]]


In [12]:
W_Q = np.array([[[ 0.4,  0.3 ]],
                [[-0.1, -0.1]],
                [[ 0.2, -0.1]]])
W_K = np.array([[[ 0.1,  0.2 ]],
                [[-0.3, -0.4]],
                [[-0.1,  0.2]]])
W_V = np.array([[[-0.2,  0.1 ]],
                [[-0.4,  0.2]],
                [[ 0.4, -0.6]]])

W_Q = W_Q.reshape((3, 2))
W_K = W_K.reshape((3, 2))
W_V = W_V.reshape((3, 2))

print(W_Q)
print(W_K)
print(W_V)

[[ 0.4  0.3]
 [-0.1 -0.1]
 [ 0.2 -0.1]]
[[ 0.1  0.2]
 [-0.3 -0.4]
 [-0.1  0.2]]
[[-0.2  0.1]
 [-0.4  0.2]
 [ 0.4 -0.6]]


In [13]:
Q = np.dot(query, W_Q)
K = np.dot(key_value, W_K)
V = np.dot(key_value, W_V)
 
print(Q)
print(K)
print(V)

[[ 1.  -0.9]
 [ 1.  -0.3]
 [ 0.8 -0.1]
 [ 1.6 -0.2]]
[[-1.  -0.6]
 [-0.1  0.6]
 [-2.3 -1.4]
 [-0.6  1. ]]
[[-0.6 -0.5]
 [-1.6  0.4]
 [-2.6 -0.3]
 [-1.  -1.1]]


In [15]:
def softmax(x):
    max_x = x.max(axis=1)
    max_x = max_x.reshape(max_x.shape[0], 1)
    
    e_x = np.exp(x-max_x)
    
    sum_e = e_x.sum(axis=1)    
    sum_e = sum_e.reshape(sum_e.shape[0], 1)
    
    return e_x / sum_e

buffer = np.dot(Q, K.T) / np.sqrt(2)
print(buffer.shape)
print(buffer)

(4, 4)
[[-0.32526912 -0.45254834 -0.73539105 -1.06066017]
 [-0.57982756 -0.1979899  -1.32936075 -0.6363961 ]
 [-0.52325902 -0.09899495 -1.20208153 -0.41012193]
 [-1.04651804 -0.1979899  -2.40416306 -0.82024387]]


In [16]:
alpha = softmax(buffer)
print(alpha)

[[0.33075618 0.29122678 0.21947966 0.15853739]
 [0.25756182 0.37732143 0.1217204  0.24339636]
 [0.24064699 0.36782053 0.12205949 0.26947299]
 [0.20629654 0.48195116 0.05307307 0.25867923]]


In [17]:
context_vector = np.dot(alpha, V)
print(context_vector)

[[-1.39360105 -0.2891224 ]
 [-1.31812076 -0.28210445]
 [-1.31972871 -0.30623341]
 [-1.29156899 -0.21083688]]


In [18]:
W_O = np.array([[[ 0.1, -0.1,  0.6 ],
               [ 0.9,  0.3,  0.1 ]]])
W_O = W_O.reshape((2, 3))
print(W_O)

[[ 0.1 -0.1  0.6]
 [ 0.9  0.3  0.1]]


In [19]:
output = np.dot(context_vector, W_O)
print(output)

[[-0.39957026  0.05262338 -0.86507287]
 [-0.38570608  0.04718074 -0.8190829 ]
 [-0.40758294  0.04010285 -0.82246057]
 [-0.31891009  0.06590583 -0.79602508]]


In [None]:
'''
[[-0.39957035  0.05262335 -0.86507285]
 [-0.3857062   0.04718071 -0.819083  ]
 [-0.407583    0.04010282 -0.82246053]
 [-0.31891015  0.06590582 -0.79602516]]
'''