In [854]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax

from trax import backend
from trax import layers as tl
from trax.backend import numpy as np
from trax.layers.combinators import _pop_rng_and_split
from trax.shapes import signature, ShapeDtype

In [157]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [128]:
class pMap(nn.Module):
    """Combinator for applying a layers to a list/tuple"""
    def __init__(self, layer, n_sections=1, check_shapes=True):
        """
        
        :param layers: the layer you wish to apply to each element.
        :param n_sections: number of sections to map to. defaults to 1.
        :param check_shapes: tests to see the shapes are identical.
        :returns: new layers with mapped layer to all elements.
        """
        super(pMap, self).__init__()
        if layer is None or isinstance(layer, (list, tuple)):
            layer = nn.Sequential(*layer)
        self._layer = layer
        self._check_shapes = check_shapes
        self._n_sections = n_sections
        self.n_in = n_sections
        self.n_out = n_sections
        
        
    def forward(self, inputs, **kwargs):
        """
        Trying to replace with a basic forward pass. Trax implementation
        uses a PRNG key split into subsections zipped with the inputs var
        then returns a list of those forward passed subsections as results
        """
        
        if self._n_sections == 1:
            results = self._layer(inputs, **kwargs)
        
        else:
            
            results = [self._layer(x) for x in inputs] 
            
        return results, self._layer.state_dict()

In [248]:
pmap = pMap(layer=[
    nn.LayerNorm(1,1), 
    nn.Linear(1,1),
    nn.LogSoftmax()], n_sections=10)

pmap

pMap(
  (_layer): Sequential(
    (0): LayerNorm((1,), eps=1, elementwise_affine=True)
    (1): Linear(in_features=1, out_features=1, bias=True)
    (2): LogSoftmax()
  )
)

In [246]:
test = torch.tensor([0,1,2,3,4,5,6,7,8,9], dtype=torch.float)
print(test.unsqueeze(0).shape)
print(test)

torch.Size([1, 10])
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])


In [134]:
pmap(inputs=test.view(-1,1))

  input = module(input)


(tensor([[0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.]], grad_fn=<LogSoftmaxBackward>),
 OrderedDict([('0.weight', tensor([1.])),
              ('0.bias', tensor([0.])),
              ('1.weight', tensor([[0.4920]])),
              ('1.bias', tensor([0.9482]))]))

In [342]:
class pBroadcastedDropout(nn.Module):
    
    def __init__(self, rate=0.0, mode='train', broadcast_dims=(-2,)):
        super(pBroadcastedDropout, self).__init__()
        
        self.rate = rate
        if self.rate >= 1.0:
            raise ValueError(f'Dropout rate ({self.rate}) must be lower than 1')
        elif self.rate < 0:
            raise ValueError(f'Dropout rate ({self.rate}) must be at least 0.0')
        
        self.broadcast_dims = broadcast_dims
        self.mode = mode
        
    def forward(self, x: torch.tensor, **kwargs):
        if self.mode == 'train' and self.rate > 0.0:
            noise_shape = list(x.shape)
            
            for dim in self.broadcast_dims:
                noise_shape[dim] = 1
                
            keep_prob = 1 - self.rate
            keep = np.random.binomial(n=1, p=keep_prob, size=tuple(noise_shape))
            keep = torch.tensor(keep)
            multiplier = keep / keep_prob
            return x * multiplier
        else:
            return x

In [343]:
bd = pBroadcastedDropout(mode='train', rate=0.2)

In [370]:
bd(test.unsqueeze(0))

tensor([[ 0.0000,  1.2500,  2.5000,  3.7500,  5.0000,  6.2500,  7.5000,  8.7500,
         10.0000,  0.0000]])

