# Reformer Efficient Attention

# Trax Efficient Attention classes

In [1]:
import os
import trax
from trax import layers as tl  # core building block
import jax
from trax import fastmath  # uses jax, offers numpy on steroids

# fastmath.use_backend('tensorflow-numpy')
import functools
from trax.fastmath import numpy as np  # note, using fastmath subset of numpy!
from trax.layers import (
    tie_in,
    length_normalized,
    apply_broadcasted_dropout,
    look_adjacent,
    permute_via_gather,
    permute_via_sort,
)



In [2]:
def mask_self_attention(
    dots, q_info, kv_info, causal=True, exclude_self=True, masked=False
):
    """Performs masking for self-attention."""
    if causal:
        mask = fastmath.lt(q_info, kv_info).astype(np.float32)
        dots = dots - 1e9 * mask
    if exclude_self:
        mask = np.equal(q_info, kv_info).astype(np.float32)
        dots = dots - 1e5 * mask
    if masked:
        zeros_like_kv_info = tie_in(kv_info, np.zeros_like(kv_info))
        mask = fastmath.lt(kv_info, zeros_like_kv_info).astype(np.float32)
        dots = dots - 1e9 * mask
    return dots

In [3]:
def our_softmax(x, passthrough=False):
    """ softmax with passthrough"""
    logsumexp = fastmath.logsumexp(x, axis=-1, keepdims=True)
    o = np.exp(x - logsumexp)
    if passthrough:
        return (x, np.zeros_like(logsumexp))
    else:
        return (o, logsumexp)

In [4]:
## compare softmax(a) using both methods
a = np.array([1.0, 2.0, 3.0, 4.0])
sma = np.exp(a) / sum(np.exp(a))
print(sma)
sma2, a_logsumexp = our_softmax(a)
print(sma2)
print(a_logsumexp)



[0.0320586  0.08714432 0.2368828  0.6439142 ]
[0.0320586  0.08714431 0.23688279 0.64391416]
[4.44019]


In [None]:
def our_simple_attend(
    q, k=None, v=None,
    mask_fn=None, q_info=None, kv_info=None,
    dropout=0.0, rng=None, verbose=False, passthrough=False
    ):


In [5]:
def our_simple_attend(
    q,
    k=None,
    v=None,
    mask_fn=None,
    q_info=None,
    kv_info=None,
    dropout=0.0,
    rng=None,
    verbose=False,
    passthrough=False,
):
    """Dot-product attention,  with masking, without optional chunking and/or.

  Args:
    q: Query vectors, shape [q_len, d_qk]
    k: Key vectors, shape [kv_len, d_qk]; or None
    v: Value vectors, shape [kv_len, d_v]
    mask_fn: a function reference that implements masking (e.g. mask_self_attention)
    q_info: Query-associated metadata for masking
    kv_info: Key-associated metadata for masking
    dropout: Dropout rate
    rng: RNG for dropout

  Returns:
    A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and
    dots_logsumexp has shape [q_len]. The logsumexp of the attention
    probabilities is useful for combining multiple rounds of attention (as in
    LSH attention).
  """
    assert v is not None
    share_qk = k is None
    if share_qk:
        k = q
        if kv_info is None:
            kv_info = q_info

    if share_qk:
        k = length_normalized(k)
    k = k / np.sqrt(k.shape[-1])

    # Dot-product attention.
    kr = np.swapaxes(k, -1, -2)  # note the fancy transpose for later..

    ## Step 1  ##
    dots = np.matmul(q, kr)
    if verbose: print("Our attend dots", dots.shape)

    # Masking
    if mask_fn is not None:
        dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :])

    # Softmax.
    #dots_logsumexp = fastmath.logsumexp(dots, axis=-1, keepdims=True)  #original
    #dots = np.exp(dots - dots_logsumexp)  #original
    ## Step 2  ##
    # replace with our_softmax()
    dots, dots_logsumexp = our_softmax(dots, passthrough=passthrough)
    if verbose: 
        print("Our attend dots post softmax", dots.shape, dots_logsumexp.shape)


    if dropout > 0.0:
        assert rng is not None
        # Dropout is broadcast across the bin dimension
        dropout_shape = (dots.shape[-2], dots.shape[-1])
        keep_prob = tie_in(dots, 1.0 - dropout)
        keep = fastmath.random.bernoulli(rng, keep_prob, dropout_shape)
        multiplier = keep.astype(dots.dtype) / tie_in(keep, keep_prob)
        dots = dots * multiplier

    ## Step 3  ##
    # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn.
    out = np.matmul(dots, v)
    if verbose:
        print("Our attend out1", out.shape)
    out = np.reshape(out, (-1, out.shape[-1]))
    if verbose:
        print("Our attend out2", out.shape)
    dots_logsumexp = np.reshape(dots_logsumexp, (-1,))
    return out, dots_logsumexp

In [6]:
seq_len = 8
emb_len = 5
d_qk = 3
d_v = 4
with fastmath.use_backend("jax"):  # specify the backend for consistency
    rng_attend = fastmath.random.get_prng(1)
    q = k = jax.random.uniform(rng_attend, (seq_len, d_qk), dtype=np.float32)
    v = jax.random.uniform(rng_attend, (seq_len, d_v), dtype=np.float32)
    o, logits = our_simple_attend(
        q,
        k,
        v,
        mask_fn=None,
        q_info=None,
        kv_info=None,
        dropout=0.0,
        rng=rng_attend,
        verbose=True,
    )
print(o, "\n", logits)

Our attend dots (8, 8)
Our attend dots post softmax (8, 8) (8, 1)
Our attend out1 (8, 4)
Our attend out2 (8, 4)
[[0.56063247 0.72906053 0.5251243  0.47101077]
 [0.5713517  0.7199196  0.5033343  0.46975708]
 [0.56228864 0.72884583 0.5217213  0.46318394]
 [0.5568317  0.72234154 0.542236   0.46997213]
 [0.565045   0.72274375 0.5204978  0.4723133 ]
 [0.5617596  0.72167814 0.5329314  0.48003793]
 [0.56754    0.72232544 0.5141734  0.46625745]
 [0.5710045  0.707855   0.53253627 0.45907974]] 
 [2.6512175 2.1914332 2.6630518 2.7792363 2.4583826 2.5421977 2.4145055
 2.5111294]


# Class OurSelfAttention

In [7]:
class OurSelfAttention(tl.SelfAttention):
    """Our self-attention. Just the Forward Function."""

    def forward_unbatched(
        self, x, mask=None, *, weights, state, rng, update_state, verbose=False
    ):
        print("ourSelfAttention:forward_unbatched")
        del update_state
        attend_rng, output_rng = fastmath.random.split(rng)
        if self.bias:
            if self.share_qk:
                w_q, w_v, w_o, b_q, b_v = weights
            else:
                w_q, w_k, w_v, w_o, b_q, b_k, b_v = weights
        else:
            if self.share_qk:
                w_q, w_v, w_o = weights
            else:
                w_q, w_k, w_v, w_o = weights

        print("x.shape,w_q.shape", x.shape, w_q.shape)
        q = np.matmul(x, w_q)
        k = None
        if not self.share_qk:
            k = np.matmul(x, w_k)
        v = np.matmul(x, w_v)

        if self.bias:
            q = q + b_q
            if not self.share_qk:
                k = k + b_k
            v = v + b_v

        mask_fn = functools.partial(
            mask_self_attention,
            causal=self.causal,
            exclude_self=self.share_qk,
            masked=self.masked,
        )
        q_info = kv_info = tie_in(x, np.arange(q.shape[-2], dtype=np.int32))

        assert (mask is not None) == self.masked
        if self.masked:
            # mask is a boolean array (True means "is valid token")
            ones_like_mask = tie_in(x, np.ones_like(mask, dtype=np.int32))
            kv_info = kv_info * np.where(mask, ones_like_mask, -ones_like_mask)

        # Notice, we are callout our vesion of attend
        o, _ = our_simple_attend(
            q,
            k,
            v,
            mask_fn=mask_fn,
            q_info=q_info,
            kv_info=kv_info,
            dropout=self.attention_dropout,
            rng=attend_rng,
            verbose=True,
        )

        # Notice, wo weight matrix applied to output of attend in forward_unbatched
        out = np.matmul(o, w_o)
        out = apply_broadcasted_dropout(out, self.output_dropout, output_rng)
        return out, state

In [8]:
causal = False
masked = False
mask = None
attention_dropout = 0.0
n_heads = 3
d_qk = 3
d_v = 4
seq_len = 8
emb_len = 5
batch_size = 1

osa = OurSelfAttention(
    n_heads=n_heads,
    d_qk=d_qk,
    d_v=d_v,
    causal=causal,
    use_reference_code=True,
    attention_dropout=attention_dropout,
    mode="train",
)

rng_osa = fastmath.random.get_prng(1)
x = jax.random.uniform(
    jax.random.PRNGKey(0), (batch_size, seq_len, emb_len), dtype=np.float32
)
_, _ = osa.init(tl.shapes.signature(x), rng=rng_osa)

In [9]:
osa(x)

ourSelfAttention:forward_unbatched
x.shape,w_q.shape (8, 5) (5, 3)
Our attend dots (8, 8)
Our attend dots post softmax (8, 8) (8, 1)
Our attend out1 (8, 4)
Our attend out2 (8, 4)
ourSelfAttention:forward_unbatched
x.shape,w_q.shape (8, 5) (5, 3)
Our attend dots (8, 8)
Our attend dots post softmax (8, 8) (8, 1)
Our attend out1 (8, 4)
Our attend out2 (8, 4)
ourSelfAttention:forward_unbatched
x.shape,w_q.shape (8, 5) (5, 3)
Our attend dots (8, 8)
Our attend dots post softmax (8, 8) (8, 1)
Our attend out1 (8, 4)
Our attend out2 (8, 4)


DeviceArray([[[ 6.70414150e-01, -1.04319841e-01, -5.33822298e-01,
                1.92711830e-01, -4.52995300e-05],
              [ 6.64090276e-01, -1.01875424e-01, -5.35732985e-01,
                1.88311636e-01, -6.30623102e-03],
              [ 6.73379958e-01, -1.06952399e-01, -5.31989932e-01,
                1.90056771e-01,  1.30265951e-03],
              [ 6.84564888e-01, -1.13240302e-01, -5.50182462e-01,
                1.95673436e-01,  5.47635555e-03],
              [ 6.81435883e-01, -1.11069024e-01, -5.32343268e-01,
                1.91912323e-01,  5.69400191e-03],
              [ 6.80724978e-01, -1.08496964e-01, -5.34994245e-01,
                1.96332291e-01,  5.89764118e-03],
              [ 6.80933297e-01, -1.14087194e-01, -5.18660128e-01,
                1.90674156e-01,  1.14096403e-02],
              [ 6.80265129e-01, -1.09031767e-01, -5.38248658e-01,
                1.94203168e-01,  4.23955917e-03]]], dtype=float32)

# Trax LSHSelfAttention

