In [1]:
import pathlib
import random
import string
import re
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import TextVectorization

In [2]:
x = tf.keras.Input(shape=[4, 3])
layer  = tf.keras.layers.MultiHeadAttention(num_heads=1, key_dim=2, use_bias=False)

output_tensor = layer(x, x)
print(output_tensor.shape)

(None, 4, 3)


In [3]:
q = np.array([[[ 0.4,  0.3 ]],
              [[-0.1, -0.1]],
              [[ 0.2, -0.1]]])
k = np.array([[[ 0.1,  0.2 ]],
              [[-0.3, -0.4]],
              [[-0.1,  0.2]]])
v = np.array([[[-0.2,  0.1 ]],
              [[-0.4,  0.2]],
              [[ 0.4, -0.6]]])
o = np.array([[[ 0.1, -0.1,  0.6 ],
               [ 0.9,  0.3,  0.1 ]]])

layer.set_weights([q, k, v, o])

In [4]:
data = np.array([1., 3., 2., 6., 2., 1., 5., 8., 4., 7., 3., 4.])
data = data.reshape((1, 4, 3))
data = tf.convert_to_tensor(data)

print(data.shape)
print(data)

(1, 4, 3)
tf.Tensor(
[[[1. 3. 2.]
  [6. 2. 1.]
  [5. 8. 4.]
  [7. 3. 4.]]], shape=(1, 4, 3), dtype=float64)


In [11]:
def get_causal_attention_mask(batch_size, seq_length):
    i = tf.range(seq_length)
    i = tf.reshape(i, (seq_length, 1))    
    j = tf.range(seq_length)
    
    mask = i >= j
    mask = tf.reshape(mask, [1, seq_length, seq_length])
    
    mult = tf.convert_to_tensor([batch_size, 1, 1], dtype=tf.int32)
    result = tf.tile(mask, mult)
    result = tf.cast(result, dtype=tf.int32)
    
    return result

mask = get_causal_attention_mask(1, 4)
print(mask.shape)
print(mask)

(1, 4, 4)
tf.Tensor(
[[[1 0 0 0]
  [1 1 0 0]
  [1 1 1 0]
  [1 1 1 1]]], shape=(1, 4, 4), dtype=int32)


In [7]:
output_tensor, weights = layer(data, data, return_attention_scores=True, attention_mask=mask)
print(output_tensor.shape)
print(weights.shape)

(1, 4, 3)
(1, 1, 4, 4)


In [8]:
print(output_tensor[0])
print(weights[0, 0])

tf.Tensor(
[[-0.51       -0.09       -0.4100001 ]
 [ 0.15930527  0.2587929  -0.8907686 ]
 [ 0.05923991  0.2129699  -0.84682053]
 [-0.21573429  0.11351576 -0.8429406 ]], shape=(4, 3), dtype=float32)
