In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# Import some additional JAX and dataloader helpers
from jax.scipy.special import logsumexp
from jax.example_libraries import optimizers

import torch
from torchvision import datasets, transforms
import mediapy as mpy
import time
from tqdm.auto import tqdm


import numpy as onp
import jax.numpy as np
from jax import grad, jit, vmap, value_and_grad
from jax import random

In [2]:
# Generate key which is used to generate random numbers
key = random.PRNGKey(1)



# Attention layer

$$
\text{Given } Q \in \mathbb{R}^{B \times d_k}, K \in \mathbb{R}^{B \times d_k}, V \in \mathbb{R}^{B \times d_v} \\
Attention(Q, K, V) = softmax \left( \frac{QK^T}{\sqrt{d_k}} \right) V
$$

In [15]:
from attention.layers.encoder import EncoderBlock
import haiku as hk
import jax.numpy as jnp
import jax

In [16]:
def encoder(x):
    encoder = EncoderBlock(num_heads=2, key_size=16, value_size=32, model_size=8, name='blob')
    return encoder(x)

In [17]:
encoder = hk.transform(encoder)
rng = jax.random.PRNGKey(42)
x = jnp.ones([8, 224, 224, 3])
params = encoder.init(rng, x)
encoded = encoder.apply(params, rng, x)

TypeError: add got incompatible shapes for broadcasting: (8, 224, 224, 8), (8, 224, 224, 3).

In [46]:
softmax(np.array([[1, 2, 0], [0, 1, 2]]))

DeviceArray([[0.24472846, 0.66524094, 0.09003057],
             [0.09003057, 0.24472846, 0.66524094]], dtype=float32)

In [47]:
q = np.array([[0, 0, 1], [1, 0, 0]])
k = np.array([[0, 0, 1], [0, 0, 1]])
v = np.array([[1], [2]])
attention_layer(q, k, v)

(DeviceArray([[1.5],
              [1.5]], dtype=float32),
 DeviceArray([[0.5, 0.5],
              [0.5, 0.5]], dtype=float32),
 DeviceArray([[1, 1],
              [0, 0]], dtype=int32))

In [48]:
q

DeviceArray([[0, 0, 1],
             [1, 0, 0]], dtype=int32)

In [23]:
k.T

DeviceArray([[0, 0],
             [0, 0],
             [1, 1]], dtype=int32)

In [13]:
d_k = 10
d_v = 5
n = 3
q = random.uniform(key, (n, d_k))
k = random.uniform(key, (n, d_k))
v = random.uniform(key, (n, d_v))

attention_layer(q, k, v)

DeviceArray([[0.15055665, 0.02880473, 0.20766646, 0.14306472, 0.05770602],
             [0.11433297, 0.02340085, 0.1404331 , 0.08351222, 0.03435331],
             [0.12930429, 0.03125305, 0.18979977, 0.10114681, 0.03891528]],            dtype=float32)