In [10]:
def our_hash_vectors(vecs, rng, n_buckets, n_hashes, mask=None, verbose=False):
    """ 
  Args:
    vecs: tensor of at least 2 dimension, 
    rng: random number generator
    n_buckets: number of buckets in each hash table
    n_hashes: the number of hash tables
    mask: None indicating no mask or a 1D boolean array of length vecs.shape[0], containing the location of padding value
    verbose: controls prints for debug
  Returns:
    A vector of size n_hashes * vecs.shape[0] containing the buckets associated with each input vector per hash table.
    
    """

    # check for even, integer bucket sizes
    assert isinstance(n_buckets, int) and n_buckets % 2 == 0

    rng = fastmath.stop_gradient(tie_in(vecs, rng))
    rot_size = n_buckets
    ### Start Code Here

    ### Step 1 ###
    rotations_shape = (vecs.shape[-1], n_hashes, rot_size // 2)
    random_rotations = fastmath.random.normal(rng, rotations_shape).astype(np.float32)
    if verbose:
        print("random.rotations.shape", random_rotations.shape)

    ### Step 2 ###
    if fastmath.backend_name() == "jax":
        rotated_vecs = np.einsum("tf,fhb->htb", vecs, random_rotations)
        print("using jax")
    else:
        # Step 2a
        random_rotations = np.reshape(random_rotations,
                                    [-1, n_hashes * (rot_size // 2)])
        if verbose:
            print("random_rotations reshaped", random_rotations.shape)
        # Step 2b
        rotated_vecs = np.dot(vecs, random_rotations)
        if verbose:
            print("rotated_vecs1", rotated_vecs.shape)
        # Step 2c
        rotated_vecs = np.reshape(rotated_vecs, [-1, n_hashes, rot_size//2])
        if verbose:
            print("rotated_vecs2", rotated_vecs.shape)
        # Step 2d
        rotated_vecs = np.transpose(rotated_vecs, (1, 0, 2))
        if verbose:
            print("rotated_vecs3", rotated_vecs.shape)

    ### Step 3 ###
    rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
    if verbose:
        print("rotated_vecs.shape", rotated_vecs.shape)
    ### Step 4 ###
    buckets = np.argmax(rotated_vecs, axis=-1).astype(np.int32)
    if verbose:
        print("buckets.shape", buckets.shape)
    if verbose:
        print("buckets", buckets)

    if mask is not None:
        n_buckets += 1  # Create an extra bucket for padding tokens only
        buckets = np.where(mask[None, :], buckets, n_buckets - 1)

    # buckets is now (n_hashes, seqlen). Next we add offsets so that
    # bucket numbers from different hashing rounds don't overlap.
    offsets = tie_in(buckets, np.arange(n_hashes, dtype=np.int32))
    offsets = np.reshape(offsets * n_buckets, (-1, 1))
    ### Step 5 ###
    buckets = np.reshape(buckets + offsets, (-1,))
    if verbose:
        print("buckets with offsets", buckets.shape, "\n", buckets)
    ### End Code Here
    return buckets

In [11]:
# example code. Note for reference, the sizes in this example match the values in the diagram above.
ohv_q = np.ones((8, 5))  # (seq_len=8, n_q=5)
ohv_n_buckets = 4  # even number
ohv_n_hashes = 3
with fastmath.use_backend("tf"):
    ohv_rng = fastmath.random.get_prng(1)
    ohv = our_hash_vectors(
        ohv_q, ohv_rng, ohv_n_buckets, ohv_n_hashes, mask=None, verbose=True
    )
    print("ohv shape", ohv.shape, "\nohv", ohv)  # (ohv_n_hashes * ohv_n_buckets)
# note the random number generators do not produce the same results with different backends
with fastmath.use_backend("jax"):
    ohv_rng = fastmath.random.get_prng(1)
    ohv = our_hash_vectors(ohv_q, ohv_rng, ohv_n_buckets, ohv_n_hashes, mask=None)
    print("ohv shape", ohv.shape, "\nohv", ohv)  # (ohv_n_hashes * ohv_n_buckets)

random.rotations.shape (5, 3, 2)
random_rotations reshaped (5, 6)
rotated_vecs1 (8, 6)
rotated_vecs2 (8, 3, 2)
rotated_vecs3 (3, 8, 2)
rotated_vecs.shape (3, 8, 4)
buckets.shape (3, 8)
buckets ndarray<tf.Tensor(
[[3 3 3 3 3 3 3 3]
 [3 3 3 3 3 3 3 3]
 [3 3 3 3 3 3 3 3]], shape=(3, 8), dtype=int32)>
buckets with offsets (24,) 
 ndarray<tf.Tensor([ 3  3  3  3  3  3  3  3  7  7  7  7  7  7  7  7 11 11 11 11 11 11 11 11], shape=(24,), dtype=int32)>
ohv shape (24,) 
ohv ndarray<tf.Tensor([ 3  3  3  3  3  3  3  3  7  7  7  7  7  7  7  7 11 11 11 11 11 11 11 11], shape=(24,), dtype=int32)>
using jax
ohv shape (24,) 
ohv [ 3  3  3  3  3  3  3  3  5  5  5  5  5  5  5  5 11 11 11 11 11 11 11 11]


# Sorting Buckets

In [12]:
def sort_buckets(buckets, q, v, n_buckets, n_hashes, seqlen, verbose=True):
    """ 
  Args:
    buckets: tensor of at least 2 dimension, 
    n_buckets: number of buckets in each hash table
    n_hashes: the number of hash tables    
    """
    if verbose:
        print("---sort_buckets--")
    ## Step 1
    ticker = np.arange(n_hashes * seqlen)
    if verbose:
        print("ticker", ticker.shape, ticker)
    ## Step 2
    buckets_and_t = seqlen * buckets + (ticker % seqlen)  # provided
    if verbose:
        print("buckets_and_t", buckets_and_t.shape, buckets_and_t)

    # Hash-based sort ("s" at the start of variable names means "sorted")
    # Step 3
    sbuckets_and_t, sticker = fastmath.sort_key_val(buckets_and_t, ticker, dimension=-1)
    if verbose:
        print("sbuckets_and_t", sbuckets_and_t.shape, sbuckets_and_t)
    if verbose:
        print("sticker", sticker.shape, sticker)
    # Step 4
    _, undo_sort = fastmath.sort_key_val(sticker, ticker, dimension=-1)
    if verbose:
        print("undo_sort", undo_sort.shape, undo_sort)

    # Step 5
    st = (sticker % seqlen)
    sq = np.take(q, st, axis=0)
    sv = np.take(v, st, axis=0)
    return sq, sv, sticker, undo_sort

In [13]:
t_n_hashes = 2
t_n_buckets = 4
t_n_seq = t_seqlen = 8
t_n_q = 3
n_v = 5

t_q = (np.array([(j % t_n_buckets) for j in range(t_n_seq)]) * np.ones((t_n_q, 1))).T
t_v = np.ones((t_n_seq, n_v))
t_buckets = np.array(
    [
        (j % t_n_buckets) + t_n_buckets * i
        for i in range(t_n_hashes)
        for j in range(t_n_seq)
    ]
)
print("q\n", t_q)
print("t_buckets: ", t_buckets)

t_sq, t_sv, t_sticker, t_undo_sort = sort_buckets(
    t_buckets, t_q, t_v, t_n_buckets, t_n_hashes, t_seqlen, verbose=True
)

print("sq.shape", t_sq.shape, "sv.shape", t_sv.shape)
print("sq\n", t_sq)

q
 [[0. 0. 0.]
 [1. 1. 1.]
 [2. 2. 2.]
 [3. 3. 3.]
 [0. 0. 0.]
 [1. 1. 1.]
 [2. 2. 2.]
 [3. 3. 3.]]
t_buckets:  [0 1 2 3 0 1 2 3 4 5 6 7 4 5 6 7]
---sort_buckets--
ticker (16,) [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
buckets_and_t (16,) [ 0  9 18 27  4 13 22 31 32 41 50 59 36 45 54 63]
sbuckets_and_t (16,) [ 0  4  9 13 18 22 27 31 32 36 41 45 50 54 59 63]
sticker (16,) [ 0  4  1  5  2  6  3  7  8 12  9 13 10 14 11 15]
undo_sort (16,) [ 0  2  4  6  1  3  5  7  8 10 12 14  9 11 13 15]
sq.shape (16, 3) sv.shape (16, 5)
sq
 [[0. 0. 0.]
 [0. 0. 0.]
 [1. 1. 1.]
 [1. 1. 1.]
 [2. 2. 2.]
 [2. 2. 2.]
 [3. 3. 3.]
 [3. 3. 3.]
 [0. 0. 0.]
 [0. 0. 0.]
 [1. 1. 1.]
 [1. 1. 1.]
 [2. 2. 2.]
 [2. 2. 2.]
 [3. 3. 3.]
 [3. 3. 3.]]


# Chunked dot product attention

In [14]:
a = np.arange(16 * 3).reshape((16, 3))
chunksize = 2
ar = np.reshape(
    a, (-1, chunksize, a.shape[-1])
)  # the -1 usage is very handy, see numpy reshape
print(ar.shape)

(8, 2, 3)


In [15]:
def dotandv(
    sq, sv, undo_sort, kv_chunk_len, n_hashes, seqlen, passthrough, verbose=False
):
    # Step 1
    rsq = np.reshape(sq,(-1, kv_chunk_len, sq.shape[-1]))
    rsqt =  np.swapaxes(rsq, -1, -2)
    if verbose:
        print("rsq.shape,rsqt.shape: ", rsq.shape, rsqt.shape)
    dotlike = np.matmul(rsq, rsqt)
    if verbose:
        print("dotlike\n", dotlike)

    # Step 2
    dotlike, slogits = our_softmax(dotlike, passthrough)
    if verbose:
        print("dotlike post softmax\n", dotlike)

    # Step 3
    vr = np.reshape(sv, (-1, kv_chunk_len, sv.shape[-1]))
    if verbose:
        print("dotlike.shape, vr.shape:", dotlike.shape, vr.shape)
    so = np.matmul(dotlike, vr)
    if verbose:
        print("so.shape:", so.shape)
    so = np.reshape(so, (-1, so.shape[-1]))
    slogits = np.reshape(slogits, (-1,))  # provided
    if verbose:
        print("so.shape,slogits.shape", so.shape, slogits.shape)

    # Step 4
    o = np.take(so, undo_sort, axis=0)
    logits = np.take(slogits, undo_sort, axis=0)
    if verbose:
        print("o.shape,o", o.shape, o)
    if verbose:
        print("logits.shape, logits", logits.shape, logits)

    # Step 5 (Provided)
    if n_hashes > 1:
        o = np.reshape(o, (n_hashes, seqlen, o.shape[-1]))
        logits = np.reshape(logits, (n_hashes, seqlen, 1))
        probs = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True))
        o = np.sum(o * probs, axis=0)

    return o

In [16]:
t_kv_chunk_len = 2
out = dotandv(
    t_sq,
    t_sv,
    t_undo_sort,
    t_kv_chunk_len,
    t_n_hashes,
    t_seqlen,
    passthrough=True,
    verbose=True,
)
print("out\n", out)
print("\n-----With softmax enabled----\n")
out = dotandv(
    t_sq,
    t_sv,
    t_undo_sort,
    t_kv_chunk_len,
    t_n_hashes,
    t_seqlen,
    passthrough=False,
    verbose=True,
)
print("out\n", out)

rsq.shape,rsqt.shape:  (8, 2, 3) (8, 3, 2)
dotlike
 [[[ 0.  0.]
  [ 0.  0.]]

 [[ 3.  3.]
  [ 3.  3.]]

 [[12. 12.]
  [12. 12.]]

 [[27. 27.]
  [27. 27.]]

 [[ 0.  0.]
  [ 0.  0.]]

 [[ 3.  3.]
  [ 3.  3.]]

 [[12. 12.]
  [12. 12.]]

 [[27. 27.]
  [27. 27.]]]
dotlike post softmax
 [[[ 0.  0.]
  [ 0.  0.]]

 [[ 3.  3.]
  [ 3.  3.]]

 [[12. 12.]
  [12. 12.]]

 [[27. 27.]
  [27. 27.]]

 [[ 0.  0.]
  [ 0.  0.]]

 [[ 3.  3.]
  [ 3.  3.]]

 [[12. 12.]
  [12. 12.]]

 [[27. 27.]
  [27. 27.]]]
dotlike.shape, vr.shape: (8, 2, 2) (8, 2, 5)
so.shape: (8, 2, 5)
so.shape,slogits.shape (16, 5) (16,)
o.shape,o (16, 5) [[ 0.  0.  0.  0.  0.]
 [ 6.  6.  6.  6.  6.]
 [24. 24. 24. 24. 24.]
 [54. 54. 54. 54. 54.]
 [ 0.  0.  0.  0.  0.]
 [ 6.  6.  6.  6.  6.]
 [24. 24. 24. 24. 24.]
 [54. 54. 54. 54. 54.]
 [ 0.  0.  0.  0.  0.]
 [ 6.  6.  6.  6.  6.]
 [24. 24. 24. 24. 24.]
 [54. 54. 54. 54. 54.]
 [ 0.  0.  0.  0.  0.]
 [ 6.  6.  6.  6.  6.]
 [24. 24. 24. 24. 24.]
 [54. 54. 54. 54. 54.]]
logits.shape, logits 

# OurLSHSelfAttention

In [17]:
# original version from trax 1.3.4
def attend(
    q,
    k=None,
    v=None,
    q_chunk_len=None,
    kv_chunk_len=None,
    n_chunks_before=0,
    n_chunks_after=0,
    mask_fn=None,
    q_info=None,
    kv_info=None,
    dropout=0.0,
    rng=None,
):
    """Dot-product attention, with optional chunking and/or masking.

  Args:
    q: Query vectors, shape [q_len, d_qk]
    k: Key vectors, shape [kv_len, d_qk]; or None
    v: Value vectors, shape [kv_len, d_v]
    q_chunk_len: Set to non-zero to enable chunking for query vectors
    kv_chunk_len: Set to non-zero to enable chunking for key/value vectors
    n_chunks_before: Number of adjacent previous chunks to attend to
    n_chunks_after: Number of adjacent subsequent chunks to attend to
    mask_fn: TODO(kitaev) doc
    q_info: Query-associated metadata for masking
    kv_info: Key-associated metadata for masking
    dropout: Dropout rate
    rng: RNG for dropout

  Returns:
    A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and
    dots_logsumexp has shape [q_len]. The logsumexp of the attention
    probabilities is useful for combining multiple rounds of attention (as in
    LSH attention).
  """
    assert v is not None
    share_qk = k is None

    if q_info is None:
        q_info = np.arange(q.shape[-2], dtype=np.int32)

    if kv_info is None and not share_qk:
        kv_info = np.arange(v.shape[-2], dtype=np.int32)

    # Split q/k/v into chunks along the time axis, if desired.
    if q_chunk_len is not None:
        q = np.reshape(q, (-1, q_chunk_len, q.shape[-1]))
        q_info = np.reshape(q_info, (-1, q_chunk_len))

    if share_qk:
        assert kv_chunk_len is None or kv_chunk_len == q_chunk_len
        k = q
        kv_chunk_len = q_chunk_len
        if kv_info is None:
            kv_info = q_info
        elif kv_chunk_len is not None:
            # kv_info is not None, but reshape as required.
            kv_info = np.reshape(kv_info, (-1, kv_chunk_len))
    elif kv_chunk_len is not None:
        k = np.reshape(k, (-1, kv_chunk_len, k.shape[-1]))
        kv_info = np.reshape(kv_info, (-1, kv_chunk_len))

    if kv_chunk_len is not None:
        v = np.reshape(v, (-1, kv_chunk_len, v.shape[-1]))

    if share_qk:
        k = length_normalized(k)
    k = k / np.sqrt(k.shape[-1])

    # Optionally include adjacent chunks.
    if q_chunk_len is not None or kv_chunk_len is not None:
        assert q_chunk_len is not None and kv_chunk_len is not None
    else:
        assert n_chunks_before == 0 and n_chunks_after == 0

    k = look_adjacent(k, n_chunks_before, n_chunks_after)
    v = look_adjacent(v, n_chunks_before, n_chunks_after)
    kv_info = look_adjacent(kv_info, n_chunks_before, n_chunks_after)

    # Dot-product attention.
    dots = np.matmul(q, np.swapaxes(k, -1, -2))

    # Masking
    if mask_fn is not None:
        dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :])

    # Softmax.
    dots_logsumexp = fastmath.logsumexp(dots, axis=-1, keepdims=True)
    dots = np.exp(dots - dots_logsumexp)

    if dropout > 0.0:
        assert rng is not None
        # Dropout is broadcast across the bin dimension
        dropout_shape = (dots.shape[-2], dots.shape[-1])
        #
        keep_prob = tie_in(dots, 1.0 - dropout)
        keep = fastmath.random.bernoulli(rng, keep_prob, dropout_shape)
        multiplier = keep.astype(dots.dtype) / tie_in(keep, keep_prob)
        dots = dots * multiplier

    # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn.
    out = np.matmul(dots, v)
    out = np.reshape(out, (-1, out.shape[-1]))
    dots_logsumexp = np.reshape(dots_logsumexp, (-1,))
    return out, dots_logsumexp

In [18]:
class OurLSHSelfAttention(tl.LSHSelfAttention):
    """Our simplified LSH self-attention """

    def forward_unbatched(self, x, mask=None, *, weights, state, rng, update_state):
        attend_rng, output_rng = fastmath.random.split(rng)
        w_q, w_v, w_o = weights

        q = np.matmul(x, w_q)
        v = np.matmul(x, w_v)

        if update_state:
            _, old_hash_rng = state
            hash_rng, hash_subrng = fastmath.random.split(old_hash_rng)
            #      buckets = self.hash_vectors(q, hash_subrng, mask)  #  original
            ## use our version of hash
            buckets = our_hash_vectors(
                q, hash_subrng, self.n_buckets, self.n_hashes, mask=mask
            )
            s_buckets = buckets
            if self._max_length_for_buckets:
                length = self.n_hashes * self._max_length_for_buckets
                if buckets.shape[0] < length:
                    s_buckets = np.concatenate(
                        [buckets, np.zeros(length - buckets.shape[0], dtype=np.int32)],
                        axis=0,
                    )
            state = (s_buckets, hash_rng)
        else:
            buckets, _ = state
            if self._max_length_for_buckets:
                buckets = buckets[: self.n_hashes * x.shape[0]]

        seqlen = x.shape[0]
        assert int(buckets.shape[0]) == self.n_hashes * seqlen

        ticker = tie_in(x, np.arange(self.n_hashes * seqlen, dtype=np.int32))
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = fastmath.stop_gradient(buckets_and_t)

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = fastmath.sort_key_val(
            buckets_and_t, ticker, dimension=-1
        )
        _, undo_sort = fastmath.sort_key_val(sticker, ticker, dimension=-1)
        sbuckets_and_t = fastmath.stop_gradient(sbuckets_and_t)
        sticker = fastmath.stop_gradient(sticker)
        undo_sort = fastmath.stop_gradient(undo_sort)

        st = sticker % seqlen
        sq = np.take(q, st, axis=0)
        sv = np.take(v, st, axis=0)

        mask_fn = functools.partial(
            mask_self_attention,
            causal=self.causal,
            exclude_self=True,
            masked=self.masked,
        )
        q_info = st

        assert (mask is not None) == self.masked
        kv_info = None
        if self.masked:
            # mask is a boolean array (True means "is valid token")
            smask = np.take(mask, st, axis=0)
            ones_like_mask = tie_in(x, np.ones_like(smask, dtype=np.int32))
            kv_info = q_info * np.where(smask, ones_like_mask, -ones_like_mask)

        ## use original version of attend (could use ours but lacks masks and masking)
        so, slogits = attend(
            sq,
            k=None,
            v=sv,
            q_chunk_len=self.chunk_len,
            n_chunks_before=self.n_chunks_before,
            n_chunks_after=self.n_chunks_after,
            mask_fn=mask_fn,
            q_info=q_info,
            kv_info=kv_info,
            dropout=self.attention_dropout,
            rng=attend_rng,
        )

        # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would
        # also work, but these helpers include performance optimizations for TPU.
        o = permute_via_gather(so, undo_sort, sticker, axis=0)
        logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1)

        if self.n_hashes > 1:
            o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
            logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
            probs = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True))
            o = np.sum(o * probs, axis=0)

        assert o.shape == (seqlen, w_v.shape[-1])
        out = np.matmul(o, w_o)
        out = apply_broadcasted_dropout(out, self.output_dropout, output_rng)
        return out, state