tf.Tensor(
[[1.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
 [5.7316449e-02 9.4268352e-01 0.0000000e+00 0.0000000e+00]
 [1.7421028e-01 8.0240512e-01 2.3384636e-02 0.0000000e+00]
 [2.4871059e-02 6.6630757e-01 5.4240436e-04 3.0827892e-01]], shape=(4, 4), dtype=float32)


## verify

In [8]:
def get_causal_attention_mask(bsize, seq_length):
    i = tf.range(seq_length)[:, tf.newaxis]
    j = tf.range(seq_length)
    mask = tf.cast(j > i, dtype="int32")
    mask = tf.reshape(mask, (seq_length, seq_length))
    mask = tf.cast(mask, dtype=tf.int32)
    
    return mask

mask = get_causal_attention_mask(1, 4).numpy()
mask[mask==1] = -2147483648
print(mask)

[[          0 -2147483648 -2147483648 -2147483648]
 [          0           0 -2147483648 -2147483648]
 [          0           0           0 -2147483648]
 [          0           0           0           0]]


In [9]:
data = np.array([1., 3., 2., 6., 2., 1., 5., 8., 4., 7., 3., 4.])
data = data.reshape((1, 4, 3))
print(data.shape)
print(data)

(1, 4, 3)
[[[1. 3. 2.]
  [6. 2. 1.]
  [5. 8. 4.]
  [7. 3. 4.]]]


In [10]:
import tensorflow as tf
import numpy as np

W_Q = np.array([[[ 0.4,  0.3 ]],
                [[-0.1, -0.1]],
                [[ 0.2, -0.1]]])
W_K = np.array([[[ 0.1,  0.2 ]],
                [[-0.3, -0.4]],
                [[-0.1,  0.2]]])
W_V = np.array([[[-0.2,  0.1 ]],
                [[-0.4,  0.2]],
                [[ 0.4, -0.6]]])

W_Q = W_Q.reshape((3, 2))
W_K = W_K.reshape((3, 2))
W_V = W_V.reshape((3, 2))

print(W_Q)
print(W_K)
print(W_V)

[[ 0.4  0.3]
 [-0.1 -0.1]
 [ 0.2 -0.1]]
[[ 0.1  0.2]
 [-0.3 -0.4]
 [-0.1  0.2]]
[[-0.2  0.1]
 [-0.4  0.2]
 [ 0.4 -0.6]]


In [11]:
data = data.reshape((4, 3))
print(data.shape)

mask = mask.reshape((4, 4))
print(mask.shape)

(4, 3)
(4, 4)


In [12]:
Q = np.dot(data, W_Q)
K = np.dot(data, W_K)
V = np.dot(data, W_V)
 
print(Q)
print(K)
print(V)

[[ 0.5 -0.2]
 [ 2.4  1.5]
 [ 2.   0.3]
 [ 3.3  1.4]]
[[-1.  -0.6]
 [-0.1  0.6]
 [-2.3 -1.4]
 [-0.6  1. ]]
[[-0.6 -0.5]
 [-1.6  0.4]
 [-2.6 -0.3]
 [-1.  -1.1]]


In [13]:
def softmax(x):
    max_x = x.max(axis=1)
    max_x = max_x.reshape(max_x.shape[0], 1)
    
    e_x = np.exp(x-max_x)
    
    sum_e = e_x.sum(axis=1)    
    sum_e = sum_e.reshape(sum_e.shape[0], 1)
    
    return e_x / sum_e

buffer = (np.dot(Q, K.T)+mask) / np.sqrt(2)
print(buffer.shape)

(4, 4)


In [14]:
buffer

array([[-2.68700577e-01, -1.51850025e+09, -1.51850025e+09,
        -1.51850025e+09],
       [-2.33345238e+00,  4.66690476e-01, -1.51850026e+09,
        -1.51850025e+09],
       [-1.54149278e+00, -1.41421356e-02, -3.54967604e+00,
        -1.51850025e+09],
       [-2.92742207e+00,  3.60624458e-01, -6.75286976e+00,
        -4.10121933e-01]])

In [15]:
alpha = softmax(buffer)
print(alpha)

[[1.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [5.73164569e-02 9.42683543e-01 0.00000000e+00 0.00000000e+00]
 [1.74210257e-01 8.02405097e-01 2.33846467e-02 0.00000000e+00]
 [2.48710601e-02 6.66307594e-01 5.42404741e-04 3.08278941e-01]]


In [16]:
context_vector = np.dot(alpha, V)
print(context_vector)

[[-0.6        -0.5       ]
 [-1.54268354  0.34841519]
 [-1.44917439  0.22684152]
 [-1.39070398 -0.08518205]]


In [17]:
W_O = np.array([[[ 0.1, -0.1,  0.6 ],
               [ 0.9,  0.3,  0.1 ]]])
W_O = W_O.reshape((2, 3))
print(W_O)

[[ 0.1 -0.1  0.6]
 [ 0.9  0.3  0.1]]


In [18]:
output = np.dot(context_vector, W_O)
print(output)

[[-0.51       -0.09       -0.41      ]
 [ 0.15930532  0.25879291 -0.89076861]
 [ 0.05923993  0.21296989 -0.84682048]
 [-0.21573424  0.11351578 -0.84294059]]


In [None]:
'''keras
[[-0.51       -0.09       -0.4100001 ]
 [ 0.15930527  0.2587929  -0.8907686 ]
 [ 0.05923991  0.2129699  -0.84682053]
 [-0.21573429  0.11351576 -0.8429406 ]], shape=(4, 3), dtype=float32)
'''