In [1]:
import jax.numpy as np
import jax.nn as nn
from jax import grad, jit, vmap
from jax import random

In [2]:
key = random.PRNGKey(0)



In [3]:
xs = random.normal(key, (10, 3), dtype=np.float32)

In [4]:
xs

DeviceArray([[-0.28371075,  0.9368161 , -1.0050074 ],
             [ 1.4165013 ,  1.05433   ,  0.9108126 ],
             [-0.42656714,  0.98618793, -0.5575325 ],
             [ 0.01532494, -2.0785687 ,  0.554837  ],
             [ 0.9142364 ,  0.57445955,  0.72278625],
             [ 0.12106168, -0.32373545,  1.6234994 ],
             [ 0.24500382, -1.3809782 , -0.6111238 ],
             [ 0.14037243,  0.8410042 , -1.094358  ],
             [-1.0775021 , -1.1396459 , -0.5933381 ],
             [-0.15576522, -0.38321453, -1.1144515 ]], dtype=float32)

In [5]:
def attention0(qs, ks, vs):
    def fn_w(i):
        return [x for x in nn.softmax(np.stack([np.dot(qs[i], k) for k in ks]))]
    def fn_h(i):
        w = fn_w(i)
        return np.sum([x * vs[j] for j, x in enumerate(w)], axis=0)
    return [fn_h(i) for i in range(len(qs))] 

In [6]:
%timeit attention0(xs, xs, xs)

211 ms ± 1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
attention0(xs, xs, xs)

[DeviceArray([-0.11803952,  0.6514982 , -0.7964032 ], dtype=float32),
 DeviceArray([1.186598  , 0.87185955, 0.8337545 ], dtype=float32),
 DeviceArray([-0.08570289,  0.65454096, -0.6735319 ], dtype=float32),
 DeviceArray([-0.01892377, -1.854406  ,  0.39680448], dtype=float32),
 DeviceArray([0.929942 , 0.6487618, 0.7303074], dtype=float32),
 DeviceArray([ 0.32398266, -0.35916838,  1.1163737 ], dtype=float32),
 DeviceArray([-0.08826812, -1.3582238 , -0.1771639 ], dtype=float32),
 DeviceArray([-0.04798152,  0.6310004 , -0.76828456], dtype=float32),
 DeviceArray([-0.48142552, -1.1528716 , -0.38736904], dtype=float32),
 DeviceArray([-0.21049924, -0.36080134, -0.68926257], dtype=float32)]

In [8]:
def attention1(qs, ks, vs):
    def fn_w(i):
        return nn.softmax(vmap(np.vdot, (0, None), 0)(ks, qs[i]))
    def fn_h(i):
        w = fn_w(i)
        return np.sum([x * vs[j] for j, x in enumerate(w)], axis=0)
    return [fn_h(i) for i in range(len(qs))] 

In [9]:
%timeit attention1(xs, xs, xs)

128 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
def attention2(qs, ks, vs):
    def fn_w(i):
        return nn.softmax(vmap(np.vdot, (0, None), 0)(ks, qs[i]))
    def fn_h(i):
        w = fn_w(i)
        return np.matmul(w,  vs)
    return [fn_h(i) for i in range(len(qs))] 

In [11]:
%timeit attention2(xs, xs, xs)

25.6 ms ± 801 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [12]:
def one_query(q, ks, vs):
    w = nn.softmax(vmap(lambda x, y: np.vdot(x, y), (0, None), 0)(ks, q))
    return np.matmul(w,  vs)

def attention3(qs, ks, vs):
    return vmap(one_query, (0, None, None), 0)(qs, ks, vs)

In [13]:
%timeit attention3(xs, xs, xs)

4.38 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
@jit
def attention3_jit(qs, ks, vs):
    return vmap(jit(one_query), (0, None, None), 0)(qs, ks, vs)

In [15]:
%timeit attention3_jit(xs, xs, xs)

117 µs ± 1.24 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [16]:
def attention4(qs, ks, vs):
    return np.matmul(nn.softmax(np.matmul(qs, ks.T)), vs)

In [17]:
%timeit attention4(xs, xs, xs)

971 µs ± 8.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [18]:
attention4_jit = jit(attention4)
%timeit attention4_jit(xs, xs, xs)

124 µs ± 1.54 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
