<a href="https://colab.research.google.com/github/vifirsanova/100-days-of-code/blob/main/day15/Reformer_Efficient_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

* *Reversible Layers* reduce memory 
* *Locality Sensitive Hashing (LSH)* reduces the cost of the Dot Product attention for large input sizes

In [1]:
!pip install trax

import os
import trax
from trax import layers as tl 
import jax
from trax import fastmath
fastmath.use_backend('tensorflow-numpy')
import functools
from trax.fastmath import numpy as np
from jax.lax import tie_in
from trax.layers import (
    length_normalized,
    apply_broadcasted_dropout,
    look_adjacent,
    permute_via_gather,
    permute_via_sort,
)

Collecting trax
[?25l  Downloading https://files.pythonhosted.org/packages/42/51/305b839f51d53abb393777f743e497d27bb341478f3fdec4d6ddaccc9fb5/trax-1.3.7-py2.py3-none-any.whl (521kB)
[K     |████████████████████████████████| 522kB 5.2MB/s 
Collecting tensorflow-text
[?25l  Downloading https://files.pythonhosted.org/packages/b6/c0/c0fed4301f592c3b56638ae7292612c17d91a43891ba1aaf9636d535beae/tensorflow_text-2.4.3-cp37-cp37m-manylinux1_x86_64.whl (3.4MB)
[K     |████████████████████████████████| 3.4MB 8.8MB/s 
Collecting t5
[?25l  Downloading https://files.pythonhosted.org/packages/65/83/376533337f39711929bb3f5c2263bfec4bf54abe5f2f1987f3ddf2e10a76/t5-0.9.0-py3-none-any.whl (230kB)
[K     |████████████████████████████████| 235kB 34.7MB/s 
Collecting funcsigs
  Downloading https://files.pythonhosted.org/packages/69/cb/f5be453359271714c01b9bd06126eaf2e368f1fddfff30818754b5ac2328/funcsigs-1.0.2-py2.py3-none-any.whl
Collecting rouge-score
  Downloading https://files.pythonhosted.org/packa

In [2]:
def mask_self_attention(dots, q_info, kv_info, causal=True, exclude_self=True, masked=False):
    if causal:
        mask = fastmath.lt(q_info, kv_info).astype(np.float32)
        dots = dots - 1e9 * mask
    if exclude_self:
        mask = np.equal(q_info, kv_info).astype(np.float32)
        dots = dots - 1e5 * mask
    if masked:
        zeros_like_kv_info = tie_in(kv_info, np.zeros_like(kv_info))
        mask = fastmath.lt(kv_info, zeros_like_kv_info).astype(np.float32)
        dots = dots - 1e9 * mask
    return dots

Softmax
$$ softmax(x_i)=\frac{\exp(x_i)}{\sum_j \exp(x_j)}$$
___

Alternative softmax calculation
$$ logsumexp(x)=\log{({\sum_j \exp(x_j)})}$$
$$ softmax(x_i)=\exp({x_i - logsumexp(x)}$$

In [5]:
def softmax(x, passthrough=False):
    logsumexp = fastmath.logsumexp(x, axis=-1, keepdims=True)
    o = np.exp(x - logsumexp)
    if passthrough:
        return (x, np.zeros_like(logsumexp))
    else:
        return (o, logsumexp)

# Compare two softmax calculation methods
test = np.array([4.0, 3.0, 2.0, 1.0])
print('The first method: ', np.exp(test) / sum(np.exp(test)))
print('The second method:', softmax(test)[0])
print('Logsumexp is', softmax(test)[1])

The first method:  [0.6439142  0.2368828  0.08714432 0.0320586 ]
The second method: [0.64391416 0.23688279 0.08714431 0.0320586 ]
Logsumexp is [4.44019]


In [6]:
def attend(
    q,
    k=None,
    v=None,
    mask_fn=None,
    q_info=None,
    kv_info=None,
    dropout=0.0,
    rng=None,
    verbose=False,
    passthrough=False,
):
    assert v is not None
    share_qk = k is None
    if share_qk:
        k = q
        if kv_info is None:
            kv_info = q_info

    if share_qk:
        k = length_normalized(k)
    k = k / np.sqrt(k.shape[-1])

    # Dot-product attention.
    kr = np.swapaxes(k, -1, -2)  # k transpose
    dots = np.matmul(q, kr)
    if verbose:
        print("Dots", dots.shape)

    # Masking
    if mask_fn is not None:
        dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :])

    # Softmax.
    dots, dots_logsumexp = softmax(dots)
    if verbose:
        print("Attend dots post softmax", dots.shape, dots_logsumexp.shape)

    if dropout > 0.0:
        assert rng is not None
        # Dropout is broadcast across the bin dimension
        dropout_shape = (dots.shape[-2], dots.shape[-1])
        keep_prob = tie_in(dots, 1.0 - dropout)
        keep = fastmath.random.bernoulli(rng, keep_prob, dropout_shape)
        multiplier = keep.astype(dots.dtype) / tie_in(keep, keep_prob)
        dots = dots * multiplier

    # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn.
    out = np.matmul(dots, v)
    if verbose:
        print("Attend out1", out.shape)
    out = np.reshape(out, (-1, out.shape[-1]))
    if verbose:
        print("Attend out2", out.shape)
    dots_logsumexp = np.reshape(dots_logsumexp, (-1,))
    return out, dots_logsumexp


seq_len = 10
emb_len = 5
d_qk = 2
d_v = 4
with fastmath.use_backend("jax"):  # the backend 
    rng_attend = fastmath.random.get_prng(1)
    q = k = jax.random.uniform(rng_attend, (seq_len, d_qk), dtype=np.float32)
    v = jax.random.uniform(rng_attend, (seq_len, d_v), dtype=np.float32)
    o, logits = attend(
        q,
        k,
        v,
        mask_fn=None,
        q_info=None,
        kv_info=None,
        dropout=0.0,
        rng=rng_attend,
        verbose=True,
    )
print(o, "\n", logits)

Dots (10, 10)
Attend dots post softmax (10, 10) (10, 1)
Attend out1 (10, 4)
Attend out2 (10, 4)
[[0.41215032 0.4410399  0.5783373  0.5694611 ]
 [0.4146571  0.44229436 0.5821161  0.5674372 ]
 [0.41805273 0.44460416 0.58764815 0.5645627 ]
 [0.41195232 0.44039991 0.5797606  0.5747733 ]
 [0.4137767  0.44150043 0.5821494  0.5723262 ]
 [0.4117964  0.4412175  0.5771261  0.5673998 ]
 [0.41744018 0.44462895 0.5903716  0.57462746]
 [0.4107483  0.44038767 0.5765146  0.57118237]
 [0.41690123 0.44381702 0.5880829  0.5720327 ]
 [0.41648537 0.44338638 0.5865165  0.570186  ]] 
 [2.3861837 2.4581163 2.5707448 2.4802527 2.5137725 2.33331   2.755928
 2.36104   2.6662705 2.6040623]


In [7]:
class SelfAttention(tl.SelfAttention):
    def forward_unbatched(
        self, x, mask=None, *, weights, state, rng, update_state, verbose=False
    ):
        del update_state
        attend_rng, output_rng = fastmath.random.split(rng)
        if self.bias:
            if self.share_qk:
                w_q, w_v, w_o, b_q, b_v = weights
            else:
                w_q, w_k, w_v, w_o, b_q, b_k, b_v = weights
        else:
            if self.share_qk:
                w_q, w_v, w_o = weights
            else:
                w_q, w_k, w_v, w_o = weights

        print("x.shape, w_q.shape", x.shape, w_q.shape)
        q = np.matmul(x, w_q)
        k = None
        if not self.share_qk:
            k = np.matmul(x, w_k)
        v = np.matmul(x, w_v)

        if self.bias:
            q = q + b_q
            if not self.share_qk:
                k = k + b_k
            v = v + b_v

        mask_fn = functools.partial(
            mask_self_attention,
            causal=self.causal,
            exclude_self=self.share_qk,
            masked=self.masked,
        )
        q_info = kv_info = tie_in(x, np.arange(q.shape[-2], dtype=np.int32))

        assert (mask is not None) == self.masked
        if self.masked:
            # mask is a boolean array (True means "is valid token")
            ones_like_mask = tie_in(x, np.ones_like(mask, dtype=np.int32))
            kv_info = kv_info * np.where(mask, ones_like_mask, -ones_like_mask)

        o, _ = attend(
            q,
            k,
            v,
            mask_fn=mask_fn,
            q_info=q_info,
            kv_info=kv_info,
            dropout=self.attention_dropout,
            rng=attend_rng,
            verbose=True,
        )

        # wo weight matrix applied to output of attend in forward_unbatched
        out = np.matmul(o, w_o)
        out = apply_broadcasted_dropout(out, self.output_dropout, output_rng)
        return out, state


causal = False
masked = False
mask = None
attention_dropout = 0.01
n_heads = 12
d_qk = 2
d_v = 4
seq_len = 10
emb_len = 5
batch_size = 16

osa = SelfAttention(
    n_heads=n_heads,
    d_qk=d_qk,
    d_v=d_v,
    causal=causal,
    use_reference_code=True,
    attention_dropout=attention_dropout,
    mode="train",
)

rng_osa = fastmath.random.get_prng(1)
x = jax.random.uniform(
    jax.random.PRNGKey(0), (batch_size, seq_len, emb_len), dtype=np.float32
)
_, _ = osa.init(tl.shapes.signature(x), rng=rng_osa)

osa(x)

x.shape, w_q.shape (10, 5) (5, 2)
Dots (10, 10)
Attend dots post softmax (10, 10) (10, 1)
Attend out1 (10, 4)
Attend out2 (10, 4)
x.shape, w_q.shape (10, 5) (5, 2)
Dots (10, 10)
Attend dots post softmax (10, 10) (10, 1)
Attend out1 (10, 4)
Attend out2 (10, 4)
x.shape, w_q.shape (10, 5) (5, 2)
Dots (10, 10)
Attend dots post softmax (10, 10) (10, 1)
Attend out1 (10, 4)
Attend out2 (10, 4)
x.shape, w_q.shape (10, 5) (5, 2)
Dots (10, 10)
Attend dots post softmax (10, 10) (10, 1)
Attend out1 (10, 4)
Attend out2 (10, 4)
x.shape, w_q.shape (10, 5) (5, 2)
Dots (10, 10)
Attend dots post softmax (10, 10) (10, 1)
Attend out1 (10, 4)
Attend out2 (10, 4)
x.shape, w_q.shape (10, 5) (5, 2)
Dots (10, 10)
Attend dots post softmax (10, 10) (10, 1)
Attend out1 (10, 4)
Attend out2 (10, 4)
x.shape, w_q.shape (10, 5) (5, 2)
Dots (10, 10)
Attend dots post softmax (10, 10) (10, 1)
Attend out1 (10, 4)
Attend out2 (10, 4)
x.shape, w_q.shape (10, 5) (5, 2)
Dots (10, 10)
Attend dots post softmax (10, 10) (10, 1)


DeviceArray([[[ 2.70541877e-01, -1.91065639e-01, -9.75493491e-02,
               -4.98708904e-01, -9.28128883e-02],
              [ 2.70205379e-01, -1.89427227e-01, -9.98546556e-02,
               -4.99432772e-01, -9.20516998e-02],
              [ 2.67894775e-01, -1.87397450e-01, -9.83914435e-02,
               -5.02314806e-01, -8.69192630e-02],
              [ 2.69780993e-01, -1.88655436e-01, -9.91298556e-02,
               -5.00778019e-01, -8.91079381e-02],
              [ 2.64413923e-01, -1.61305130e-01, -9.35228318e-02,
               -4.34522957e-01, -9.77130830e-02],
              [ 2.68971711e-01, -1.87803984e-01, -1.00712404e-01,
               -5.01329243e-01, -8.79787207e-02],
              [ 2.72200435e-01, -1.90876096e-01, -9.56467688e-02,
               -4.98503745e-01, -9.28032100e-02],
              [ 2.18693882e-01, -1.79120436e-01, -9.83639807e-02,
               -4.28791493e-01, -1.08542286e-01],
              [ 2.65745878e-01, -1.86937898e-01, -1.03170246e-01,
      

**LSH Self-Atention**

* uses q's only (no k's)
* calculates similarity of each q relative to all other q's
* uses bucketing
* generates multiple hash tables
* dot product is generated only between members of the bucket
* we get a reduced dot-product attention array

In [15]:
def hash_vectors(vecs, rng, n_buckets, n_hashes, mask=None, verbose=False):
    # check for even, integer bucket sizes
    assert isinstance(n_buckets, int) and n_buckets % 2 == 0

    rng = fastmath.stop_gradient(tie_in(vecs, rng))
    rot_size = n_buckets

    rotations_shape = [vecs.shape[-1], n_hashes, rot_size//2]
    random_rotations = fastmath.random.normal(rng, rotations_shape).astype(np.float32)
    if verbose:
        print("random.rotations.shape", random_rotations.shape)

    if fastmath.backend_name() == "jax":
        rotated_vecs = np.einsum("tf,fhb->htb", vecs, random_rotations)
    else:
        random_rotations = np.reshape(random_rotations, [-1, n_hashes * (rot_size // 2)])
        if verbose:
            print("random_rotations reshaped", random_rotations.shape)
        rotated_vecs = np.dot(vecs, random_rotations)
        if verbose:
            print("rotated_vecs1", rotated_vecs.shape)
        rotated_vecs = np.reshape(rotated_vecs, [-1, n_hashes, rot_size//2])
        if verbose:
            print("rotated_vecs2", rotated_vecs.shape)
        rotated_vecs = np.transpose(rotated_vecs, (1, 0, 2))
        if verbose:
            print("rotated_vecs3", rotated_vecs.shape)

    rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
    if verbose:
        print("rotated_vecs.shape", rotated_vecs.shape)
    buckets = np.argmax(rotated_vecs, axis=-1).astype(np.int32)
    if verbose:
        print("buckets.shape", buckets.shape)
    if verbose:
        print("buckets", buckets)

    if mask is not None:
        n_buckets += 1  # Create an extra bucket for padding tokens only
        buckets = np.where(mask[None, :], buckets, n_buckets - 1)

    # buckets is now (n_hashes, seqlen)
    # offsets is needed for bucket numbers from different hashing rounds not to overlap
    offsets = tie_in(buckets, np.arange(n_hashes, dtype=np.int32))
    offsets = np.reshape(offsets * n_buckets, (-1, 1))
    size = n_hashes * vecs.shape[0]
    buckets = np.reshape(buckets + offsets, (-1,))
    if verbose:
        print("buckets with offsets", buckets.shape, "\n", buckets)
    return buckets


ohv_q = np.ones((10, 5))  
ohv_n_buckets = 6  # even number
ohv_n_hashes = 3
with fastmath.use_backend("tensorflow-numpy"):
    ohv_rng = fastmath.random.get_prng(1)
    ohv = hash_vectors(
        ohv_q, ohv_rng, ohv_n_buckets, ohv_n_hashes, mask=None, verbose=True
    )
    print("ohv shape", ohv.shape, "\nohv", ohv)  # (ohv_n_hashes * ohv_n_buckets)
# the random number generators do not produce the same results with different backends
with fastmath.use_backend("jax"):
    ohv_rng = fastmath.random.get_prng(1)
    ohv = hash_vectors(ohv_q, ohv_rng, ohv_n_buckets, ohv_n_hashes, mask=None)
    print("ohv shape", ohv.shape, "\nohv", ohv)  # (ohv_n_hashes * ohv_n_buckets)

random.rotations.shape (5, 3, 3)
random_rotations reshaped (5, 9)
rotated_vecs1 (10, 9)
rotated_vecs2 (10, 3, 3)
rotated_vecs3 (3, 10, 3)
rotated_vecs.shape (3, 10, 6)
buckets.shape (3, 10)
buckets ndarray<tf.Tensor(
[[0 0 0 0 0 0 0 0 0 0]
 [3 3 3 3 3 3 3 3 3 3]
 [5 5 5 5 5 5 5 5 5 5]], shape=(3, 10), dtype=int32)>
buckets with offsets (30,) 
 ndarray<tf.Tensor(
[ 0  0  0  0  0  0  0  0  0  0  9  9  9  9  9  9  9  9  9  9 17 17 17 17
 17 17 17 17 17 17], shape=(30,), dtype=int32)>
ohv shape (30,) 
ohv ndarray<tf.Tensor(
[ 0  0  0  0  0  0  0  0  0  0  9  9  9  9  9  9  9  9  9  9 17 17 17 17
 17 17 17 17 17 17], shape=(30,), dtype=int32)>
ohv shape (30,) 
ohv [ 0  0  0  0  0  0  0  0  0  0 11 11 11 11 11 11 11 11 11 11 15 15 15 15
 15 15 15 15 15 15]


`q[n_seq,n_q]`
`n_hash = 2`
`n_buckets = 4`
`n_seq = 8`

`bucket = [0,1,2,3,0,1,2,3, 4,5,6,7,4,5,6,7]`

Bucket is `n_hash*n_seq` long, the bucket values are offset by `n_hash` (the numbers do not overlap).

In [9]:
def sort_buckets(buckets, q, v, n_buckets, n_hashes, seqlen, verbose=True):
    if verbose:
        print("---sort_buckets--")
    ticker = np.arange(n_hashes*seqlen)
    if verbose:
        print("ticker", ticker.shape, ticker)
    buckets_and_t = seqlen * buckets + (ticker % seqlen)
    if verbose:
        print("buckets_and_t", buckets_and_t.shape, buckets_and_t)

    # Hash-based sort ("s" at the start of variable names means "sorted")
    sorted_buckets_and_t, sorted_ticker = fastmath.sort_key_val(buckets_and_t, ticker, dimension=-1)
    if verbose:
        print("sorted_buckets_and_t", sorted_buckets_and_t.shape, sorted_buckets_and_t)
    if verbose:
        print("sorted_ticker", sorted_ticker.shape, sorted_ticker)
    _, undo_sort = fastmath.sort_key_val(sorted_ticker, ticker, dimension=-1)
    if verbose:
        print("undo_sort", undo_sort.shape, undo_sort)

    st = sorted_ticker % seqlen
    sq = np.take(q, st, axis=0)
    sv = np.take(v, st, axis=0)
    return sq, sv, sorted_ticker, undo_sort


t_n_hashes = 2
t_n_buckets = 6
t_n_seq = t_seqlen = 10
t_n_q = 3
n_v = 5

t_q = (np.array([(j % t_n_buckets) for j in range(t_n_seq)]) * np.ones((t_n_q, 1))).T
t_v = np.ones((t_n_seq, n_v))
t_buckets = np.array(
    [
        (j % t_n_buckets) + t_n_buckets * i
        for i in range(t_n_hashes)
        for j in range(t_n_seq)
    ]
)
print("q\n", t_q)
print("t_buckets: ", t_buckets)

t_sq, t_sv, t_sticker, t_undo_sort = sort_buckets(
    t_buckets, t_q, t_v, t_n_buckets, t_n_hashes, t_seqlen, verbose=True
)

print("sq.shape", t_sq.shape, "sv.shape", t_sv.shape)
print("sq\n", t_sq)

q
 [[0. 0. 0.]
 [1. 1. 1.]
 [2. 2. 2.]
 [3. 3. 3.]
 [4. 4. 4.]
 [5. 5. 5.]
 [0. 0. 0.]
 [1. 1. 1.]
 [2. 2. 2.]
 [3. 3. 3.]]
t_buckets:  [ 0  1  2  3  4  5  0  1  2  3  6  7  8  9 10 11  6  7  8  9]
---sort_buckets--
ticker (20,) [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
buckets_and_t (20,) [  0  11  22  33  44  55   6  17  28  39  60  71  82  93 104 115  66  77
  88  99]
sorted_buckets_and_t (20,) [  0   6  11  17  22  28  33  39  44  55  60  66  71  77  82  88  93  99
 104 115]
sorted_ticker (20,) [ 0  6  1  7  2  8  3  9  4  5 10 16 11 17 12 18 13 19 14 15]
undo_sort (20,) [ 0  2  4  6  8  9  1  3  5  7 10 12 14 16 18 19 11 13 15 17]
sq.shape (20, 3) sv.shape (20, 5)
sq
 [[0. 0. 0.]
 [0. 0. 0.]
 [1. 1. 1.]
 [1. 1. 1.]
 [2. 2. 2.]
 [2. 2. 2.]
 [3. 3. 3.]
 [3. 3. 3.]
 [4. 4. 4.]
 [5. 5. 5.]
 [0. 0. 0.]
 [0. 0. 0.]
 [1. 1. 1.]
 [1. 1. 1.]
 [2. 2. 2.]
 [2. 2. 2.]
 [3. 3. 3.]
 [3. 3. 3.]
 [4. 4. 4.]
 [5. 5. 5.]]


In [12]:
def dotandv(
    sorted_q, sorted_v, undo_sort, kv_chunk_len, n_hashes, seqlen, passthrough, verbose=False
):
            
    reshaped_sorted_q = np.reshape(sorted_q, (-1, kv_chunk_len, sorted_q.shape[-1]))
    reshaped_sorted_qt = np.swapaxes(reshaped_sorted_q, -1, -2)
    if verbose:
        print("rsorted_q.shape,reshaped_sorted_qt.shape: ", reshaped_sorted_q.shape, reshaped_sorted_qt.shape)
    dotlike = np.matmul(reshaped_sorted_q, reshaped_sorted_qt)
    if verbose:
        print("dotlike\n", dotlike)

    dotlike, sorted_logits = softmax(dotlike, passthrough)
    if verbose:
        print("dotlike post softmax\n", dotlike)

    vr = np.reshape(sorted_v, (-1, kv_chunk_len, sorted_v.shape[-1]))
    if verbose:
        print("dotlike.shape, vr.shape:", dotlike.shape, vr.shape)
    so = np.matmul(dotlike, vr)
    if verbose:
        print("so.shape:", so.shape)
    so = np.reshape(so, (-1, so.shape[-1]))
    sorted_logits = np.reshape(sorted_logits, (-1,))  # provided
    if verbose:
        print("so.shape,sorted_logits.shape", so.shape, sorted_logits.shape)

    o = np.take(so, undo_sort, axis=0)
    logits = np.take(sorted_logits, undo_sort, axis=0)
    if verbose:
        print("o.shape,o", o.shape, o)
    if verbose:
        print("logits.shape, logits", logits.shape, logits)

    if n_hashes > 1:
        o = np.reshape(o, (n_hashes, seqlen, o.shape[-1]))
        logits = np.reshape(logits, (n_hashes, seqlen, 1))
        probs = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True))
        o = np.sum(o * probs, axis=0)

    return o

In [13]:
t_kv_chunk_len = 2
out = dotandv(
    t_sq,
    t_sv,
    t_undo_sort,
    t_kv_chunk_len,
    t_n_hashes,
    t_seqlen,
    passthrough=True,
    verbose=True,
)
print("out\n", out)
print("\n-----With softmax enabled----\n")
out = dotandv(
    t_sq,
    t_sv,
    t_undo_sort,
    t_kv_chunk_len,
    t_n_hashes,
    t_seqlen,
    passthrough=False,
    verbose=True,
)
print("out\n", out)

rsorted_q.shape,reshaped_sorted_qt.shape:  (10, 2, 3) (10, 3, 2)
dotlike
 [[[ 0.  0.]
  [ 0.  0.]]

 [[ 3.  3.]
  [ 3.  3.]]

 [[12. 12.]
  [12. 12.]]

 [[27. 27.]
  [27. 27.]]

 [[48. 60.]
  [60. 75.]]

 [[ 0.  0.]
  [ 0.  0.]]

 [[ 3.  3.]
  [ 3.  3.]]

 [[12. 12.]
  [12. 12.]]

 [[27. 27.]
  [27. 27.]]

 [[48. 60.]
  [60. 75.]]]
dotlike post softmax
 [[[ 0.  0.]
  [ 0.  0.]]

 [[ 3.  3.]
  [ 3.  3.]]

 [[12. 12.]
  [12. 12.]]

 [[27. 27.]
  [27. 27.]]

 [[48. 60.]
  [60. 75.]]

 [[ 0.  0.]
  [ 0.  0.]]

 [[ 3.  3.]
  [ 3.  3.]]

 [[12. 12.]
  [12. 12.]]

 [[27. 27.]
  [27. 27.]]

 [[48. 60.]
  [60. 75.]]]
dotlike.shape, vr.shape: (10, 2, 2) (10, 2, 5)
so.shape: (10, 2, 5)
so.shape,sorted_logits.shape (20, 5) (20,)
o.shape,o (20, 5) [[  0.   0.   0.   0.   0.]
 [  6.   6.   6.   6.   6.]
 [ 24.  24.  24.  24.  24.]
 [ 54.  54.  54.  54.  54.]
 [108. 108. 108. 108. 108.]
 [135. 135. 135. 135. 135.]
 [  0.   0.   0.   0.   0.]
 [  6.   6.   6.   6.   6.]
 [ 24.  24.  24.  24.  24.]
 [ 

In [14]:
# original version from trax 1.3.4
def attend(
    q,
    k=None,
    v=None,
    q_chunk_len=None,
    kv_chunk_len=None,
    n_chunks_before=0,
    n_chunks_after=0,
    mask_fn=None,
    q_info=None,
    kv_info=None,
    dropout=0.0,
    rng=None,
):
    """Dot-product attention, with optional chunking and/or masking.

  Args:
    q: Query vectors, shape [q_len, d_qk]
    k: Key vectors, shape [kv_len, d_qk]; or None
    v: Value vectors, shape [kv_len, d_v]
    q_chunk_len: Set to non-zero to enable chunking for query vectors
    kv_chunk_len: Set to non-zero to enable chunking for key/value vectors
    n_chunks_before: Number of adjacent previous chunks to attend to
    n_chunks_after: Number of adjacent subsequent chunks to attend to
    mask_fn: TODO(kitaev) doc
    q_info: Query-associated metadata for masking
    kv_info: Key-associated metadata for masking
    dropout: Dropout rate
    rng: RNG for dropout

  Returns:
    A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and
    dots_logsumexp has shape [q_len]. The logsumexp of the attention
    probabilities is useful for combining multiple rounds of attention (as in
    LSH attention).
  """
    assert v is not None
    share_qk = k is None

    if q_info is None:
        q_info = np.arange(q.shape[-2], dtype=np.int32)

    if kv_info is None and not share_qk:
        kv_info = np.arange(v.shape[-2], dtype=np.int32)

    # Split q/k/v into chunks along the time axis, if desired.
    if q_chunk_len is not None:
        q = np.reshape(q, (-1, q_chunk_len, q.shape[-1]))
        q_info = np.reshape(q_info, (-1, q_chunk_len))

    if share_qk:
        assert kv_chunk_len is None or kv_chunk_len == q_chunk_len
        k = q
        kv_chunk_len = q_chunk_len
        if kv_info is None:
            kv_info = q_info
        elif kv_chunk_len is not None:
            # kv_info is not None, but reshape as required.
            kv_info = np.reshape(kv_info, (-1, kv_chunk_len))
    elif kv_chunk_len is not None:
        k = np.reshape(k, (-1, kv_chunk_len, k.shape[-1]))
        kv_info = np.reshape(kv_info, (-1, kv_chunk_len))

    if kv_chunk_len is not None:
        v = np.reshape(v, (-1, kv_chunk_len, v.shape[-1]))

    if share_qk:
        k = length_normalized(k)
    k = k / np.sqrt(k.shape[-1])

    # Optionally include adjacent chunks.
    if q_chunk_len is not None or kv_chunk_len is not None:
        assert q_chunk_len is not None and kv_chunk_len is not None
    else:
        assert n_chunks_before == 0 and n_chunks_after == 0

    k = look_adjacent(k, n_chunks_before, n_chunks_after)
    v = look_adjacent(v, n_chunks_before, n_chunks_after)
    kv_info = look_adjacent(kv_info, n_chunks_before, n_chunks_after)

    # Dot-product attention.
    dots = np.matmul(q, np.swapaxes(k, -1, -2))

    # Masking
    if mask_fn is not None:
        dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :])

    # Softmax.
    dots_logsumexp = fastmath.logsumexp(dots, axis=-1, keepdims=True)
    dots = np.exp(dots - dots_logsumexp)

    if dropout > 0.0:
        assert rng is not None
        # Dropout is broadcast across the bin dimension
        dropout_shape = (dots.shape[-2], dots.shape[-1])
        #
        keep_prob = tie_in(dots, 1.0 - dropout)
        keep = fastmath.random.bernoulli(rng, keep_prob, dropout_shape)
        multiplier = keep.astype(dots.dtype) / tie_in(keep, keep_prob)
        dots = dots * multiplier

    # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn.
    out = np.matmul(dots, v)
    out = np.reshape(out, (-1, out.shape[-1]))
    dots_logsumexp = np.reshape(dots_logsumexp, (-1,))
    return out, dots_logsumexp

In [17]:
class LSHSelfAttention(tl.LSHSelfAttention):

    def forward_unbatched(self, x, mask=None, *, weights, state, rng, update_state):
        attend_rng, output_rng = fastmath.random.split(rng)
        w_q, w_v, w_o = weights

        q = np.matmul(x, w_q)
        v = np.matmul(x, w_v)

        if update_state:
            _, old_hash_rng = state
            hash_rng, hash_subrng = fastmath.random.split(old_hash_rng)
            buckets = hash_vectors(
                q, hash_subrng, self.n_buckets, self.n_hashes, mask=mask
            )
            s_buckets = buckets
            if self._max_length_for_buckets:
                length = self.n_hashes * self._max_length_for_buckets
                if buckets.shape[0] < length:
                    s_buckets = np.concatenate(
                        [buckets, np.zeros(length - buckets.shape[0], dtype=np.int32)],
                        axis=0,
                    )
            state = (s_buckets, hash_rng)
        else:
            buckets, _ = state
            if self._max_length_for_buckets:
                buckets = buckets[: self.n_hashes * x.shape[0]]

        seqlen = x.shape[0]
        assert int(buckets.shape[0]) == self.n_hashes * seqlen

        ticker = tie_in(x, np.arange(self.n_hashes * seqlen, dtype=np.int32))
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = fastmath.stop_gradient(buckets_and_t)

        sbuckets_and_t, sticker = fastmath.sort_key_val(
            buckets_and_t, ticker, dimension=-1
        )
        _, undo_sort = fastmath.sort_key_val(sticker, ticker, dimension=-1)
        sbuckets_and_t = fastmath.stop_gradient(sbuckets_and_t)
        sticker = fastmath.stop_gradient(sticker)
        undo_sort = fastmath.stop_gradient(undo_sort)

        st = sticker % seqlen
        sq = np.take(q, st, axis=0)
        sv = np.take(v, st, axis=0)

        mask_fn = functools.partial(
            mask_self_attention,
            causal=self.causal,
            exclude_self=True,
            masked=self.masked,
        )
        q_info = st

        assert (mask is not None) == self.masked
        kv_info = None
        if self.masked:
            smask = np.take(mask, st, axis=0)
            ones_like_mask = tie_in(x, np.ones_like(smask, dtype=np.int32))
            kv_info = q_info * np.where(smask, ones_like_mask, -ones_like_mask)

        so, slogits = attend(
            sq,
            k=None,
            v=sv,
            q_chunk_len=self.chunk_len,
            n_chunks_before=self.n_chunks_before,
            n_chunks_after=self.n_chunks_after,
            mask_fn=mask_fn,
            q_info=q_info,
            kv_info=kv_info,
            dropout=self.attention_dropout,
            rng=attend_rng,
        )

        o = permute_via_gather(so, undo_sort, sticker, axis=0)
        logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1)

        if self.n_hashes > 1:
            o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
            logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
            probs = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True))
            o = np.sum(o * probs, axis=0)

        assert o.shape == (seqlen, w_v.shape[-1])
        out = np.matmul(o, w_o)
        out = apply_broadcasted_dropout(out, self.output_dropout, output_rng)
        return out, state


# Here we're going to try out our LSHSelfAttention
n_heads = 3
causal = False
masked = False
mask = None
chunk_len = 10
n_chunks_before = 0
n_chunks_after = 0
attention_dropout = 0.0
n_hashes = 5
n_buckets = 4
seq_len = 10
emb_len = 3
al = LSHSelfAttention(
    n_heads=n_heads,
    d_qk=3,
    d_v=4,
    causal=causal,
    chunk_len=10,
    n_chunks_before=n_chunks_before,
    n_chunks_after=n_chunks_after,
    n_hashes=n_hashes,
    n_buckets=n_buckets,
    use_reference_code=True,
    attention_dropout=attention_dropout,
    mode="train",
)

x = jax.random.uniform(jax.random.PRNGKey(0), (1, seq_len, emb_len), dtype=np.float32)
al_osa = fastmath.random.get_prng(1)
_, _ = al.init(tl.shapes.signature(x), rng=al_osa)

al(x)

DeviceArray([[[-0.1887156 , -0.4345857 , -0.30507997],
              [-0.17001462, -0.3992369 , -0.27571315],
              [-0.18278453, -0.42552638, -0.31050825],
              [-0.19504203, -0.42496088, -0.30478832],
              [-0.17509116, -0.40555453, -0.28232577],
              [-0.17590663, -0.3963362 , -0.29975712],
              [-0.2093995 , -0.45383382, -0.2939606 ],
              [-0.19098149, -0.43987733, -0.29390997],
              [-0.2013992 , -0.44397017, -0.32320336],
              [-0.20334871, -0.45338914, -0.3046518 ]]], dtype=float32)