In [19]:
# Here we're going to try out our LSHSelfAttention
n_heads = 3
causal = False
masked = False
mask = None
chunk_len = 8
n_chunks_before = 0
n_chunks_after = 0
attention_dropout = 0.0
n_hashes = 5
n_buckets = 4
seq_len = 8
emb_len = 5
al = OurLSHSelfAttention(
    n_heads=n_heads,
    d_qk=3,
    d_v=4,
    causal=causal,
    chunk_len=8,
    n_chunks_before=n_chunks_before,
    n_chunks_after=n_chunks_after,
    n_hashes=n_hashes,
    n_buckets=n_buckets,
    use_reference_code=True,
    attention_dropout=attention_dropout,
    mode="train",
)

x = jax.random.uniform(jax.random.PRNGKey(0), (1, seq_len, emb_len), dtype=np.float32)
al_osa = fastmath.random.get_prng(1)
_, _ = al.init(tl.shapes.signature(x), rng=al_osa)

In [20]:
al(x)

using jax
using jax
using jax


DeviceArray([[[ 6.6842866e-01, -1.1364317e-01, -5.4430598e-01,
                2.1126229e-01, -1.0988474e-02],
              [ 7.0949805e-01, -1.5455173e-01, -5.9923279e-01,
                2.2719435e-01,  1.3833910e-02],
              [ 7.1442688e-01, -1.2046629e-01, -5.3956538e-01,
                1.7320301e-01, -1.6552359e-02],
              [ 6.7178923e-01, -7.6611072e-02, -5.9399861e-01,
                2.1236289e-01,  7.9482794e-04],
              [ 7.1518469e-01, -1.1359149e-01, -5.7821870e-01,
                2.1304403e-01,  3.0598417e-02],
              [ 6.8235356e-01, -9.3979865e-02, -5.5341852e-01,
                2.1608178e-01, -6.6673756e-04],
              [ 6.1286646e-01, -8.1026971e-02, -4.8148811e-01,
                1.9373316e-01,  3.1555206e-02],
              [ 7.2203499e-01, -1.0199657e-01, -5.5215180e-01,
                1.7872262e-01, -2.2289127e-02]]], dtype=float32)

# Putting the "Re" in Reformer

In [21]:
import trax
from trax import layers as tl               # core building block
import numpy as np                          # regular ol' numpy
from trax.models.reformer.reformer import (
    ReversibleHalfResidualV2 as ReversibleHalfResidual,
)                                           # unique spot
from trax import fastmath                   # uses jax, offers numpy on steroids
from trax import shapes                     # data signatures: dimensionality and type
from trax.fastmath import numpy as jnp      # For use in defining new layer types.
from trax.shapes import ShapeDtype
from trax.shapes import signature

# Residual Networks

In [22]:
# simple function taking one input and one output
bl_add1 = tl.Fn("add1", lambda x0: (x0 + 1), n_out=1)
bl_add2 = tl.Fn("add2", lambda x0: (x0 + 2), n_out=1)
bl_add3 = tl.Fn("add3", lambda x0: (x0 + 3), n_out=1)
# try them out
x = np.array([1])
print(bl_add1(x), bl_add2(x), bl_add3(x))
# some information about our new layers
print(
    "name:",
    bl_add1.name,
    "number of inputs:",
    bl_add1.n_in,
    "number of outputs:",
    bl_add1.n_out,
)

[2] [3] [4]
name: add1 number of inputs: 1 number of outputs: 1


In [23]:
bl_3add1s = tl.Branch(bl_add1, bl_add2, bl_add3)
bl_3add1s

Branch_out3[
  add1
  add2
  add3
]

In [24]:
# n_in = 1, Each bl_addx pushes n_out = 1 elements onto the stack
bl_3add1s(x)

(array([2]), array([3]), array([4]))

In [25]:
# n = np.array([10]); m = np.array([20])  # n, m will remain on the stack
n = "n"
m = "m"  # n, m will remain on the stack
bl_3add1s([x, n, m]) 

(array([2]), array([3]), array([4]), 'n', 'm')

In [26]:
bl_addab = tl.Fn(
    "addab", lambda x0, x1: (x0 + x1), n_out=1
)  # Trax figures out how many inputs there are
bl_rep3x = tl.Fn(
    "add2x", lambda x0: (x0, x0, x0), n_out=3
)  # but you have to tell it how many outputs there are
bl_3ops = tl.Branch(bl_add1, bl_addab, bl_rep3x)

In [27]:
# Before Running this cell, what is the output you are expecting?
y = np.array([3])
bl_3ops([x, y, n, m])

(array([2]), array([4]), array([1]), array([1]), array([1]), 'n', 'm')

In [28]:
bl_2ops = tl.Branch(bl_add1, None)
bl_2ops([x, n, m])

(array([2]), array([1]), 'n', 'm')

# Residual Model

In [29]:
# coding=utf-8
# Copyright 2020 The Trax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Combinators for composing layers."""

from trax import fastmath
from trax.fastmath import numpy as jnp
from trax.layers import base
from trax.layers.base import Fn
from trax.shapes import ShapeDtype


class Serial(base.Layer):
  """Combinator that applies layers serially (by function composition).
  This combinator is commonly used to construct deep networks, e.g., like this::
      mlp = tl.Serial(
        tl.Dense(128),
        tl.Relu(),
        tl.Dense(10),
        tl.LogSoftmax()
      )
  A Serial combinator uses stack semantics to manage data for its sublayers.
  Each sublayer sees only the inputs it needs and returns only the outputs it
  has generated. The sublayers interact via the data stack. For instance, a
  sublayer k, following sublayer j, gets called with the data stack in the
  state left after layer j has applied. The Serial combinator then:
    - takes n_in items off the top of the stack (n_in = k.n_in) and calls
      layer k, passing those items as arguments; and
    - takes layer k's n_out return values (n_out = k.n_out) and pushes
      them onto the data stack.
  A Serial instance with no sublayers acts as a special-case (but useful)
  1-input 1-output no-op.
  """

  def __init__(self, *sublayers, name=None, sublayers_to_print=None):
    super().__init__(
        name=name, sublayers_to_print=sublayers_to_print)

    sublayers = _ensure_flat(sublayers)
    self._sublayers = sublayers
    self._n_layers = len(sublayers)

    if sublayers:
      self._n_in, self._n_out = self._n_inputs_n_outputs(sublayers)
      self._weights = tuple(None for l in sublayers)
      self._state = tuple(None for l in sublayers)

  def forward(self, xs):
    self._validate_forward_inputs(xs)
    state, weights = self.state, self.weights
    rngs = _split_rngs(self.rng, self._n_layers)
    if not self.sublayers:  # No-op: leave args unchanged.
      return xs

    stack = xs
    new_state = []
    n_layers = self._n_layers
    if len(weights) != n_layers:
      raise ValueError(
          f'Number of weight elements ({len(weights)}) does not equal '
          f'number of sublayers ({n_layers}).')
    if len(state) != n_layers:
      raise ValueError(
          f'Number of state elements ({len(state)}) does not equal '
          f'number of sublayers ({n_layers}).')

    for layer, w, s, rng in zip(self.sublayers, weights, state, rngs):
      inputs = _inputs_from_stack(layer, stack)
      outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
      stack = _outputs_onto_stack(layer, outputs, stack)
      new_state.append(s)
    self.state = new_state
    return stack

  # pylint: disable=protected-access
  def init_weights_and_state(self, input_signature):
    weights = []
    states = []
    # In the code below, stack, inputs, and outputs are abstract (shapes and
    # dtypes), but weights and states are non-abstract actual values.
    stack = input_signature
    for sublayer in self.sublayers:
      inputs = _inputs_from_stack(sublayer, stack)
      weights_or_cache_marker, state_or_cache_marker = (
          sublayer.init(inputs, use_cache=True))
      outputs, _ = sublayer._forward_abstract(inputs)
      stack = _outputs_onto_stack(sublayer, outputs, stack)

      weights.append(weights_or_cache_marker)
      states.append(state_or_cache_marker)
    self.state = states
    self.weights = weights
  # pylint: enable=protected-access

  def _n_inputs_n_outputs(self, layers):
    del self
    running_max = 0
    running_total = 0
    for layer in layers:
      running_total += layer.n_in
      running_max = max(running_max, running_total)
      running_total -= layer.n_out
    return running_max, (running_max - running_total)

  def _validate_forward_inputs(self, xs):
    if not isinstance(xs, (tuple, list)) and self._n_in != 1:
      raise TypeError(f'Serial.forward input must be a tuple or list; '
                      f'instead got {type(xs)}.')
      # TODO(jonni): Include full xs (or shape) in error message?
    len_xs = 1 if isinstance(xs, jnp.ndarray) else len(xs)
    if len_xs < self.n_in:
      raise ValueError(
          f'Number of inputs ({len(xs)}) to Serial.forward less than n_in '
          f'({self.n_in}).')