In [426]:
def FeedForward(sample, 
                d_input: int, 
                d_output: int, 
                dropout: float, 
                activation: bool, 
                mode: str):
    """
    Building a simple Feed Forward NN
    :param sample: sample data used to set the LayerNorm dimensions.
    :param d_input: model input dimension.
    :param d_output: model output dimension.
    :param dropout: amount of dropout to apply on range [0.0, 1.0).
    :param activation: True applies ReLU activation.
    :param mode: whether to run dropout in train or eval mode.
    """
    
    if len(sample.shape) == 1:
        norm_output = sample.shape[0]
        norm_input = 1
    elif len(sample.shape) == 2:
        norm_output = sample.shape[1]
        norm_input = sample.shape[0]
    else:
        norm_output = sample.shape[-1]
        norm_input = sample.shape[-2]
        
    
    return nn.Sequential(
        nn.LayerNorm(normalized_shape=(norm_input, norm_output)),
        nn.Linear(d_input, d_output),
        pBroadcastedDropout(rate=dropout, mode=mode),
        activation,
        nn.Linear(d_output, d_output),
        pBroadcastedDropout(rate=dropout, mode=mode)
    )
    

In [451]:
ff = FeedForward(sample=test.view(1,-1), 
                 d_input=10, 
                 d_output=10, 
                 dropout=0.2, 
                 activation=nn.ReLU(), 
                 mode='train')
ff(test.view(1,-1))

tensor([[ 0.6268, -0.0000,  0.7972,  0.0000, -1.2683,  0.0000, -0.3375,  0.1356,
          0.0523, -0.1159]], grad_fn=<MulBackward0>)

In [707]:
class SplitForOutput(nn.Module):
    """
    Splits activations into sections, to be used prior to the output layer.
    
    After the reversible portion of the network, there is a final output portion that's 
    non-reversible where the minimum includes normalization, output projection, and log-softmax. 
    The output portion needs to operate on chucks of the sequence to avoid running out of memory
    for large vocabulary sizes.
    
    This layer concatenates the two subparts of the activations along the feature dimension
    then splits into chunks along the time dimension. 
    
    """
    
    def __init__(self, n_sections=2, axis=-2, n_in=2):
        super(SplitForOutput, self).__init__()
        self.n_sections = n_sections
        self.axis = axis
        self.n_in = 2
        self.n_out = n_sections
            
    def forward(self, inputs: torch.tensor):
                    
        x1, x2 = inputs

        x1_split = np.split(x1, self.n_sections, self.axis)
        x2_split = np.split(x2, self.n_sections, self.axis)

        res = [np.concatenate(ys, -1) for ys in zip(x1_split, x2_split)]
        return tuple(res)

    def reverse(self, output, **kwargs):
        
        x1_split = []
        x2_split = []
        for y in output:
            y1, y2 = np.split(y, 2, -1)
            x1_split.append(y1)
            x2_split.append(y2)

        x1 = np.concatenate(x1_split, self.axis)
        x2 = np.concatenate(x2_split, self.axis)

        return (x1, x2)
        

In [823]:
def Chunk(x, n_sections=2):
    assert x.shape[1] % n_sections == 0
    
    return torch.cat(torch.chunk(x, chunks=n_sections, dim=-2))

def Unchunk(x, n_sections=2):
    assert x.shape[0] % n_sections == 0
    
    return torch.cat(torch.chunk(x, chunks=n_sections, dim=-3), dim=-2)

In [887]:
class ReversibleHalfResidual(nn.Module):
    
    def __init__(self, residual_layers):
        
        x = ([], (x, x))
        
        self.compute_residual = nn.Sequential(
            # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
            
        )

In [888]:
residual_layers = tl.Serial(tl.Dense(10))

compute_residual = tl.Serial(
        # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
        tl.Parallel([], tl.Dup()),
        tl.Swap(),
        tl.Parallel(residual_layers, [], []))
    

In [876]:
t = np.arange(-7, 9).astype(np.float32)
t

array([-7., -6., -5., -4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.,  5.,
        6.,  7.,  8.], dtype=float32)

