In [7]:
import numpy as np


def softmax(x, axis=None):
    return np.exp(x) / np.sum(np.exp(x), axis=axis)


def tensor_dot(q, k):
    b = softmax((k @ q) / np.sqrt(q.shape[0]))
    return b


i_query = np.random.normal(size=(4,))
i_keys = np.random.normal(size=(11, 4))

b = tensor_dot(i_query, i_keys)
print("b = ", b)

b =  [0.11948588 0.05694906 0.26137693 0.04176397 0.07714815 0.0220334
 0.12616168 0.06698716 0.05799734 0.12461754 0.04547888]


In [8]:
def attention_layer(q, k, v):
    b = tensor_dot(q, k)
    return b @ v


i_values = np.random.normal(size=(11, 2))
attention_layer(i_query, i_keys, i_values)

array([-0.70352711, -0.30749631])

In [9]:
def batched_tensor_dot(q, k):
    # a will be batch x seq x feature dim
    # which is N x N x 4
    # batched dot product in einstein notation
    a = np.einsum("ij,kj->ik", q, k) / np.sqrt(q.shape[0])
    # now we softmax over sequence
    b = softmax(a, axis=1)
    return b


def self_attention(x):
    b = batched_tensor_dot(x, x)
    return b @ x


i_batched_query = np.random.normal(size=(11, 4))
self_attention(i_batched_query)

array([[-0.14166611,  0.82768644, -0.20830393, -0.13968515],
       [ 0.62958833,  1.52217987, -0.98164199, -0.81337288],
       [-0.37721688,  0.51519563,  1.01461165, -0.50778403],
       [-0.15620989,  0.52947983, -0.00580699, -0.22095754],
       [-0.49284163,  0.70659222,  1.29013892, -0.81564934],
       [-0.21432032,  1.20800275, -0.44441405, -0.27398989],
       [-0.3590753 ,  0.64706809,  0.71811405, -0.49860559],
       [-1.32671927,  0.48657072,  1.42693456, -0.13747287],
       [-0.26645251,  0.71149397,  0.23760909,  0.06369302],
       [-1.70293759,  0.64473755,  0.7384292 ,  0.53364533],
       [-0.26003857,  0.63070486, -0.30909207, -0.01952356]])

In [13]:
batched_tensor_dot(i_batched_query, i_batched_query)[0:3]

array([[0.12769163, 0.09124732, 0.03532783, 0.09123936, 0.03330473,
        0.12311383, 0.05293832, 0.02072694, 0.10300264, 0.03439216,
        0.11929049],
       [0.15402048, 0.51885061, 0.03003771, 0.13019326, 0.0341949 ,
        0.16123542, 0.05393516, 0.00488009, 0.05912303, 0.00521977,
        0.10893934],
       [0.04492095, 0.0226277 , 0.20560824, 0.07523942, 0.16833385,
        0.02918269, 0.13766222, 0.10261309, 0.08268766, 0.03375067,
        0.02883877]])

In [14]:
# weights should be input feature_dim -> desired output feature_dim
w_q = np.random.normal(size=(4, 4))
w_k = np.random.normal(size=(4, 4))
w_v = np.random.normal(size=(4, 2))


def trainable_self_attention(x, w_q, w_k, w_v):
    q = x @ w_q
    k = x @ w_k
    v = x @ w_v
    b = batched_tensor_dot(q, k)
    return b @ v


trainable_self_attention(i_batched_query, w_q, w_k, w_v)

array([[-9.30191444e+00,  1.33519393e+01],
       [-1.07976527e+01,  1.71696788e+01],
       [-2.06228584e-01,  5.26121744e-01],
       [-2.07197807e+00,  4.16864276e+00],
       [-7.24454804e+00,  1.73452878e+01],
       [-3.09215914e+02,  5.69863172e+02],
       [-7.20045416e+00,  1.60051619e+01],
       [-1.29332446e+00,  3.26045502e+00],
       [-8.53170329e-01,  8.68704900e-01],
       [-5.67902393e+00,  1.03121008e+01],
       [-3.09908251e+00,  4.59573106e+00]])

In [15]:
w_q_h1 = np.random.normal(size=(4, 4))
w_k_h1 = np.random.normal(size=(4, 4))
w_v_h1 = np.random.normal(size=(4, 2))
w_q_h2 = np.random.normal(size=(4, 4))
w_k_h2 = np.random.normal(size=(4, 4))
w_v_h2 = np.random.normal(size=(4, 2))
w_h = np.random.normal(size=2)


def multihead_attention(x, w_q_h1, w_k_h1, w_v_h1, w_q_h2, w_k_h2, w_v_h2):
    h1_out = trainable_self_attention(x, w_q_h1, w_k_h1, w_v_h1)
    h2_out = trainable_self_attention(x, w_q_h2, w_k_h2, w_v_h2)
    # join along last axis so we can use dot.
    all_h = np.stack((h1_out, h2_out), -1)
    return all_h @ w_h


multihead_attention(i_batched_query, w_q_h1, w_k_h1, w_v_h1, w_q_h2, w_k_h2, w_v_h2)

array([[ 5.41295587e-01, -1.24346344e-02],
       [ 5.22764962e+00, -7.50054109e+00],
       [-1.73216017e-01,  2.90437943e-01],
       [ 1.14987061e-01, -2.44625446e-01],
       [ 2.81643642e+00, -1.71539036e+00],
       [ 3.54703195e+00,  7.77289153e-02],
       [ 8.32457972e-01, -3.56496920e-01],
       [ 1.24207454e+00,  4.60469372e+00],
       [ 2.77992592e-01,  5.84268211e-01],
       [ 2.96992346e+02,  3.59071906e+02],
       [ 1.24000150e+00,  1.35369927e+00]])