class Parallel(base.Layer):
  """Combinator that applies a list of layers in parallel to its inputs.
  Layers in the list apply to successive spans of inputs, where the spans are
  determined how many inputs each layer takes. The resulting output is the
  (flattened) concatenation of the respective layer outputs.
  For example, suppose one has three layers:
    - F: 1 input, 1 output
    - G: 3 inputs, 1 output
    - H: 2 inputs, 2 outputs (h1, h2)
  Then Parallel(F, G, H) will take 6 inputs and give 4 outputs:
    - inputs: a, b, c, d, e, f
    - outputs: F(a), G(b, c, d), h1, h2
  As an important special case, a None argument to Parallel acts as if it takes
  one argument, which it leaves unchanged. (It acts as a one-arg no-op.) For
  example:
    Parallel(None, F)
  creates a layer that passes its first input unchanged and applies F to the
  following input(s).
  """

  def __init__(self, *sublayers, name=None):
    """The constructor.
    Args:
      *sublayers: A list of sublayers.
      name: Descriptive name for this layer.
    Returns:
      A new layer in which each of the given sublayers applies to its
      corresponding span of elements in the dataflow stack.
    """
    super().__init__(name=name)
    sublayers = self._validate(sublayers)
    self._n_layers = len(sublayers)
    self._sublayers = sublayers
    self._n_in = sum(l.n_in for l in sublayers)
    self._n_out = sum(l.n_out for l in sublayers)
    self._weights = tuple(None for l in sublayers)
    self._state = tuple(None for l in sublayers)

  def forward(self, inputs):
    n_layers, layers = self._n_layers, self.sublayers
    sublayer_inputs = self._allot_to_sublayers(inputs)
    state, weights = self.state, self.weights
    rngs = _split_rngs(self.rng, n_layers)
    if len(sublayer_inputs) != n_layers:
      raise ValueError(
          f'Number of inputs for sublayers ({len(sublayer_inputs)}) does not equal '
          f'number of sublayers ({n_layers}).')
    if len(weights) != n_layers:
      raise ValueError(
          f'Number of weight elements ({len(weights)}) does not equal '
          f'number of sublayers ({n_layers}).')
    if len(state) != n_layers:
      raise ValueError(
          f'Number of state elements ({len(state)}) does not equal '
          f'number of sublayers ({n_layers}).')
    if len(rngs) != n_layers:
      raise ValueError(
          f'Number of rngs ({len(rngs)}) does not equal '
          f'number of sublayers ({n_layers}).')
    outputs = []
    new_state = []
    for layer, x, w, s, r in zip(layers, sublayer_inputs, weights, state, rngs):
      # Note that zip silently truncates its result if lengths don't match.
      sub_outputs, sub_state = layer.pure_fn(x, w, s, r, use_cache=True)
      if layer.n_out == 1:
        outputs.append(sub_outputs)
      else:
        outputs.extend(sub_outputs)
      new_state.append(sub_state)
    output = outputs[0] if self.n_out == 1 else tuple(outputs)
    self.state = tuple(new_state)
    return output

  def init_weights_and_state(self, input_signature):
    sublayer_signatures = self._allot_to_sublayers(input_signature)
    inits = [layer.init(signature, use_cache=True)
             for layer, signature
             in zip(self.sublayers, sublayer_signatures)]
    if inits:
      weights, state = tuple(zip(*inits))
      self.state = state
      self.weights = weights

  def _validate(self, layers):
    if not layers or len(layers) < 2:
      raise ValueError(
          f'layers ({layers}) must be a list with at least two elements')
    layers = list(layers)  # Ensure we can modify layers.
    for i, obj in enumerate(layers):
      if obj is None or obj == []:  # pylint: disable=g-explicit-bool-comparison
        layers[i] = Serial(None)
      elif isinstance(obj, (list, tuple)):
        layers[i] = Serial(obj)
      else:
        if not isinstance(obj, base.Layer):
          raise ValueError(
              f'Found nonlayer object ({obj}) in layers list: [{layers}]')
      if layers[i].n_in == 0:
        raise ValueError(
            f'Sublayer with n_in = 0 not allowed in Parallel: {layers[i]}')
    return layers

  def _allot_to_sublayers(self, inputs):
    """Divides Parallel's inputs for use by the sublayers.
    Args:
      inputs: Tuple of ndarrays or ShapeDtype instances.
    Returns:
      A tuple that partitions this layer's inputs among its sublayers.
      Sublayers that take one argument get that argument directly. All other
      sublayers get a tuple of items.
    """
    start, end = 0, 0
    sub_inputs = []
    for layer in self.sublayers:
      n_in = layer.n_in
      end = start + n_in
      if n_in == 1:
        sub_inputs.append(inputs[start])
      else:
        sub_inputs.append(inputs[start:end])
      start = end
    return tuple(sub_inputs)


class Concatenate(base.Layer):
  """Concatenates n tensors into a single tensor."""

  def __init__(self, n_items=2, axis=-1):
    name = 'Concatenate' if axis == -1 else f'Concatenate_axis{axis}'
    super().__init__(n_in=n_items, name=name)
    self._n_items = n_items
    self._axis = axis

  def forward(self, xs):
    return jnp.concatenate(xs, self._axis)


class Split(base.Layer):
  """Splits the input into n items along an axis."""

  def __init__(self, n_items=2, axis=-1):
    super().__init__(n_out=n_items)
    self._n_items = n_items
    self._axis = axis

  def forward(self, inputs):
    return tuple(jnp.split(inputs, self._n_items, self._axis))


def _scan(f, xs, init_value, axis=0, remat=False):
  """Scans the f over the given axis of xs.
  In pseudo-python, the scan function would look as follows:
  def scan(f, xs, init_value, axis):
    xs  = [xs[..., i, ...] for i in range(xs.shape[axis])]
    cur_value = init_value
    ys = []
    for x in xs:
      y, cur_value = f(x, cur_value)
      ys.append(y)
    return np.stack(ys, axis), cur_value
  Args:
    f: function (x, carry) -> (y, new_carry)
    xs: tensor, x will be xs slices on axis
    init_value: tensor, initial value of the carry-over
    axis: int, the axis on which to slice xs
    remat: whether to re-materialize f
  Returns:
    A pair (ys, last_value) as described above.
  """
  def swapaxes(x):
    transposed_axes = list(range(len(x.shape)))
    transposed_axes[axis] = 0
    transposed_axes[0] = axis
    return jnp.transpose(x, axes=transposed_axes)
  if axis != 0:
    xs = fastmath.nested_map(swapaxes, xs)
  def transposed_f(c, x):
    y, d = f(x, c)
    return d, y
  if remat:
    transposed_f = fastmath.remat(transposed_f)
  last_value, ys = fastmath.scan(transposed_f, init_value, xs)
  if axis != 0:
    ys = fastmath.nested_map(swapaxes, ys)
  return ys, last_value


class Scan(base.Layer):
  """Applies a layer progressively/cumulatively to an axis-derived sequence.
  Conceptually, this is a function from a list to a same-length list of partial
  (cumulative) results. For instance, a list of values (`[1, 2, 3, 4, 5]`) can
  transform to a list of cumulative sums (`[1, 3, 6, 10, 15]`). Functions for
  the same concept are called `scan` in Scala, `scanl` in Haskell, and
  `accumulate*` in Factor.
  In more detail, we assume the layer takes a tuple of inputs of the following
  form:
    (input1, ..., inputN, carry1, ..., carryM)
  and returns:
    (output1, ..., outputK, new_carry1, ..., new_carryM)
  The scanned version applies the layer iteratively to a tensor treating values
  at the given axis as if they were a list. For example, to calculate all
  sums of prefixes of a tensor, we can do this::
    def add(x, carry):
      def f(input, carry):
        res = input + carry
        return res, res  # output and carry are the same
      return tl.Fn('add', f, n_out=2)
    Scan(add)([1, 2, 3], 0) = [1, 3, 6], 6
  """

  def __init__(self, layer, axis=0, n_carry=1, remat=False):
    super().__init__(n_in=layer.n_in, n_out=layer.n_out)
    self._sublayers = [layer]
    self._n_carry = n_carry
    self._axis = axis
    self._remat = remat
    self._weights = (None,)
    self._state = (None,)

  @property
  def sublayer(self):
    """Returns the unique sublayer managed by this layer."""
    return self._sublayers[0]

  def forward(self, inputs):
    weights = self.weights[0]
    if isinstance(inputs, list):
      inputs = tuple(inputs)  # so that inputs structure matches outputs
    n_carry = self._n_carry
    def scannable_fn(x, carry_and_state):  # pylint: disable=invalid-name
      carry, state = carry_and_state
      x_and_carry = x + carry if n_carry > 0 else x
      res, new_state = self.sublayer.pure_fn(
          x_and_carry, weights, state, self.rng, use_cache=True)
      if n_carry > 0:
        return (res[:-n_carry], (res[-n_carry:], new_state))
      else:
        return (res, ([], new_state))

    if n_carry > 0:
      xs = inputs[:-n_carry]  # Split input stack into inputs and carry.
      init = (inputs[-n_carry:], self.state[0])
    else:
      xs, init = inputs, ([], self.state[0])
    ys, (carry, new_state) = _scan(scannable_fn, xs, init,
                                   axis=self._axis, remat=self._remat)
    res = ys + carry if n_carry > 0 else ys
    self.state = (new_state,)
    return res  # Put outputs and carry back on stack.

  def init_weights_and_state(self, input_signature):
    n_carry = self._n_carry
    if n_carry == 0:
      if isinstance(input_signature, (list, tuple)):
        layer_sig = [ShapeDtype(_shape_without_axis(x, self._axis), x.dtype)
                     for x in input_signature]
        layer_sig = tuple(layer_sig)
      else:
        layer_sig = ShapeDtype(_shape_without_axis(input_signature, self._axis),
                               input_signature.dtype)
      weights, state = self.sublayer.init(layer_sig)
      self.state = (state,)
      self.weights = (weights,)
    else:
      xs = input_signature[:-n_carry]
      init = input_signature[-n_carry:]
      xs_slices = [ShapeDtype(_shape_without_axis(x, self._axis), x.dtype)
                   for x in xs]
      layer_signature = tuple(xs_slices + list(init))
      weights, state = self.sublayer.init(layer_signature, use_cache=True)
      self.state = (state,)
      self.weights = (weights,)


class Cond(base.Layer):
  """Applies layers conditionally.
  For parameters `cond`, `true`, and `false` runs the equivalent of `true(y)
  if cond(x) else false(y)`, where `x` is `cond.n_in` elements from front of the
  stack and `y` is the rest of the stack.
  Exactly one of `true` and `false` functions is executed, so it can be used to
  conditionally run long computations. The state of non-executed function is not
  updated. Note that different branches may be executed on different devices
  if `cond` returns different values on them.
  `cond` must return exactly one element: a Boolean value.
  `true` and `false` must have the same n_in, and the same n_out.
  """

  def __init__(self, cond, true, false, name=None):
    super(Cond, self).__init__(name=name)

    sublayers = [cond, true, false]
    self._sublayers = sublayers
    self._n_layers = len(sublayers)
    self._cond = cond
    self._true = true
    self._false = false

    if cond.n_out != 1:
      raise ValueError(
          'cond.n_out must be 1: cond:{}->{}'.format(cond.n_in, cond.n_out))
    if true.n_in != false.n_in:
      raise ValueError(
          'true.n_in and false.n_in must be equal: true:{}->{} ; false:{}->{}'
          .format(true.n_in, true.n_out, false.n_in, false.n_out))
    if true.n_out != false.n_out:
      raise ValueError(
          'true.n_out and false.n_out must be equal: true:{}->{} ; false:{}->{}'
          .format(true.n_in, true.n_out, false.n_in, false.n_out))

    self._n_in = cond.n_in + true.n_in
    self._n_out = true.n_out
    self._weights = tuple(None for l in sublayers)
    self._state = tuple(None for l in sublayers)

  # pylint: disable=protected-access
  def init_weights_and_state(self, input_signature):
    weights = []
    states = []
    # In the code below, stack, inputs, and outputs are abstract (shapes and
    # dtypes), but weights and states are non-abstract actual values.
    stack = _make_tuple(input_signature)

    # Inputs/outputs of `cond`.
    inputs = _inputs_from_stack(self._cond, stack)
    weights_or_cache_marker, state_or_cache_marker = (
        self._cond.init(inputs, use_cache=True))
    weights.append(weights_or_cache_marker)
    states.append(state_or_cache_marker)
    self._cond._forward_abstract(inputs)
    stack = _make_tuple(_outputs_onto_stack(self._cond, [], stack))

    # Inputs/outputs of `true` and `false`.
    for sublayer in [self._true, self._false]:
      inputs = _inputs_from_stack(sublayer, stack)
      weights_or_cache_marker, state_or_cache_marker = (
          sublayer.init(inputs, use_cache=True))
      weights.append(weights_or_cache_marker)
      states.append(state_or_cache_marker)

    self.state = states
    self.weights = weights
    # pylint: enable=protected-access

  def _validate_forward_inputs(self, xs):
    xs = _make_tuple(xs)
    if len(xs) < self.n_in:
      raise ValueError(
          f'Number of inputs ({len(xs)}) to Cond.forward less than n_in '
          f'({self.n_in}).')

  def forward(self, xs):
    # TODO(jaszczur): modify; it's a copy from SkippingSerial
    self._validate_forward_inputs(xs)
    layers_state = self.state
    # Get 3 rngs, one for each layer.
    rngs = _split_rngs(self.rng, 3)

    # Prepare the stack and do some safety checks as in the parent class.
    stack = _make_tuple(xs)
    weights = self.weights
    if len(weights) != 3:
      raise ValueError('number of weights ({}) not equal to 3'
                       .format(len(weights)))
    if len(layers_state) != 3:
      raise ValueError('length of state ({}) not equal to 3'
                       .format(len(layers_state)))

    def true_func(t):
      outputs, new_true_state = self._true.pure_fn(
          t[0][0], t[1][0], t[2][0], t[3][0])
      # t[2][1] is old_false_state which is not changing if true is executed.
      return outputs, (new_true_state, t[2][1])
    def false_func(t):
      outputs, new_false_state = self._false.pure_fn(
          t[0][1], t[1][1], t[2][1], t[3][1])
      # t[2][1] is old_true_state, which is not changing if false is executed.
      return outputs, (t[2][0], new_false_state)

    cond_inputs = _inputs_from_stack(self._cond, xs)
    cond_output, s = self._cond.pure_fn(cond_inputs, self.weights[0],
                                        self.state[0], rngs[0], use_cache=True)
    stack = _outputs_onto_stack(self._cond, [], stack)
    self._cond.state = s

    outputs, both_states = fastmath.cond(
        cond_output,
        true_func,
        false_func,
        [(stack, stack),
         (self.weights[1], self.weights[2]),
         (self.state[1], self.state[2]),
         (rngs[1], rngs[2])]
    )
    stack = _outputs_onto_stack(self._cond, [], stack)

    # We don't know which (`true` or `false`) branch was run, but both of them
    # are adding (n_out) and removing (n_in) the same number of elements of the
    # stack (this was checked in __init__). _outputs_onto_stack just uses the
    # layer's n_in and n_out, so we can pass either `true` or `false` to it.
    # Note that `outputs` is the actual output of `true` or `false` branch,
    # whichever was run, and we add it to the stack in any case.
    stack = _outputs_onto_stack(self._true, outputs, stack)
    self._true.state = both_states[0]
    self._false.state = both_states[1]
    return _make_singleitem_or_original(stack)