In [899]:
x = t.reshape(4, -1)
print(x)
y = tls(tuple(x))

[[-7. -6. -5. -4.]
 [-3. -2. -1.  0.]
 [ 1.  2.  3.  4.]
 [ 5.  6.  7.  8.]]


LayerError: Exception passing through layer Serial (in _forward_internal):
  layer created in file [...]/<ipython-input-858-6edb9536a799>, line 3
  layer input shapes: (ShapeDtype{shape:(4,), dtype:float32}, ShapeDtype{shape:(4,), dtype:float32}, ShapeDtype{shape:(4,), dtype:float32}, ShapeDtype{shape:(4,), dtype:float32})

  File [...]/trax/layers/combinators.py, line 75, in forward_with_state
    '({})'.format(len(state), n_layers))

ValueError: length of state (0) not equal to number of layers (2)

In [847]:
tls.forward(tuple(t), None)

NotImplementedError: 

In [708]:
sfo = SplitForOutput()

In [709]:
test = torch.rand((4,4))
a, b = sfo.forward((test.unsqueeze(0),test.unsqueeze(0)))
c, d = sfo.reverse((test.unsqueeze(0),test.unsqueeze(0)))

In [94]:
class Map(tl.Layer):
  """Combinator for applying a layer to a list or tuple."""

  def __init__(self, layer, n_sections=1, check_shapes=True):
    """Initialize the combinator.
    Args:
      layer: a layer to apply to each element.
      n_sections: how many sections to map to (default: 1).
      check_shapes: whether to check that shapes are identical (default: true).
    Returns:
      A new layer representing mapping layer to all elements of the input.
    """
    super(Map, self).__init__()#n_in=n_sections, n_out=n_sections)
    if layer is None or isinstance(layer, (list, tuple)):
        layer = tl.Serial(layer)
    self._layer = layer
    
    # Generally a Map should be applied to lists where all elements have
    # the same shape -- because self._layer will only be initialized once
    # and it could have different parameters for different shapes. But there
    # are valid cases -- e.g., when self._layer has no parameters -- where we
    # can apply Map to different shapes -- set check_shapes=False in such cases.
    self._check_shapes = check_shapes
    self._n_sections = n_sections
    self.n_in = n_sections
    self.n_out = n_sections

    def forward_with_state(self, inputs, weights=(), state=(), **kwargs):
      if self._n_sections == 1:
        results = self._layer(inputs, weights=weights, state=state, **kwargs)
      else:
        rngs = _pop_rng_and_split(kwargs, len(inputs))
        results = [self._layer(x, weights=weights, state=state, rng=r, **kwargs)
                 for x, r in zip(inputs, rngs)]
        results = tuple(results)
      # TODO(kitaev): think about how to merge state across copies in the map.
      return results, self._layer.state

  def new_weights_and_state(self, input_signature):
    if self._n_sections == 1:
      return self._layer.init(input_signature)
    first_shape = input_signature[0].shape
    if self._check_shapes:
      for shape_dtype in input_signature:
        if shape_dtype.shape != first_shape:
          raise ValueError('Map layer can only be applied to list of elements '
                           'with the same shapes. This shape %s vs first shape '
                           '%s.' % (str(shape_dtype.shape), str(first_shape)))
    return self._layer.init(input_signature[0])

  @tl.Layer.weights.setter
  def weights(self, weights):
    self._weights = self._layer.weights = weights

  @tl.Layer.state.setter
  def state(self, state):
    self._state = self._layer.state = state

  def _set_input_signature_recursive(self, input_signature):
    self._input_signature = input_signature
    self._layer._set_input_signature_recursive(input_signature)  # pylint: disable=protected-access