# pylint: disable=invalid-name
def Branch(*layers, name='Branch'):
  """Combinator that applies a list of layers in parallel to copies of inputs.
  Each layer in the input list is applied to as many inputs from the stack
  as it needs, and their outputs are successively combined on stack.
  For example, suppose one has three layers:
    - F: 1 input, 1 output
    - G: 3 inputs, 1 output
    - H: 2 inputs, 2 outputs (h1, h2)
  Then Branch(F, G, H) will take 3 inputs and give 4 outputs:
    - inputs: a, b, c
    - outputs: F(a), G(a, b, c), h1, h2    where h1, h2 = H(a, b)
  As an important special case, a None argument to Branch acts as if it takes
  one argument, which it leaves unchanged. (It acts as a one-arg no-op.)
  Args:
    *layers: List of layers.
    name: Descriptive name for this layer.
  Returns:
    A branch layer built from the given sublayers.
  """
  if len(layers) == 1:
    return layers[0]
  parallel_layer = Parallel(*layers)
  indices = [list(range(layer.n_in)) for layer in parallel_layer.sublayers]
  return Serial(Select(_deep_flatten(indices)), parallel_layer,
                name=name, sublayers_to_print=layers)


def Residual(*layers, shortcut=None):
  """Wraps a series of layers with a residual connection.
  Args:
    *layers: One or more layers, to be applied in series.
    shortcut: If None (the usual case), the Residual layer computes the
        element-wise sum of the stack-top input with the output of the layer
        series. If specified, the `shortcut` layer applies to a copy of the
        inputs and (elementwise) adds its output to the output from the main
        layer series.
  Returns:
      A layer representing a residual connection paired with a layer series.
  """
  layers = _ensure_flat(layers)
  layer = layers[0] if len(layers) == 1 else Serial(layers)
  # TODO(jonni): Should we require layer.n_out = 1 and shortcut.n_out = 1?
  return Serial(
      Branch(shortcut, layer),
      Add(),  # pylint: disable=no-value-for-parameter
  )


def Select(indices, n_in=None, name=None):
  """Copies, reorders, or deletes stack elements according to `indices`.
  Args:
    indices: A list or tuple of 0-based indices to select elements relative to
        the top of the stack.
    n_in: Number of input elements to pop from the stack, and replace with
        those specified by `indices`. If not specified, its value will be
        calculated as `max(indices) + 1`.
    name: Descriptive name for this layer.
  Returns:
    Tensors, matching the number selected (`n_out = len(indices)`).
    Specifically:
      - n_out = 0: an empty tuple
      - n_out = 1: one tensor (NOT wrapped in a tuple)
      - n_out > 1: a tuple of tensors, with n_out items
  """
  if n_in is None:
    n_in = max(indices) + 1
  if name is None:
    name = f'Select{indices}'.replace(' ', '')

  def select(xs):  # pylint: disable=invalid-name
    if not isinstance(xs, (tuple, list)):
      xs = (xs,)
    selected = tuple(xs[i] for i in indices)
    return selected[0] if len(selected) == 1 else selected

  return base.PureLayer(select, n_in=n_in, n_out=len(indices), name=name)


def Drop():
  """Drops the top stack element."""
  return Fn('Drop', lambda x: (), n_out=0)


def Dup():
  """Duplicates (copies) the top element on the data stack."""
  return Fn('Dup', lambda x: (x, x), n_out=2)


def Swap():
  """Swaps the top two stack elements."""
  return Fn('Swap', lambda x0, x1: (x1, x0), n_out=2)


def SerialWithSideOutputs(layers, n_side_outputs=1):
  """Serial layer with side outputs.
  This layer makes it easier to manage the stack when layers have side outputs.
  In the simplest case of layers with n_in=1, n_out=2 and with
  n_side_outputs=1, this layer runs the following computation on x::
    side_outputs = []
    for i in range(len(layers)):
      x, side_output = layers[i](x)
      side_outputs.append(side_output)
    return [x] + side_outputs
  In the general case of layers with variable n_in and n_out and
  n_side_outputs being a list of N integers, it does the following::
    side_outputs = []
    for i in range(N):
      res = layer[i](cur_stack)  # remove n_in from stack
      cur_stack.append(res[:n_side_outputs[i]])  # put back some on stack
      side_outputs.extend(res[n_side_outputs:])
    return cur_stack + side_outputs
  Args:
    layers: a list of layers to execute
    n_side_outputs: an int or a list of ints, how many outputs of each layer
        to put aside
  Returns:
    A layer that performs the above computation.
  """
  if isinstance(n_side_outputs, int):
    n_side_outputs = [n_side_outputs] * len(layers)

  # Calculate the n_in for this layer.
  running_max = 0
  running_total = 0
  for layer, n_side_output in zip(layers, n_side_outputs):
    running_total += layer.n_in
    running_max = max(running_max, running_total)
    running_total -= layer.n_out - n_side_output
  n_in = running_max

  # Create the list of layers to run serially.
  cur_stack_size = n_in
  serial_layers = []
  for layer, n_side_output in zip(layers, n_side_outputs):
    serial_layers.append(layer)
    cur_stack_size += layer.n_out - layer.n_in
    # Indices to move n_side_outputs to the back of the stack.
    # Don't touch first n_out - n_side_outputs.
    move_back_indices = list(range(layer.n_out - n_side_output))
    # Then comes the rest of the stack that we're not moving.
    move_back_indices += [i + layer.n_out
                          for i in range(cur_stack_size - layer.n_out)]
    # Finally the indices we move.
    move_back_indices += [i + layer.n_out - n_side_output
                          for i in range(n_side_output)]
    # Swap them on stack.
    serial_layers.append(Select(move_back_indices))

  return Serial(serial_layers)


def FlattenList():
  """Flatten lists."""
  # TODO(jonni): Consider renaming layer to DeepFlatten.
  return Fn('FlattenList', lambda x: tuple(_deep_flatten(x)))


def Add():
  """Adds two tensors."""
  return Fn('Add', lambda x0, x1: x0 + x1)


def SubtractTop():
  """Subtracts the first tensor from the second."""
  return Fn('SubtractTop', lambda x0, x1: x1 - x0)


def Multiply():
  """Multiplies two tensors."""
  return Fn('Multiply', lambda x0, x1: x0 * x1)


def Gate():
  """Returns a gating layer on a (memory, gate, candidate) tuple.
  Final update is memory * gate + (1 - gate) * candidate
  This gating equation may also be referred to as Highway Network.
  Highway Networks: https://arxiv.org/abs/1505.00387
  """
  return Fn('Gate', lambda m, g, c: g * m + (1.0 - g) * c)


class Cache(base.Layer):
  """Applies a layer on the first run and returns the outputs on next calls."""

  def __init__(self, layer):
    super().__init__(n_in=layer.n_in, n_out=layer.n_out)
    self._sublayers = [layer]

  @property
  def sublayer(self):
    """Returns the unique sublayer managed by this layer."""
    return self._sublayers[0]

  @property
  def state(self):
    """Returns a tuple containing this layer's state; may be empty."""
    return self._state

  @state.setter
  def state(self, state):
    """Recursively sets state on this layer and all sublayers."""
    if isinstance(state, dict) and state == base.GET_STATE_FROM_CACHE:
      return
    self._state = state
    self.sublayer.state = state[1]

  def init_weights_and_state(self, input_signature):
    weights, layer_state = self.sublayer.init(input_signature, use_cache=True)
    self.state = ((), layer_state)
    self._weights = (weights,)

  def forward(self, inputs):
    state, weights = self.state, self.weights[0]
    if state[0] is ():  # pylint: disable=literal-comparison
      res, layer_state = self.sublayer.pure_fn(
          inputs, weights, state[1], self.rng)
      self.state = (res, layer_state)
      return res
    else:
      return state[0]


class BatchLeadingAxes(base.Layer):
  """Applies a layer after flattening all but n_last_axes_to_keep to batch.
  This can be used to make layers accept an arbitrary number of leading
  axes (dimensions) as batch. For example, a Convolution layer may normally
  only operate on tensors of shape [B, W, H, C]. In this case, the layer
      BatchLeadingAxes(Convolution(), n_last_axes_to_keep=3)
  will operate on any tensor [..., W, H, C] and treat the leading axes as batch.
  """

  def __init__(self, layer, n_last_axes_to_keep=1):
    super().__init__(n_in=layer.n_in, n_out=layer.n_out)
    self._sublayers = [layer]
    self._n_last_axes_to_keep = n_last_axes_to_keep
    self._weights = (None,)
    self._state = (None,)

  @property
  def sublayer(self):
    """Returns the unique sublayer managed by this layer."""
    return self._sublayers[0]

  def forward(self, inputs):
    batched_axes_shape = list(inputs.shape[:-self._n_last_axes_to_keep])
    batched_shape = [-1] + list(inputs.shape[-self._n_last_axes_to_keep:])
    inputs = jnp.reshape(inputs, batched_shape)
    res, layer_state = self.sublayer.pure_fn(
        inputs, self.weights[0], self.state[0], self.rng)
    self.state = (layer_state,)
    return jnp.reshape(res, batched_axes_shape + list(res.shape[1:]))

  def init_weights_and_state(self, input_signature):
    weights, layer_state = self.sublayer.init(input_signature, use_cache=True)
    self.state = (layer_state,)
    self.weights = (weights,)


# All module-private helper functions are below.
# pylint: disable=invalid-name


def _deep_flatten(items):
  """Returns a list of objects, flattening sublists/subtuples along the way.
  Example: _deep_flatten([1, (2, 3, (4, 5), [6, 7]), [[[8]]]]) would return
  the list [1, 2, 3, 4, 5, 6, 7, 8].
  Args:
    items: An iterable. If elements of this iterable are lists or tuples, they
        will be (recursively) flattened until non-list non-tuple objects are
        reached.
  Returns:
    A list of non-list, non-tuple objects.
  """
  def _flat_gen(xs):
    for x in xs:
      if isinstance(x, (list, tuple)):
        for y in _flat_gen(x):
          yield y
      else:
        yield x
  return list(_flat_gen(items))


def _ensure_sublayers(layers):
  """Ensures that elements in a layer list are layers.
  Args:
    layers: A tuple or list whose elements can each be a layer, tuple, or list,
        and so on recursively.
  Returns:
    An analogous collection of layers in which embedded layer lists are
    wrapped in Serial layer instances.
  """
  if not layers:  # None or an empty list can signal a no-op.
    return Serial(None)  # no-op, but still handles shapes and initialization
  elif isinstance(layers, (list, tuple)):
    sublayers_not_lists = []
    for layer in layers:
      sublayers_not_lists.append(
          Serial(layer) if isinstance(layer, (list, tuple)) else layer)
    return sublayers_not_lists
  else:
    raise TypeError(type(layers))


def _split_rngs(rng, n_copies):
  if rng is None:
    return (None,) * n_copies
  return fastmath.random.split(rng, n_copies)


def _inputs_from_stack(layer, stack, n_in=None):
  """Returns the correct number/format of inputs for the given layer."""
  if n_in is None:
    n_in = layer.n_in
  stack = _make_tuple(stack)
  return _make_singleitem_or_original(stack[:n_in])


def _outputs_onto_stack(layer, outputs, stack, n_in=None, n_out=None):
  """"Returns the new stack after outputs have been pushed onto it."""
  if n_in is None:
    n_in = layer.n_in
  if n_out is None:
    n_out = layer.n_out
  outputs = _make_tuple(outputs)
  stack = _make_tuple(stack)
  return _make_singleitem_or_original(outputs + stack[n_in:])


def _make_tuple(xs):
  """Returns a tuple from a list, a tuple, or a single element."""
  if isinstance(xs, (list, tuple)):
    return tuple(xs)
  else:
    return (xs,)


def _make_singleitem_or_original(xs):
  """Returns a single element if possible, or the original list/tuple if not."""
  if isinstance(xs, (list, tuple)) and len(xs) == 1:
    return xs[0]
  else:
    return xs


def _shape_without_axis(x, axis):
  return x.shape[:axis] + x.shape[axis + 1:]


def _ensure_flat(layers):
  """Ensures that layers is a single flat list of Layer instances."""
  if len(layers) == 1 and layers[0] is None:
    layers = ()
  else:
    layers = _deep_flatten(layers)
  for obj in layers:
    if not isinstance(obj, base.Layer):
      raise ValueError(
          f'Found nonlayer object ({obj}) in layers: {layers}')
  return layers

In [None]:
def MyResidual(layer):
    return tl.Serial(
        ### START CODE HERE ###
        # tl.----,
        # tl.----,
        ### END CODE HERE ###
    )

In [30]:
# Lets Try it
mr = MyResidual(bl_add1)
x = np.array([1])
mr([x, n, m])

[array([1]), 'n', 'm']

In [31]:
Fl = tl.Fn("F", lambda x0: (2 * x0), n_out=1)
Gl = tl.Fn("G", lambda x0: (10 * x0), n_out=1)
x1 = np.array([1])

In [32]:
resfg = tl.Serial(
    ### START CODE HERE ###
    # None,  #Fl    # x + F(x)
    # None,  #Gl    # x + F(x) + G(x + F(x)) etc
    # None,  #Fl
    # None,  #Gl
    ### END CODE HERE ###
)

In [33]:
# Lets try it
resfg([x1, n, m])

[array([1]), 'n', 'm']

# Reversible Residual Networks

In [None]:
refm = trax.models.reformer.ReformerLM(
    vocab_size=33000, n_layers=2, mode="train"  # Add more options.
)
refm

In [None]:
x1 = np.array([1])
x2 = np.array([5])
# Dup() duplicates the Top of Stack and returns the stack
dl = tl.Dup()
dl(x1)

In [None]:
# coding=utf-8
# Copyright 2020 The Trax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Implementations of reversible layers."""

from trax import fastmath
from trax.layers import base
from trax.layers import combinators as cb

# pylint: disable=protected-access
_inputs_from_stack = cb._inputs_from_stack
_outputs_onto_stack = cb._outputs_onto_stack
_split_rngs = cb._split_rngs
# pylint: enable=protected-access


class ReversibleLayer(base.Layer):
  """Reversible Layer."""

  def reverse(self, output, weights=(), state=(), new_state=(), rng=None):
    """Reverse this layer: compute input given output."""
    raise NotImplementedError

  def reverse_and_grad(self, output, grad, weights=(), state=(), new_state=(),
                       rng=None):
    """Backward pass: computes the inverse of a layer and propagates gradients.
    While you may choose to only implement reverse, some layers implement this
    function directly as computation may be shared between reversing and
    computing gradients.
    Args:
      output: Output activations; can be a (possibly nested) tuple.
      grad: gradient signal (cotangent) computed based on subsequent layers.
        The structure and shape must match the output.
      weights: layer weights
      state: start state
      new_state: updated state computed by the forward pass
      rng: Single-use random number generator (JAX PRNG key).
    Returns:
      A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input,
      x_grad is the gradient signal for the input, and weights_grad is the
      gradient signal for the weights.
    """
    def _do_forward(x, weights):
      old_weights, old_state, old_rng = self.weights, self.state, self._rng
      self.state, self._rng = state, rng
      self.weights = weights
      res = self.forward(x)
      self.weights, self.state, self._rng = old_weights, old_state, old_rng
      return res

    reconstructed_x = self.reverse(output, weights, state, new_state, rng)
    _, vjpfun = fastmath.vjp(_do_forward, reconstructed_x, weights)
    x_weights_grad = vjpfun(grad)
    return reconstructed_x, x_weights_grad

  @property
  def has_backward(self):
    return True

  def backward(self, inputs, output, grad, weights, state, new_state, rng):
    del inputs
    _, inputs_weights_grad = (
        self.reverse_and_grad(output, grad, weights, state, new_state, rng))
    return inputs_weights_grad


class ReversibleSwap(ReversibleLayer):
  """Swap the first two element on the stack."""

  def __init__(self):
    super().__init__(n_in=2, n_out=2)

  def forward(self, inputs):
    x0, x1 = inputs
    return x1, x0

  def reverse(self, output, weights=(), state=(), new_state=(), rng=None):
    del state, new_state, rng, weights
    # Swap is its own inverse, except that reverse doesn't return the state.
    return self.forward(output)


class ReversibleSerial(ReversibleLayer, cb.Serial):
  """A reversible version of tl.Serial (requires reversible sub-layers)."""

  def __init__(self, *layers):
    super().__init__(*layers)
  # def __init__(self, *layers):  # pylint: disable=super-init-not-called
  #   cb.Serial.__init__(self, layers)

    # Note that sublayers has already been flattened to remove nested lists.
    for i, layer in enumerate(self.sublayers):
      if not isinstance(layer, ReversibleLayer):
        raise ValueError(
            'Sub-layer {} of ReversibleSerial is not reversible: {}'.format(
                i, layer))

  def reverse(self, output, weights=(), state=(), new_state=(), rng=None):
    rngs = (None,) * self._n_layers
    if rng is not None:
      rngs = fastmath.random.split(rng, self._n_layers)

    stack = output
    for layer, p, s, ns, rng in reversed(list(zip(
        self.sublayers, weights, state, new_state, rngs))):
      layer_val = _inputs_from_stack(layer, stack, layer.n_out)
      layer_val = layer.reverse(layer_val, p, s, ns, rng=rng)
      stack = _outputs_onto_stack(
          layer, layer_val, stack, layer.n_out, layer.n_in)

    return stack

  def reverse_and_grad(self, output, grad, weights=(), state=(), new_state=(),
                       rng=None):
    rngs = (None,) * self._n_layers
    if rng is not None:
      rngs = fastmath.random.split(rng, self._n_layers)

    stack = output
    stack_grad = grad
    weights_grad = []
    for layer, p, s, ns, rng in reversed(list(zip(
        self.sublayers, weights, state, new_state, rngs))):
      layer_val = _inputs_from_stack(layer, stack, layer.n_out)
      layer_ct = _inputs_from_stack(layer, stack_grad, layer.n_out)
      layer_val, layer_ct = layer.reverse_and_grad(
          layer_val, layer_ct, p, s, ns, rng=rng)
      layer_ct, p_ct = layer_ct
      weights_grad.insert(0, p_ct)
      stack = _outputs_onto_stack(
          layer, layer_val, stack, layer.n_out, layer.n_in)
      stack_grad = _outputs_onto_stack(
          layer, layer_ct, stack_grad, layer.n_out, layer.n_in)

    return stack, (stack_grad, tuple(weights_grad))


class ReversibleHalfResidual(ReversibleLayer):
  """Half of a RevNet-style residual that optionally performs attention.
  When attention_layer is None, this layer has the signature ::
      [accumulator, *context] -> [accumulator + f(context), *context]
  The attention_layer must be an instance of EfficientAttentionBase or one of
  its subclasses (see efficient_attention.py), or None.
  Attention is special-cased for the following two reasons:
  - LSH attention needs to save bucket assignments from the forward pass to the
    backward pass, for training stability. This requires special-casing it.
  - We can call attention_layer.forward_and_or_backward to compute its output
    (needed for inverting a reversible residual layer) while simultaneously
    performing the backward pass. Sharing computation between these two
    operations improves training speed.
  """

  def __init__(self, *residual_layers, attention_layer=None):
    super().__init__()

    self.compute_residual = cb.Serial(*residual_layers)
    self.attention_layer = attention_layer

    if self.attention_layer is None:
      self._sublayers = (self.compute_residual,)
    else:
      assert hasattr(attention_layer, 'forward_and_or_backward')
      self._sublayers = (self.compute_residual, self.attention_layer)

    running_max = 0
    running_total = 0
    for layer in self._sublayers:
      running_total += layer.n_in
      running_max = max(running_max, running_total)
      running_total -= layer.n_out
    self._n_in = self._n_out = running_max + 1

  def forward(self, xs):
    rngs = _split_rngs(self.rng, len(self.sublayers))
    accumulator, *context = xs
    stack = context = tuple(context)
    new_state = []
    for layer, w, s, rng in zip(self.sublayers, self.weights, self.state, rngs):
      inputs = _inputs_from_stack(layer, stack)
      outputs, s = layer.pure_fn(inputs, w, s, rng)
      stack = _outputs_onto_stack(layer, outputs, stack)
      new_state.append(s)
    residual = stack[0] if isinstance(stack, (tuple, list)) else stack

    output = accumulator + residual
    stack = (output,) + context
    self.state = tuple(new_state)
    return stack

  def reverse(self, output, weights=(), state=(), new_state=(), rng=None):
    raise NotImplementedError('Only reverse_and_grad is actually used.')

  def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(),
                       rng=None):
    rngs = _split_rngs(rng, len(self.sublayers))

    accumulator_output, *context = output
    context = tuple(context)
    accumulator_output_ct, *context_ct = ct
    context_ct = tuple(context_ct)

    # Forward pass through self.compute_residual. Outputs that will not receive
    # a gradient signal from subsequent layers are moved to aux.
    def call_compute_residual(x, weights):
      res, _ = self.compute_residual.pure_fn(
          x, weights=weights, state=state[0], rng=rngs[0])
      if not isinstance(res, (tuple, list)):
        return res, None
      else:
        n_differentiable = 1
        if self.attention_layer is not None:
          n_differentiable = min(len(res), self.attention_layer.n_in)
        return res[:n_differentiable], res[n_differentiable:]

    stack = context
    inputs = _inputs_from_stack(self.compute_residual, stack)
    outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp(
        call_compute_residual, inputs, weights[0], has_aux=True)
    if outputs_aux is not None:
      n_differentiable_outputs = len(outputs)
      outputs = outputs + outputs_aux
    stack = _outputs_onto_stack(self.compute_residual, outputs, stack)

    stack_ct = accumulator_output_ct
    if self.attention_layer is None:
      residual = stack[0] if isinstance(stack, (tuple, list)) else stack
    else:
      inputs = _inputs_from_stack(self.attention_layer, stack)
      (residual, _, attn_inputs_ct, attn_weights_ct
      ) = self.attention_layer.forward_and_or_backward(
          inputs, weights[1], new_state[1], rngs[1],
          output_grad=accumulator_output_ct,
          compute_output=True, update_state=False)
      stack_ct = _outputs_onto_stack(
          self.attention_layer, attn_inputs_ct, stack_ct,
          self.attention_layer.n_out, self.attention_layer.n_in)

    compute_residual_ct = _inputs_from_stack(
        self.compute_residual, stack_ct, self.compute_residual.n_out)
    if outputs_aux is not None:
      if not isinstance(compute_residual_ct, (tuple, list)):
        compute_residual_ct = (compute_residual_ct,)
      compute_residual_ct = compute_residual_ct[:n_differentiable_outputs]
      assert len(compute_residual_ct) == n_differentiable_outputs
    (compute_residual_inputs_ct, compute_residual_weights_ct
    ) = compute_residual_vjpfun(compute_residual_ct)
    stack_ct = _outputs_onto_stack(
        self.compute_residual, compute_residual_inputs_ct, stack_ct,
        self.compute_residual.n_out, self.compute_residual.n_in)
    if not isinstance(stack_ct, (tuple, list)):
      stack_ct = (stack_ct,)
    stack_ct = (accumulator_output_ct,) + fastmath.nested_map_multiarg(
        lambda x, y: x+y, context_ct[:len(stack_ct)], stack_ct
        ) + context_ct[len(stack_ct):]

    reconstructed_x = accumulator_output - residual
    stack = (reconstructed_x,) + context
    if self.attention_layer is None:
      weights_ct = (compute_residual_weights_ct,)
    else:
      weights_ct = (compute_residual_weights_ct, attn_weights_ct)
    return stack, (stack_ct, weights_ct)

  # pylint: disable=protected-access
  def init_weights_and_state(self, input_signature):
    stack = input_signature[1:]
    if len(stack) == 1:
      stack = stack[0]

    inputs = _inputs_from_stack(self.compute_residual, stack)
    weights, state = self.compute_residual.init(inputs)
    outputs, _ = self.compute_residual._forward_abstract(inputs)
    stack = _outputs_onto_stack(self.compute_residual, outputs, stack)

    if self.attention_layer is None:
      self.state = (state,)
      self.weights = (weights,)
    else:
      inputs = _inputs_from_stack(self.attention_layer, stack)
      attn_weights, attn_state = self.attention_layer.init(inputs)
      self.state = (state, attn_state)
      self.weights = (weights, attn_weights)
  # pylint: enable=protected-access

In [None]:
# ReversibleSwap() duplicates the Top of Stack and returns the stack
sl = tl.ReversibleSwap()
sl([x1, x2])

In [None]:
# Demonstrate reverse swap
print(x1, x2, sl.reverse([x1, x2]))

In [None]:
Fl = tl.Fn("F", lambda x0: (2 * x0), n_out=1)
Gl = tl.Fn("G", lambda x0: (10 * x0), n_out=1)

In [None]:
half_res_F = ReversibleHalfResidual(Fl)
print(type(half_res_F), "\n", half_res_F)

In [None]:
half_res_F([x1, x1])  # this is going to produce an error - why?

In [None]:
# we have to initialize the ReversibleHalfResidual layer to let it know what the input is going to look like
half_res_F.init(shapes.signature([x1, x1]))
half_res_F([x1, x1])

# Build a reversible model

# Chatbot

In [None]:
import numpy as np
import trax
#from trax import layers as tl
#from trax.fastmath import numpy as fastnp
#from trax.supervised import training

# UNIT TEST for UNQ_C1
def test_get_conversation(target):

    data = {'file1.json': {'log':[{'text': 'hi'},
                                  {'text': 'hello'},
                                  {'text': 'nice'}]},
            'file2.json':{'log':[{'text': 'a b'}, 
                                 {'text': ''}, 
                                 {'text': 'good '}, 
                                 {'text': 'no?'}]}}
    
    res1 = target('file1.json', data)
    res2 = target('file2.json', data)
    
    expected1 = ' Person 1: hi Person 2: hello Person 1: nice'
    expected2 = ' Person 1: a b Person 2:  Person 1: good  Person 2: no?'

    success = 0
    fails = 0
    
    try:
        assert res1 == expected1
        success += 1
    except ValueError:
        print('Error in test 1 \nResult  : ', res1, 'x \nExpected: ', expected1)
        fails += 1
    try:
        assert res2 == expected2
        success += 1
    except:
        print('Error in test 2 \nResult  : ', res2, ' \nExpected: ', expected2)
        fails += 1
            
    if fails == 0:
        print("\033[92m All tests passed")
    else:
        print('\033[92m', success," Tests passed")
        print('\033[91m', fails, " Tests failed")