In [164]:
class BroadcastedDropout(tl.Layer):
  """Layer constructor function for a broadcasted dropout layer."""

  def __init__(self, rate=0.0, mode='train', broadcast_dims=(-2,)):
    super(BroadcastedDropout, self).__init__()
    self._rate = rate
    if self._rate >= 1.0:
      raise ValueError('Dropout rate (%f) must be lower than 1.' % rate)
    self._broadcast_dims = broadcast_dims
    self._mode = mode

  def forward_with_state(self, x, weights, state, rng):
    """Dropout, with broadcasting to save memory."""
    del weights
    if rng is None:
      raise ValueError('BroadcastedDropout requires rng kwarg.')
    if self._mode == 'train' and self._rate > 0.0:
      noise_shape = list(x.shape)
      for dim in self._broadcast_dims:
        noise_shape[dim] = 1
      keep_prob = jax.lax.tie_in(rng, 1.0 - self._rate)
      keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape))
      multiplier = keep.astype(x.dtype) / jax.lax.tie_in(keep, keep_prob)
      return x * multiplier, state
    else:
      return x, state

In [895]:
def FeedForward(d_model, d_ff, dropout, activation, mode):
  """Feed-forward block with layer normalization at start."""
  return [
      tl.LayerNorm(),
      tl.Dense(d_ff),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      activation(),
      tl.Dense(d_model),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
  ]


In [97]:
class SplitForOutput(tl.ReversibleLayer):
  """Splits activations into sections (for use right before the output layer).
  After the reversible portion of the network, there is a final output portion
  that's non-reversible (which at minimum includes normalization, output
  projection, and log-softmax). The output portion needs to operate on chunks
  of the sequence to avoid running out of memory for large vocabulary sizes.
  This layer concatenates the two subparts of the activations along the feature
  dimension, and then splits into chunks along the time dimension. We implement
  it is a subclass of tl.ReversibleLayer because we want to ensure that multiple
  copies of the activations don't exist simultaneously except in the middle of a
  memory copy operation.
  """

  def __init__(self, n_sections=2, axis=-2):
    super(SplitForOutput, self).__init__(n_in=2, n_out=n_sections)
    self._n_sections = n_sections
    self._axis = axis

  def forward(self, inputs, weights):
    del weights
    x1, x2 = inputs

    x1_split = np.split(x1, self._n_sections, self._axis)
    x2_split = np.split(x2, self._n_sections, self._axis)

    res = [np.concatenate(ys, -1) for ys in zip(x1_split, x2_split)]
    return tuple(res)

  def reverse(self, output, weights=(), state=(), new_state=(), **kwargs):
    del weights, kwargs

    x1_split = []
    x2_split = []
    for y in output:
      y1, y2 = np.split(y, 2, -1)
      x1_split.append(y1)
      x2_split.append(y2)

    x1 = np.concatenate(x1_split, self._axis)
    x2 = np.concatenate(x2_split, self._axis)

    return (x1, x2)

  def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(),
                       **kwargs):
    del weights, kwargs
    return self.reverse(output), (self.reverse(ct), ())


In [98]:
@tl.layer()
def Chunk(x, weights, n_sections=2, **kwargs):
  del weights, kwargs
  assert x.shape[1] % n_sections == 0
  return np.reshape(x, (
      x.shape[0] * n_sections,
      x.shape[1] // n_sections,
      ) + x.shape[2:])


@tl.layer()
def Unchunk(x, weights, n_sections=2, **kwargs):
  del weights, kwargs
  assert x.shape[0] % n_sections == 0
  return np.reshape(x, (
      x.shape[0] // n_sections,
      x.shape[1] * n_sections,
      ) + x.shape[2:])


In [894]:

class ReversibleHalfResidual(tl.ReversibleLayer, tl.Serial):
  """Half of a RevNet-style residual (only updates part of the hidden state)."""

  def __init__(self, residual_layers):
    self.compute_residual = tl.Serial(
        # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
        tl.Parallel([], tl.Dup()),
        tl.Swap(),
        tl.Parallel(residual_layers, [], []),
    )

    layers = [
        self.compute_residual,
        tl.Parallel(tl.Add(), [])
    ]
    super(ReversibleHalfResidual, self).__init__(layers)

    self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
    self.reverse_layers = [self.compute_residual, self.subtract_top]

  def reverse(self, output, weights=(), state=(), new_state=(), **kwargs):
    reconstructed_x = output
    rng = kwargs.pop('rng', None)
    rngs = (None,) * self._n_layers
    if rng is not None:
      rngs = backend.random.split(rng, self._n_layers)
    # Note that self.sublayers aligns exactly with self.reverse_layers in
    # terms of parameter and rng usage, so no re-ordering is required.
    for layer, p, s, ns, rng in zip(
        self.reverse_layers, weights, state, new_state, rngs):
      reconstructed_x = layer(reconstructed_x, weights=p,
                              state=s, new_state=ns, rng=rng, **kwargs)
    return reconstructed_x

  def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(),
                       **kwargs):
    rng = kwargs.pop('rng', None)
    rngs = (None,) * self._n_layers
    if rng is not None:
      rngs = backend.random.split(rng, self._n_layers)

    def call_compute_residual(x, weights):
      res = self.compute_residual(x, weights=weights, state=state[0],
                                  rng=rngs[0], **kwargs)
      return res

    assert len(ct) == 2
    ct = ((ct[0], ct[0], ct[1]))

    stack_with_residual, vjpfun = jax.vjp(
        call_compute_residual, output, weights[0])
    reconstructed_x = self.subtract_top(
        stack_with_residual, weights=weights[-1], state=state[-1], rng=rngs[-1],
        **kwargs)

    x_ct, residual_weights_ct = vjpfun(ct)
    assert not jax.tree_util.tree_leaves(weights[-1])
    add_top_weights_ct = weights[-1]
    return reconstructed_x, (x_ct, [residual_weights_ct, add_top_weights_ct])

In [100]:

class ApplyAttentionWrapper(tl.Parallel):
  """Like tl.Parallel(attention, [], []) but implements forward_and_backward."""

  def __init__(self, attention):
    assert hasattr(attention, 'forward_and_backward')
    super(ApplyAttentionWrapper, self).__init__(attention, [], [])
    self.attention = attention

  def forward_and_backward(self, inputs, ct, state, new_state, rng=None,
                           **kwargs):
    # Simultaneous forward pass and backprop through the attention mechanism.
    qkv = inputs[:3]
    passthrough = inputs[3:]
    out_ct = ct[0]
    passthrough_ct = ct[1:]
    if rng is not None:
      # Adjust RNG to match the forward pass.
      rng = backend.random.split(rng, self._n_layers)[0]

    out, qkv_ct = self.attention.forward_and_backward(
        qkv, out_ct, rng=rng, state=state[0], new_state=new_state[0], **kwargs)
    return (out,) + passthrough, qkv_ct + passthrough_ct



In [101]:

class ReversibleAttentionHalfResidual(tl.ReversibleLayer, tl.Serial):
  """Half of a RevNet-style residual that performs attention.
  If inputs are (x1, x2), then outputs are (x1 + z, x2) where:
  z = post_attention(attention(pre_attention(x1)))
  Other than an efficiency optimization, this layer is equivalent to
  ReversibleHalfResidual([pre_attention, attention, post_attention]).
  The post_attention layers must be linear in their input (typically they will
  consists of reshaping and dense linear layers), which allows the following
  optimization. We can back-propagate the gradient signal from the output of
  ReversibleAttentionHalfResidual to the output of the "attention" portion based
  only on the network parameters. Then, attention.forward_and_backward can be
  used to recover the output of the "attention" portion while simultaneously
  performing the backward pass, which allows shared computation between the two
  directions.
  """

  def __init__(self, pre_attention, attention, post_attention):
    self.pre_attention = tl.Serial(
        # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
        tl.Parallel([], tl.Dup()),
        tl.Swap(),
        tl.Parallel(pre_attention, [], []),
    )
    assert hasattr(attention, 'forward_and_backward')
    self.attention = ApplyAttentionWrapper(attention)
    self.post_attention = tl.Parallel(post_attention, [], [])

    layers = [
        self.pre_attention,
        self.attention,
        self.post_attention,
        tl.Parallel(tl.Add(), []),
    ]
    super(ReversibleAttentionHalfResidual, self).__init__(layers)

    self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
    self.reverse_layers = [
        self.pre_attention,
        self.attention,
        self.post_attention,
        self.subtract_top,
    ]

  def reverse(self, output, weights=(), state=(), new_state=(), **kwargs):
    rng = kwargs.pop('rng', None)
    rngs = (None,) * self._n_layers
    if rng is not None:
      rngs = backend.random.split(rng, self._n_layers)

    reconstructed_x = output
    # Note that self.sublayers aligns exactly with self.reverse_layers in
    # terms of parameter and rng usage, so no re-ordering is required.
    for layer, p, s, ns, rng in zip(self.reverse_layers, weights,
                                    state, new_state, rngs):
      reconstructed_x = layer.reverse(reconstructed_x, weights=p,
                                      state=s, new_state=ns, rng=rng, **kwargs)
    return reconstructed_x

  def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(),
                       **kwargs):
    rng = kwargs.pop('rng', None)
    rngs = (None,) * self._n_layers
    if rng is not None:
      rngs = backend.random.split(rng, self._n_layers)

    # Forward pass through self.pre_attention, while preparing for
    # later backprop.
    def call_pre_attention(x, weights):
      res = self.pre_attention(x, weights=weights, state=state[0], rng=rngs[0],
                               **kwargs)
      return res
    stack, pre_attention_vjpfun = jax.vjp(call_pre_attention,
                                          output, weights[0])

    # Backprop through adding the residual
    assert len(ct) == 2
    ct = saved_ct = (ct[0], ct[0], ct[1])

    # Backprop through self.post_attention with respect to the inputs only
    def call_post_attention(x):
      res = self.post_attention(x, weights=weights[2], state=state[2],
                                rng=rngs[2], **kwargs)
      return res
    # Note: these are *not* the actual inputs to self.post_attention.
    # If self.post_attention is not linear, we will get incorrect gradients.
    dummy_inputs = (stack[-3], stack[-2], stack[-1])
    _, post_attention_vjpfun = jax.vjp(call_post_attention, dummy_inputs)
    (ct,) = post_attention_vjpfun(ct)

    # Simultaneous forward pass and backprop through the attention mechanism
    stack, ct = self.attention.forward_and_backward(
        stack, ct, rng=rngs[1], state=state[1], new_state=new_state[1],
        **kwargs)
    assert not jax.tree_util.tree_leaves(weights[1])
    attention_weights_ct = weights[1]  # This is valid when weights is empty.

    # Backprop through self.pre_attention
    x_ct, pre_attention_weights_ct = pre_attention_vjpfun(ct)

    # Forward pass for self.post_attention, and backprop with respect to the
    # parameters only
    def call_post_attention2(weights):
      res = self.post_attention(stack, weights=weights, state=state[2],
                                rng=rngs[2], **kwargs)
      return res
    stack, post_attention_vjpfun = jax.vjp(call_post_attention2, weights[2])
    (post_attention_weights_ct,) = post_attention_vjpfun(saved_ct)

    # Forward pass through subtracting the residual
    reconstructed_x = self.subtract_top(
        stack, weights=weights[-1], state=state[-1], rng=rngs[-1], **kwargs)

    assert not jax.tree_util.tree_leaves(weights[-1])
    add_top_weights_ct = weights[-1]
    weights_ct = [
        pre_attention_weights_ct,
        attention_weights_ct,
        post_attention_weights_ct,
        add_top_weights_ct,
    ]

    return reconstructed_x, (x_ct, weights_ct)

In [102]:

def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,
                 n_heads, n_attention_chunks, attention_type,
                 dropout, share_qk, ff_activation, ff_use_sru, mode):
  """Reversible transformer decoder layer.
  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    n_attention_chunks: int: number of chunks for attention
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    share_qk: string, whether to share queries and keys
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    mode: str: 'train' or 'eval'
  Returns:
    the layer.
  """
  if share_qk:
    pre_attention = [
        Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(),
        tl.Dup(),
        tl.Parallel(
            tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key),
            tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value),
        ),
        tl.Dup(),
    ]
  else:
    pre_attention = [
        Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(),
        tl.Dup(), tl.Dup(),
        tl.Parallel(
            tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key),
            tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key),
            tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value),
        ),
    ]

  attention = attention_type(mode=mode)

  # ReversibleAttentionHalfResidual requires that post_attention be linear in
  # its input (so the backward pass can be computed without knowing the input)
  post_attention = [
      tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model),
      Unchunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
  ]

  if ff_use_sru:
    feed_forward = [tl.SRU(d_model) for _ in range(ff_use_sru)]
  else:
    feed_forward = [FeedForward(d_model, d_ff, dropout, ff_activation, mode)]

  return [
      ReversibleAttentionHalfResidual(pre_attention, attention, post_attention),
      tl.ReversibleSwap(),
      ReversibleHalfResidual(feed_forward),
      tl.ReversibleSwap(),
  ]


In [103]:

def ReformerLM(vocab_size,
               d_model=512,
               d_ff=2048,
               d_attention_key=64,
               d_attention_value=64,
               n_layers=6,
               n_heads=8,
               dropout=0.1,
               max_len=2048,
               n_chunks=0,
               n_attention_chunks=1,
               attention_type=tl.DotProductCausalAttention,
               share_qk=False,
               axial_pos_shape=(),
               d_axial_pos_embs=None,
               ff_activation=tl.FastGelu,
               ff_use_sru=0,
               mode='train'):
  """Reversible transformer language model (only uses a decoder, no encoder).
  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    mode: str: 'train', 'eval', or 'predict'
  Returns:
    the layer.
  """
  if n_chunks == 0:
    n_chunks = 1
    concatenate_input_chunks = []
  else:
    concatenate_input_chunks = tl.Concatenate(n_items=n_chunks)

  if not axial_pos_shape:
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, mode=mode)
  else:
    assert d_axial_pos_embs is not None
    positional_encoding = tl.AxialPositionalEncoding(
        shape=axial_pos_shape, d_embs=d_axial_pos_embs,
        dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
        dropout=dropout, mode=mode)

  positional_embedder = [
      tl.Embedding(d_model, vocab_size),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      positional_encoding,
  ]

  decoder_blocks = []

  if isinstance(attention_type, (tuple, list)):
    assert n_layers % len(attention_type) == 0
  else:
    attention_type = [attention_type]
  for layer_idx in range(n_layers):
    layer_attention_type = attention_type[layer_idx % len(attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        n_attention_chunks,
        attention_type=layer_attention_type,
        dropout=dropout,
        share_qk=(share_qk or issubclass(layer_attention_type,
                                         tl.LSHCausalAttention)),
        ff_activation=ff_activation,
        ff_use_sru=ff_use_sru,
        mode=mode)
    decoder_blocks.append(decoder_block)

  return tl.Serial(
      concatenate_input_chunks,
      tl.ShiftRight(mode=mode),
      positional_embedder,
      tl.Dup(),
      tl.ReversibleSerial(decoder_blocks + [
          SplitForOutput(n_sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
      ]),
      Map([
          # TODO(kitaev): Test whether dropout should go before or after the
          # LayerNorm, and whether dropout broadcasting is needed here.
          tl.LayerNorm(),
          BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
          tl.Dense(vocab_size),
          tl.LogSoftmax(),
      ], n_sections=n_chunks),
  )


In [104]:

def ReformerShortenLM(vocab_size,
                      shorten_factor=1,
                      d_embedding=256,
                      d_model=512,
                      d_ff=2048,
                      d_attention_key=64,
                      d_attention_value=64,
                      n_layers=6,
                      n_heads=8,
                      dropout=0.1,
                      max_len=2048,
                      n_attention_chunks=1,
                      attention_type=tl.DotProductCausalAttention,
                      share_qk=False,
                      axial_pos_shape=(),
                      d_axial_pos_embs=None,
                      ff_activation=tl.FastGelu,
                      ff_use_sru=0,
                      mode='train'):
  """Reversible transformer language model with shortening.
  When shorten_factor is F and processing an input of shape [batch, length],
  we embed the (shifted-right) input and then group each F elements (on length)
  into a single vector -- so that in the end we process a tensor of shape
    [batch, length // F, d_model]
  almost until the end -- at the end it's un-shortend and a SRU is applied.
  This reduces the length processed inside the main model body, effectively
  making the model faster but possibly slightly less accurate.
  Args:
    vocab_size: int: vocab size
    shorten_factor: by how much to shorten, see above
    d_embedding: the depth of the embedding layer and final logits
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, values must sum to d_embedding.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    mode: str: 'train' or 'eval'
  Returns:
    the layer.
  """
  assert mode != 'predict'  # TODO(lukaszkaiser,kitaev): fast inference

  if not axial_pos_shape:
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, mode=mode)
  else:
    assert d_axial_pos_embs is not None
    positional_encoding = tl.AxialPositionalEncoding(
        shape=axial_pos_shape, d_embs=d_axial_pos_embs,
        dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
        dropout=dropout, mode=mode)

  positional_embedder = [
      tl.Embedding(d_embedding, vocab_size),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      positional_encoding,
  ]

  decoder_blocks = []

  if isinstance(attention_type, (tuple, list)):
    assert n_layers % len(attention_type) == 0
  else:
    attention_type = [attention_type]
  for layer_idx in range(n_layers):
    layer_attention_type = attention_type[layer_idx % len(attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        n_attention_chunks,
        attention_type=layer_attention_type,
        dropout=dropout,
        share_qk=(share_qk or issubclass(layer_attention_type,
                                         tl.LSHCausalAttention)),
        ff_activation=ff_activation,
        ff_use_sru=ff_use_sru,
        mode=mode)
    decoder_blocks.append(decoder_block)

  # pylint: disable=g-long-lambda
  return tl.Serial(
      tl.ShiftRight(),
      positional_embedder,
      tl.Dup(),              # Stack has (x, x), the first will be shortened
      # Before shortening, we need to pad by shorten factor so as not to leak
      # information into the future. To understand why, imagine shorten factor
      # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we
      # would have 0ABC, which gets grouped to [0A][BC] on input, which is
      # predicting ABCD as targets. The problem is that [0A] has access to A
      # and [BC] has access to C -- it will learn to copy it, peek into
      # the future. Shifting twice to [00][AB] solves the problem as the first
      # "big" symbol becomes all-0 and the rest is shifted enough.
      tl.ShiftRight(n_shifts=shorten_factor - 1),
      tl.Fn(lambda x: np.reshape(  # Shorten -- move to depth.
          x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1),
      tl.Dense(d_model),
      tl.Dup(),  # Stack has (short_x, short_x, x)
      tl.ReversibleSerial(decoder_blocks),
      tl.Select([0], n_in=2),
      tl.LayerNorm(),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      tl.Dense(shorten_factor * d_embedding),
      tl.Fn(lambda x: np.reshape(  # Prolong back.
          x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1),
      tl.Concatenate(),  # Concatenate with just the embeddings.
      tl.CausalConv(d_embedding),
      tl.Relu(),
      tl.SRU(d_embedding),  # One RNN layer for conditional dependence.
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  )