# UNIT TEST for UNQ_C2
def test_reversible_layer_forward(target):
    f1 = lambda x: x + 2
    g1 = lambda x: x * 3
    
    f2 = lambda x: x + 1
    g2 = lambda x: x * 2
    
    input_vector1 = np.array([1, 2, 3, 4, 5, 6, 7, 8])
    expected1 = np.array([8, 10, 12, 14, 29, 36, 43, 50])
    
    input_vector2 = np.array([1] * 128)
    expected2 = np.array([3] * 64 + [7] * 64)
    
    success = 0
    fails = 0
    try:
        res = target(input_vector1, f1, g1)
        assert isinstance(res, np.ndarray)
        success += 1
    except:
        print('Wrong type! Output is not of type np.ndarray')
        fails += 1
    try:
        res = target(input_vector1, f1, g1)
        assert np.allclose(res, expected1)
        success += 1
    except ValueError:
        print('Error in test 1 \nResult  : ', res, 'x \nExpected: ', expected1)
        fails += 1
    try:
        res = target(input_vector2, f2, g2)
        assert np.allclose(res, expected2)
        success += 1
    except:
        print('Error in test 2 \nResult  : ', res, ' \nExpected: ', expected2)
        fails += 1
            
    if fails == 0:
        print("\033[92m All tests passed")
    else:
        print('\033[92m', success," Tests passed")
        print('\033[91m', fails, " Tests failed")


# UNIT TEST for UNQ_C3
def test_reversible_layer_reverse(target):
    
    f1 = lambda x: x + 2
    g1 = lambda x: x * 3
    
    f2 = lambda x: x + 1
    g2 = lambda x: x * 2
    
    input_vector1 = np.array([1, 2, 3, 4, 5, 6, 7, 8])
    expected1 = np.array([-3,  0,  3,  6,  2,  0, -2, -4])
    
    input_vector2 = np.array([1] * 128)
    expected2 = np.array([1] * 64 + [-1] * 64)
    
    success = 0
    fails = 0
    try:
        res = target(input_vector1, f1, g1)
        assert isinstance(res, np.ndarray)
        success += 1
    except:
        print('Wrong type! Output is not of type np.ndarray')
        fails += 1
    try:
        res = target(input_vector1, f1, g1)
        assert np.allclose(res, expected1)
        success += 1
    except ValueError:
        print('Error in test 1 \nResult  : ', res, 'x \nExpected: ', expected1)
        fails += 1
    try:
        res = target(input_vector2, f2, g2)
        assert np.allclose(res, expected2)
        success += 1
    except:
        print('Error in test 2 \nResult  : ', res, ' \nExpected: ', expected2)
        fails += 1
            
    if fails == 0:
        print("\033[92m All tests passed")
    else:
        print('\033[92m', success," Tests passed")
        print('\033[91m', fails, " Tests failed")
        

# UNIT TEST for UNQ_C4
def test_ReformerLM(target):
    test_cases = [
                {
                    "name":"layer_len_check",
                    "expected":11,
                    "error":"We found {} layers in your model. It should be 11.\nCheck the LSTM stack before the dense layer"
                },
                {
                    "name":"simple_test_check",
      "expected":"Serial[ShiftRight(1)Embedding_train_512DropoutPositionalEncodingDup_out2ReversibleSerial_in2_out2[ReversibleHalfResidualV2_in2_out2[Serial[LayerNorm]SelfAttention]ReversibleSwap_in2_out2ReversibleHalfResidualV2_in2_out2[Serial[LayerNormDense_2048DropoutFastGeluDense_512Dropout]]ReversibleSwap_in2_out2ReversibleHalfResidualV2_in2_out2[Serial[LayerNorm]SelfAttention]ReversibleSwap_in2_out2ReversibleHalfResidualV2_in2_out2[Serial[LayerNormDense_2048DropoutFastGeluDense_512Dropout]]ReversibleSwap_in2_out2]Concatenate_in2LayerNormDropoutDense_trainLogSoftmax]",
                    "error":"The ReformerLM is not defined properly."
                }
            ]
    temp_model = target('train')
    
    success = 0
    fails = 0
    
    for test_case in test_cases:
        try:
            if test_case['name'] == "simple_test_check":
                assert test_case["expected"] == str(temp_model).replace(' ', '').replace('\n','')
                success += 1
            if test_case['name'] == "layer_len_check":
                if test_case["expected"] == len(temp_model.sublayers):
                    success += 1
                else:
                    print(test_case["error"].format(len(temp_model.sublayers))) 
                    fails += 1
        except:
            print(test_case['error'])
            fails += 1
            
    if fails == 0:
        print("\033[92m All tests passed")
    else:
        print('\033[92m', success," Tests passed")
        print('\033[91m', fails, " Tests failed")


# UNIT TEST for UNQ_C5
def test_tasks(train_task, eval_task):
    target = train_task
    success = 0
    fails = 0
     
    # Test the labeled data parameter for train_task
    try:
        strlabel = str(target._labeled_data)
        assert ("generator" in strlabel) and ("add_loss_weights" in  strlabel)
        success += 1
    except:
        fails += 1
        print("Wrong labeled data parameter in train_task")
    
    # Test the cross entropy loss data parameter
    try:
        strlabel = str(target._loss_layer)
        assert(strlabel == "CrossEntropyLoss_in3")
        success += 1
    except:
        fails += 1
        print("Wrong loss functions. CrossEntropyLoss_in3 was expected")
        
     # Test the optimizer parameter
    try:
        assert(isinstance(target.optimizer, trax.optimizers.adam.Adam))
        success += 1
    except:
        fails += 1
        print("Wrong optimizer")
        
    # Test the schedule parameter
    try:
        assert(isinstance(target._lr_schedule,trax.supervised.lr_schedules._BodyAndTail))
        success += 1
    except:
        fails += 1
        print("Wrong learning rate schedule type")
    
    # Test the _n_steps_per_checkpoint parameter
    try:
        assert(target._n_steps_per_checkpoint==10)
        success += 1
    except:
        fails += 1
        print("Wrong checkpoint step frequency")
        
    target = eval_task
    # Test the labeled data parameter for eval_task
    try:
        strlabel = str(target._labeled_data)
        assert ("generator" in strlabel) and ("add_loss_weights" in  strlabel)
        success += 1
    except:
        fails += 1
        print("Wrong labeled data parameter in eval_task")
    
    # Test the metrics in eval_task 
    try:
        strlabel = str(target._metrics).replace(' ', '')
        assert(strlabel == "[CrossEntropyLoss_in3,Accuracy_in3]")
        success += 1
    except:
        fails += 1
        print(f"Wrong metrics. found {strlabel} but expected [CrossEntropyLoss_in3,Accuracy_in3]")
        
        
    if fails == 0:
        print("\033[92m All tests passed")
    else:
        print('\033[92m', success," Tests passed")
        print('\033[91m', fails, " Tests failed")

# Exploring the MultiWoz dataset

In [34]:
import json
import random
import numpy as np
from termcolor import colored

import trax   
from trax import layers as tl
from trax.supervised import training
!pip list | grep trax

trax                               1.3.1


In [None]:
# filename of the MultiWOZ dialogue dataset
DATA_FILE = 'data.json'

# data directory
DATA_DIR = './data'

# dictionary where we will load the dialogue dataset
DIALOGUE_DB = {}

# vocabulary filename
VOCAB_FILE = 'en_32k.subword'

# vocabulary file directory
VOCAB_DIR = 'data/vocabs'

In [None]:
# help function to load a JSON file
def load_json(directory, file):
    with open(f'{directory}/{file}') as file: 
        db = json.load(file)
    return db

# load the dialogue data set into our dictionary
DIALOGUE_DB = load_json(DATA_DIR, DATA_FILE)

In [None]:
print(f'The number of dialogues is: {len(DIALOGUE_DB)}')

In [None]:
# print 7 keys from the dataset to see the filenames
print(list(DIALOGUE_DB.keys())[0:7]) 

In [None]:
# get keys of the fifth file in the list above
print(DIALOGUE_DB['SNG0073.json'].keys())

In [None]:
DIALOGUE_DB['SNG0073.json']['goal']

In [None]:
# get first element of the log list
DIALOGUE_DB['SNG0073.json']['log'][0]

In [None]:
print(' Person 1: ', DIALOGUE_DB['SNG0073.json']['log'][0]['text'])
print(' Person 2: ',DIALOGUE_DB['SNG0073.json']['log'][1]['text'])

In [None]:
# UNQ_C1
# GRADED FUNCTION: get_conversation
def get_conversation(file, data_db):
    '''
    Args:
        file (string): filename of the dialogue file saved as json
        data_db (dict): dialogue database
    
    Returns:
        string: A string containing the 'text' fields of  data[file]['log'][x]
    '''
    
    # initialize empty string
    result = ''
    
    # get length of file's log list
    len_msg_log = len(data_db[file]['log'])
    
    # set the delimiter strings
    delimiter_1 = ' Person 1: '
    delimiter_2 = ' Person 2: '
    
    # loop over the file's log list
    for i in range(len_msg_log):
        
    ### START CODE HERE (REPLACE INSTANCES OF 'None' WITH YOUR CODE) ###
    
        # get i'th element of file log list
        cur_log = data_db[file]['log'][i]
        
        # check if i is even
        if i%2 == 0:                   
            # append the 1st delimiter string
            result += delimiter_1
        else: 
            # append the 2nd delimiter string
            result += delimiter_2
        
        # append the message text from the log
        result += cur_log['text']
    
    ### END CODE HERE ###

    return result

In [None]:
# BEGIN UNIT TEST
import w4_unittest
w4_unittest.test_get_conversation(get_conversation)
# END UNIT TEST

In [None]:
file = 'SNG01856.json'
conversation = get_conversation(file, DIALOGUE_DB)

# print raw output
print(conversation)

In [None]:
def print_conversation(conversation):
    
    delimiter_1 = 'Person 1: '
    delimiter_2 = 'Person 2: '
    
    split_list_d1 = conversation.split(delimiter_1)
    
    for sublist in split_list_d1[1:]:
        split_list_d2 = sublist.split(delimiter_2)
        print(colored(f'Person 1: {split_list_d2[0]}', 'red'))
        
        if len(split_list_d2) > 1:
            print(colored(f'Person 2: {split_list_d2[1]}', 'green'))

            
print_conversation(conversation)

In [None]:
DIALOGUE_DB['SNG01856.json']['log'][0]

In [None]:
# this is an example of the attractions file
attraction_file = open('data/attraction_db.json')
attractions = json.load(attraction_file)
print(attractions[0])

In [None]:
# this is an example of the hospital file
hospital_file = open('data/hospital_db.json')
hospitals = json.load(hospital_file)
print(hospitals[0]) # feel free to index into other indices

In [None]:
# this is an example of the hotel file
hotel_file = open('data/hotel_db.json')
hotels = json.load(hotel_file)
print(hotels[0]) # feel free to index into other indices

In [None]:
# this is an example of the police file
police_file = open('data/police_db.json')
police = json.load(police_file)
print(police[0]) # feel free to index into other indices

In [None]:
# this is an example of a restuarant file
restaurant_file = open('data/restaurant_db.json')
restaurants = json.load(restaurant_file)
print(restaurants[0]) # feel free to index into other indices

In [None]:
with open('data/README') as file:
    print(file.read())

# Part 2: Processing the data for Reformer inputs

In [None]:
# the keys are the file names
all_files = DIALOGUE_DB.keys()

# initialize empty list
untokenized_data = []

# loop over all files
for file in all_files:
    # this is the graded function you coded
    # returns a string delimited by Person 1 and Person 2
    result = get_conversation(file, DIALOGUE_DB)
    
    # append to the list
    untokenized_data.append(result)

# print the first element to check if it's the same as the one we got before
print(untokenized_data[0])

In [None]:
# shuffle the list we generated above
random.shuffle(untokenized_data)

# define a cutoff (5% of the total length for this assignment)
# convert to int because we will use it as a list index
cut_off = int(len(untokenized_data) * .05)

# slice the list. the last elements after the cut_off value will be the eval set. the rest is for training. 
train_data, eval_data = untokenized_data[:-cut_off], untokenized_data[-cut_off:]

print(f'number of conversations in the data set: {len(untokenized_data)}')
print(f'number of conversations in train set: {len(train_data)}')
print(f'number of conversations in eval set: {len(eval_data)}')

# 2.1 Tokenizing, batching with bucketing

In [None]:
def stream(data):
    # loop over the entire data
    while True:
        # get a random element
        d = random.choice(data)
        
        # yield a tuple pair of identical values 
        # (i.e. our inputs to the model will also be our targets during training)
        yield (d, d)

In [None]:
# trax allows us to use combinators to generate our data pipeline
data_pipeline = trax.data.Serial(
    # randomize the stream
    trax.data.Shuffle(),
    
    # tokenize the data
    trax.data.Tokenize(vocab_dir=VOCAB_DIR,
                       vocab_file=VOCAB_FILE),
    
    # filter too long sequences
    trax.data.FilterByLength(2048),
    
    # bucket by length
    trax.data.BucketByLength(boundaries=[128, 256,  512, 1024],
                             batch_sizes=[16,    8,    4,   2, 1]),
    
    # add loss weights but do not add it to the padding tokens (i.e. 0)
    trax.data.AddLossWeights(id_to_mask=0)
)

# apply the data pipeline to our train and eval sets
train_stream = data_pipeline(stream(train_data))
eval_stream = data_pipeline(stream(eval_data))

In [None]:
# the stream generators will yield (input, target, weights). let's just grab the input for inspection
inp, _, _ = next(train_stream)

# print the shape. format is (batch size, token length)
print("input shape: ", inp.shape)

# detokenize the first element
print(trax.data.detokenize(inp[0], vocab_dir=VOCAB_DIR, vocab_file=VOCAB_FILE))

# Part 3: Reversible layers

In [None]:
# UNQ_C2
# GRADED FUNCTION: reversible_layer_forward
def reversible_layer_forward(x, f, g):
    """
    Args: 
        x (np.array): an input vector or matrix
        f (function): a function which operates on a vector/matrix
        g (function): a function which operates on a vector/matrix
    Returns: 
        y (np.array): an output vector or matrix whose form is determined by 'x', f and g
    """
    # split the input vector into two (* along the last axis because it is the depth dimension)
    x1, x2 = np.split(x, 2, axis=-1) 
    
    ### START CODE HERE (REPLACE INSTANCES OF 'None' WITH YOUR CODE) ###
    
    # get y1 using equation 3
    y1 = x1 + f(x2)
    
    # get y2 using equation 4
    y2 = x2 + g(y1)
    
    # concatenate y1 and y2 along the depth dimension. be sure output is of type np.ndarray
    y = np.concatenate([y1, y2], axis=-1)
    
    ### END CODE HERE ### 
    return y

In [None]:
# BEGIN UNIT TEST
w4_unittest.test_reversible_layer_forward(reversible_layer_forward)
# END UNIT TEST

In [None]:
# UNQ_C3
# GRADED FUNCTION: reversible_layer_reverse
def reversible_layer_reverse(y, f, g):
    """
    Args: 
        y (np.array): an input vector or matrix
        f (function): a function which operates on a vector/matrix of the form of 'y'
        g (function): a function which operates on a vector/matrix of the form of 'y'
    Returns: 
        y (np.array): an output vector or matrix whose form is determined by 'y', f and g
    """
    
    # split the input vector into two (* along the last axis because it is the depth dimension)
    y1, y2 = np.split(y, 2, axis=-1)
    
    ### START CODE HERE (REPLACE INSTANCES OF 'None' WITH YOUR CODE) ###
    
    # compute x2 using equation 5
    x2 = y2 - g(y1)
    
    # compute x1 using equation 6
    x1 = y1 - f(x2)
    
    # concatenate x1 and x2 along the depth dimension
    x = np.concatenate([x1, x2], axis=-1) 
    
    ### END CODE HERE ### 
    return x

In [None]:
# BEGIN UNIT TEST
w4_unittest.test_reversible_layer_reverse(reversible_layer_reverse)
# END UNIT TEST

In [None]:
# UNIT TEST COMMENT: assert at the end can be used in grading as well
f = lambda x: x + 2
g = lambda x: x * 3
input_vector = np.random.uniform(size=(32,))

output_vector = reversible_layer_forward(input_vector, f, g)
reversed_vector = reversible_layer_reverse(output_vector, f, g)

assert np.allclose(reversed_vector, input_vector)

# Reversible layers and randomness

In [None]:
# Layers like dropout have noise, so let's simulate it here:
f = lambda x: x + np.random.uniform(size=x.shape)

# See that the above doesn't work any more:
output_vector = reversible_layer_forward(input_vector, f, g)
reversed_vector = reversible_layer_reverse(output_vector, f, g)

assert not np.allclose(reversed_vector, input_vector)  # Fails!!

# It failed because the noise when reversing used a different random seed.

random_seed = 27686
rng = trax.fastmath.random.get_prng(random_seed)
f = lambda x: x + trax.fastmath.random.uniform(key=rng, shape=x.shape)

# See that it works now as the same rng is used on forward and reverse.
output_vector = reversible_layer_forward(input_vector, f, g)
reversed_vector = reversible_layer_reverse(output_vector, f, g)

assert np.allclose(reversed_vector, input_vector,  atol=1e-07)

# Part 4: ReformerLM Training

In [None]:
# UNQ_C4
# GRADED FUNCTION
def ReformerLM(vocab_size=33000, n_layers=2, mode='train', attention_type=tl.SelfAttention):
    """
    Args: 
        vocab_size (int): size of the vocabulary
        n_layers (int): number of decoder layers
        mode (string): setting of the model which can be 'train', 'eval', or 'predict' 
        attention_type(class): attention class to use 
    Returns: 
        model (ReformerLM): a reformer language model implemented in Trax
    """    
    
    ### START CODE HERE (REPLACE INSTANCES OF 'None' WITH YOUR CODE) ###
    # initialize an instance of Trax's ReformerLM class
    model = trax.models.reformer.ReformerLM( 
        # set vocab size
        vocab_size=vocab_size,
        # set number of layers
        n_layers=n_layers,
        # set mode
        mode=mode,
        # set attention type
        attention_type=attention_type
    )
    
    ### END CODE HERE ###
    return model

In [None]:
# display the model
temp_model = ReformerLM('train')
print(str(temp_model))

# free memory
del temp_model 

In [None]:
# BEGIN UNIT TEST
w4_unittest.test_ReformerLM(ReformerLM)
# END UNIT TEST

In [None]:
# UNQ_C5
# GRADED FUNCTION: train_model
def training_loop(ReformerLM, train_gen, eval_gen, output_dir = "./model/"):
    """
    Args:
        ReformerLM:  the Reformer language model you are building
        train_gen (generator): train data generator.
        eval_gen (generator): Validation generator. 
        output_dir (string): Path to save the model output. Defaults to './model/'.

    Returns:
        trax.supervised.training.Loop: Training loop for the model.
    """

    # use the warmup_and_rsqrt_decay learning rate schedule
    lr_schedule = trax.lr.warmup_and_rsqrt_decay(
        n_warmup_steps=1000, max_value=0.01)

    ### START CODE HERE (REPLACE INSTANCES OF 'None' WITH YOUR CODE) ###
    
    # define the train task
    train_task = training.TrainTask(            
        # labeled data
        labeled_data=train_gen,
        # loss layer
        loss_layer=tl.CrossEntropyLoss(),
        # optimizer
        optimizer=trax.optimizers.Adam(0.01),
        # lr_schedule
        lr_schedule=lr_schedule,
        # n_steps
        n_steps_per_checkpoint=10
    )

    # define the eval task
    eval_task = training.EvalTask(                      
        # labeled data
        labeled_data=eval_gen,
        # metrics
        metrics=[tl.CrossEntropyLoss(), tl.Accuracy()]
    )

    ### END CODE HERE ###
    loop = training.Loop(ReformerLM(mode='train'),
                         train_task,
                         eval_tasks=[eval_task],
                         output_dir=output_dir)
    return loop

In [None]:
# UNIT TEST COMMENT: Use the train task and eval task for grading train_model
test_loop = training_loop(ReformerLM, train_stream, eval_stream)
train_task = test_loop._task
eval_task = test_loop._eval_task

print(train_task)
print(eval_task)

In [None]:
# BEGIN UNIT TEST
w4_unittest.test_tasks(train_task, eval_task)
# END UNIT TEST

In [None]:
# we will now test your function
!rm -f model/model.pkl.gz
loop = training_loop(ReformerLM, train_stream, eval_stream)
loop.run(10)

# Part 5: Decode from a pretrained model

In [None]:
# define the `predict_mem_len` and `predict_drop_len` of tl.SelfAttention
def attention(*args, **kwargs):
    # number of input positions to remember in a cache when doing fast inference. 
    kwargs['predict_mem_len'] = 120
    # number of input elements to drop once the fast inference input cache fills up.
    kwargs['predict_drop_len'] = 120
    # return the attention layer with the parameters defined above
    return tl.SelfAttention(*args, **kwargs)

# define the model using the ReformerLM function you implemented earlier.
model = ReformerLM(
    vocab_size=33000,
    n_layers=6,
    mode='predict',
    attention_type=attention,
)

# define an input signature so we can initialize our model. shape will be (1, 1) and the data type is int32.
shape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32)

In [None]:
# initialize from file
model.init_from_file('chatbot_model1.pkl.gz',
                     weights_only=True, input_signature=shape11)

# save the starting state
STARTING_STATE = model.state

In [None]:
def tokenize(sentence, vocab_file, vocab_dir):
    return list(trax.data.tokenize(iter([sentence]), vocab_file=vocab_file, vocab_dir=vocab_dir))[0]

def detokenize(tokens, vocab_file, vocab_dir):
    return trax.data.detokenize(tokens, vocab_file=vocab_file, vocab_dir=vocab_dir)

In [None]:
# UNQ_C6
# GRADED FUNCTION
def ReformerLM_output_gen(ReformerLM, start_sentence, vocab_file, vocab_dir, temperature):
    """
    Args:
        ReformerLM:  the Reformer language model you just trained
        start_sentence (string): starting sentence of the conversation
        vocab_file (string): vocabulary filename
        vocab_dir (string): directory of the vocabulary file
        temperature (float): parameter for sampling ranging from 0.0 to 1.0.
            0.0: same as argmax, always pick the most probable token
            1.0: sampling from the distribution (can sometimes say random things)

    Returns:
        generator: yields the next symbol generated by the model
    """
    
    ### START CODE HERE (REPLACE INSTANCES OF 'None' WITH YOUR CODE) ###
    
    # Create input tokens using the the tokenize function
    input_tokens = tokenize(start_sentence, vocab_file=vocab_file, vocab_dir=vocab_dir)
    
    # Add batch dimension to array. Convert from (n,) to (x, n) where 
    # x is the batch size. Default is 1. (hint: you can use np.expand_dims() with axis=0)
    input_tokens_with_batch = np.array(input_tokens)[None, :]
    
    # call the autoregressive_sample_stream function from trax
    output_gen = trax.supervised.decoding.autoregressive_sample_stream( 
        # model
        ReformerLM,
        # inputs will be the tokens with batch dimension
        inputs=input_tokens_with_batch,
        # temperature
        temperature=temperature
    )
    
    ### END CODE HERE ###
    
    return output_gen

In [None]:
# BEGIN UNIT TEST
import pickle

WEIGHTS_FROM_FILE = ()

with open('weights', 'rb') as file:
    WEIGHTS_FROM_FILE = pickle.load(file)

shape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32)

def attention(*args, **kwargs):
    kwargs['predict_mem_len'] = 120
    kwargs['predict_drop_len'] = 120
    return tl.SelfAttention(*args, **kwargs)

test_model = ReformerLM(vocab_size=5, n_layers=1, mode='predict', attention_type=attention)

test_output_gen = ReformerLM_output_gen(test_model, "test", vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, temperature=0)

test_model.init_weights_and_state(shape11)

test_model.weights = WEIGHTS_FROM_FILE

output = []

for i in range(6):
    output.append(next(test_output_gen)[0])

print(output)

# free memory
del test_model 
del WEIGHTS_FROM_FILE
del test_output_gen
# END UNIT TEST

In [None]:
shape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32)

def attention(*args, **kwargs):
    kwargs['predict_mem_len'] = 120  # max length for predictions
    kwargs['predict_drop_len'] = 120  # never drop old stuff
    return tl.SelfAttention(*args, **kwargs)

model = ReformerLM(
    vocab_size=33000,
    n_layers=6,
    mode='predict',
    attention_type=attention,
)

In [None]:
model.init_from_file('chatbot_model1.pkl.gz',
                     weights_only=True, input_signature=shape11)

STARTING_STATE = model.state

In [None]:
def generate_dialogue(ReformerLM, model_state, start_sentence, vocab_file, vocab_dir, max_len, temperature):
    """
    Args:
        ReformerLM:  the Reformer language model you just trained
        model_state (np.array): initial state of the model before decoding
        start_sentence (string): starting sentence of the conversation
        vocab_file (string): vocabulary filename
        vocab_dir (string): directory of the vocabulary file
        max_len (int): maximum number of tokens to generate 
        temperature (float): parameter for sampling ranging from 0.0 to 1.0.
            0.0: same as argmax, always pick the most probable token
            1.0: sampling from the distribution (can sometimes say random things)

    Returns:
        generator: yields the next symbol generated by the model
    """  
    
    # define the delimiters we used during training
    delimiter_1 = 'Person 1: ' 
    delimiter_2 = 'Person 2: '
    
    # initialize detokenized output
    sentence = ''
    
    # token counter
    counter = 0
    
    # output tokens. we insert a ': ' for formatting
    result = [tokenize(': ', vocab_file=vocab_file, vocab_dir=vocab_dir)]
    
    # reset the model state when starting a new dialogue
    ReformerLM.state = model_state
    
    # calls the output generator implemented earlier
    output = ReformerLM_output_gen(ReformerLM, start_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, temperature=temperature)
    
    # print the starting sentence
    print(start_sentence.split(delimiter_2)[0].strip())
    
    # loop below yields the next tokens until max_len is reached. the if-elif is just for prettifying the output.
    for o in output:
        
        result.append(o)
        
        sentence = detokenize(np.concatenate(result, axis=0), vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)
        
        if sentence.endswith(delimiter_1):
            sentence = sentence.split(delimiter_1)[0]
            print(f'{delimiter_2}{sentence}')
            sentence = ''
            result.clear()
        
        elif sentence.endswith(delimiter_2):
            sentence = sentence.split(delimiter_2)[0]
            print(f'{delimiter_1}{sentence}')
            sentence = ''
            result.clear()

        counter += 1
        
        if counter > max_len:
            break   

We can now feed in different starting sentences and see how the model generates the dialogue. You can even input your own starting sentence. Just remember to ask a question that covers the topics in the Multiwoz dataset so you can generate a meaningful conversation.

In [None]:
sample_sentence = ' Person 1: Are there theatres in town? Person 2: '
generate_dialogue(ReformerLM=model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)

In [None]:
sample_sentence = ' Person 1: Is there a hospital nearby? Person 2: '
generate_dialogue(ReformerLM=model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)

In [None]:
sample_sentence = ' Person 1: Can you book a taxi? Person 2: '
generate_dialogue(ReformerLM=model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)