In [524]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import jax
import trax
from trax import backend
from trax import layers as tl
from trax.backend import numpy as np
from trax.layers.combinators import _pop_rng_and_split
from trax.shapes import signature, ShapeDtype

In [74]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [75]:
from trax.models.research.reformer import ReformerLM

In [658]:
vocab_size = 1000
d_model=512,
d_ff=2048,
d_attention_key=64,
d_attention_value=64,
n_layers=6,
n_heads=8,
dropout=0.1,
max_len=2048,
n_chunks=0,
n_attention_chunks=1,
attention_type=tl.DotProductCausalAttention,
share_qk=False,
axial_pos_shape=(),
d_axial_pos_embs=None,
ff_activation=tl.FastGelu,
ff_use_sru=0,
mode='train'
r = ReformerLM(vocab_size,
               d_model=512,
               d_ff=2048,
               d_attention_key=64,
               d_attention_value=64,
               n_layers=6,
               n_heads=8,
               dropout=0.1,
               max_len=2048,
               n_chunks=0,
               n_attention_chunks=1,
               attention_type=tl.DotProductCausalAttention,
               share_qk=False,
               axial_pos_shape=(),
               d_axial_pos_embs=None,
               ff_activation=tl.FastGelu,
               ff_use_sru=0,
               mode='train')

In [77]:
test = np.array([[0,1,2,3,4,5,6,7]])

In [78]:
weights, state = r.init(input_signature)
rng = backend.random.get_prng(0)
inputs = (test, )*2
output = r(inputs, weights=weights, state=state, rng=rng)
dummy_loss = backend.numpy.sum(output[0])
#r.forward_with_state(xs=(test,), weights=(1,)*r._n_layers, state=(1,)*r._n_layers)

In [79]:
output

(DeviceArray([[[-8.335339 , -8.690216 , -5.8988566, ..., -8.088729 ,
                -8.524491 , -7.8314757],
               [-8.316006 , -8.7872925, -6.1046233, ..., -8.088554 ,
                -8.654272 , -7.870216 ],
               [-8.455892 , -8.115887 , -5.906452 , ..., -8.373974 ,
                -8.440801 , -7.525086 ],
               ...,
               [-8.354822 , -8.8492565, -6.0343447, ..., -6.980222 ,
                -9.428055 , -6.6733527],
               [-8.75255  , -9.481588 , -6.420909 , ..., -7.0931687,
                -8.65214  , -5.918196 ],
               [-8.388021 , -8.544396 , -6.143963 , ..., -7.09911  ,
                -7.7136116, -6.223985 ]]], dtype=float32),
 array([[0, 1, 2, 3, 4, 5, 6, 7]]))

In [80]:
vocab_size = 16
input_sd = ShapeDtype((1, 8), np.int32)
input_signature = (input_sd, input_sd)
model = ReformerLM(
    vocab_size, d_model=32, d_ff=64,
    d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2,
    max_len=16, n_chunks=2, n_attention_chunks=1)
final_shape = tl.check_shape_agreement(
    model, input_signature)
assert ((1, 8, 16), (1, 8, 16)) == final_shape
final_shape

((1, 8, 16), (1, 8, 16))

# Map

In [87]:
class pMap(nn.Module):
    """Combinator for applying a layers to a list/tuple"""
    def __init__(self, layer, n_sections=1, check_shapes=True):
        """
        
        :param layers: the layer you wish to apply to each element.
        :param n_sections: number of sections to map to. defaults to 1.
        :param check_shapes: tests to see the shapes are identical.
        :returns: new layers with mapped layer to all elements.
        """
        super(pMap, self).__init__()
        if layer is None or isinstance(layer, (list, tuple)):
            layer = nn.Sequential(*layer)
        self.layer = layer
        self.check_shapes = check_shapes
        self.n_sections = n_sections
        self.n_in = n_sections
        self.n_out = n_sections
        
        
    def forward(self, inputs, **kwargs):
        """
        Trying to replace with a basic forward pass. Trax implementation
        uses a PRNG key split into subsections zipped with the inputs var
        then returns a list of those forward passed subsections as results
        """
            
        if self.n_sections == 1:
            results = self.layer(inputs, **kwargs)
        
        else:
            results = [self.layer(x) for x in inputs] 
            
        return results, self.layer.state_dict()

In [88]:
pmap = pMap(layer=[
    nn.LayerNorm(1,1), 
    nn.Linear(1,1),
    nn.LogSoftmax()], n_sections=10)

pmap

pMap(
  (layer): Sequential(
    (0): LayerNorm((1,), eps=1, elementwise_affine=True)
    (1): Linear(in_features=1, out_features=1, bias=True)
    (2): LogSoftmax()
  )
)

In [89]:
test = torch.tensor([0,1,2,3,4,5,6,7,8,9], dtype=torch.float)
print(test.unsqueeze(0).shape)
print(test)

torch.Size([1, 10])
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])


In [90]:
pmap(inputs=test.view(-1,1))

  input = module(input)


([tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>)],
 OrderedDict([('0.weight', tensor([1.])),
              ('0.bias', tensor([0.])),
              ('1.weight', tensor([[0.8039]])),
              ('1.bias', tensor([0.5340]))]))

# BroadcastedDropout

In [91]:
class pBroadcastedDropout(nn.Module):
    
    def __init__(self, rate=0.0, mode='train', broadcast_dims=(-2,)):
        super(pBroadcastedDropout, self).__init__()
        
        self.rate = rate
        if self.rate >= 1.0:
            raise ValueError(f'Dropout rate ({self.rate}) must be lower than 1')
        elif self.rate < 0:
            raise ValueError(f'Dropout rate ({self.rate}) must be at least 0.0')
        
        self.broadcast_dims = broadcast_dims
        self.mode = mode
        
    def forward(self, x: torch.tensor, **kwargs):
        if self.mode == 'train' and self.rate > 0.0:
            noise_shape = list(x.shape)
            
            for dim in self.broadcast_dims:
                noise_shape[dim] = 1
                
            keep_prob = 1 - self.rate
            keep = np.random.binomial(n=1, p=keep_prob, size=tuple(noise_shape))
            keep = torch.tensor(keep)
            multiplier = keep / keep_prob
            return x * multiplier
        else:
            return x

In [92]:
bd = pBroadcastedDropout(mode='train', rate=0.2)

In [93]:
bd(test.unsqueeze(0))

tensor([[ 0.0000,  1.2500,  2.5000,  3.7500,  5.0000,  6.2500,  7.5000,  8.7500,
         10.0000, 11.2500]])

# Simple FeedForward

In [545]:
def FeedForward(sample, 
                d_input: int, 
                d_output: int, 
                dropout: float, 
                activation: bool, 
                mode: str):
    """
    Building a simple Feed Forward NN
    :param sample: sample data used to set the LayerNorm dimensions.
    :param d_input: model input dimension.
    :param d_output: model output dimension.
    :param dropout: amount of dropout to apply on range [0.0, 1.0).
    :param activation: True applies ReLU activation.
    :param mode: whether to run dropout in train or eval mode.
    """
    
    if len(sample.shape) == 1:
        norm_output = sample.shape[0]
        norm_input = 1
    elif len(sample.shape) == 2:
        norm_output = sample.shape[1]
        norm_input = sample.shape[0]
    else:
        norm_output = sample.shape[-1]
        norm_input = sample.shape[-2]
        
    
    return nn.Sequential(
        nn.LayerNorm(normalized_shape=(norm_input, norm_output)),
        nn.Linear(d_input, d_output),
        pBroadcastedDropout(rate=dropout, mode=mode),
        activation,
        nn.Linear(d_output, d_output),
        pBroadcastedDropout(rate=dropout, mode=mode)
    )
    

In [549]:
ff = FeedForward(sample=test.view(1,-1), 
                 d_input=test.view(1,-1).shape[1], 
                 d_output=10, 
                 dropout=0.2, 
                 activation=nn.ReLU(), 
                 mode='train')
ff(test.view(1,-1))

AttributeError: module 'jax.numpy' has no attribute 'random'

In [550]:
class SplitForOutput(nn.Module):
    """
    Splits activations into sections, to be used prior to the output layer.
    
    After the reversible portion of the network, there is a final output portion that's 
    non-reversible where the minimum includes normalization, output projection, and log-softmax. 
    The output portion needs to operate on chucks of the sequence to avoid running out of memory
    for large vocabulary sizes.
    
    This layer concatenates the two subparts of the activations along the feature dimension
    then splits into chunks along the time dimension. 
    
    """
    
    def __init__(self, n_sections=2, axis=-2, n_in=2):
        super(SplitForOutput, self).__init__()
        self.n_sections = n_sections
        self.axis = axis
        self.n_in = 2
        self.n_out = n_sections
            
    def forward(self, inputs: torch.tensor):
                    
        x1, x2 = inputs

        x1_split = np.split(x1, self.n_sections, self.axis)
        x2_split = np.split(x2, self.n_sections, self.axis)

        res = [np.concatenate(ys, -1) for ys in zip(x1_split, x2_split)]
        return tuple(res)

    def reverse(self, output, **kwargs):
        
        x1_split = []
        x2_split = []
        for y in output:
            y1, y2 = np.split(y, 2, -1)
            x1_split.append(y1)
            x2_split.append(y2)

        x1 = np.concatenate(x1_split, self.axis)
        x2 = np.concatenate(x2_split, self.axis)

        return (x1, x2)
        

In [812]:
class Chunk(nn.Module):
    
    def __init__(self, n_sections=2):
        super(Chunk, self).__init__()
        self.n_sections = n_sections
        
    def forward(self, x):
        assert x.shape[1] % n_sections == 0
        return torch.cat(torch.chunk(x, chunks=n_sections, dim=-2))
    
class Unchunk(nn.Module):
    
    def __init__(self, n_sections=2):
        super(Unchunk, self).__init__()
        self.n_sections = n_sections
        
    def forward(self, x):
        assert x.shape[0] % self.n_sections == 0
        return torch.cat(torch.chunk(x, chunks=self.n_sections, dim=-3), dim=-2)

In [666]:
def Chunk(x, n_sections=2):
    assert x.shape[1] % n_sections == 0
    
    return torch.cat(torch.chunk(x, chunks=n_sections, dim=-2))

def Unchunk(x, n_sections=2):
    assert x.shape[0] % n_sections == 0
    
    return torch.cat(torch.chunk(x, chunks=n_sections, dim=-3), dim=-2)

In [552]:
test = torch.tensor([[1,0,0,0],[0,1,0,0],
                    [0,0,1,0], [0,0,0,1]])

In [580]:
torch.chunk(test, chunks=2, dim=-2)

(tensor([[1., 0., 0., 0.],
         [0., 1., 0., 0.]]), tensor([[0., 0., 1., 0.],
         [0., 0., 0., 1.]]))

In [553]:
test

tensor([[1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1]])

## compute_residual

In [554]:
#test = test.view(1,-1)

ff = feed_forward(d_model=4, d_ff=4, dropout=0.2, activation=True, mode='train')

residual_layers = ff
compute_residual = tl.Serial(
    tl.Parallel([], tl.Dup()),
    tl.Swap(),
    tl.Parallel(residual_layers, [], [])
)

layers = [compute_residual, tl.Parallel(tl.Add(), [])]
compute_residual = tl.Serial(layers)

input_sd = ShapeDtype(tuple(test.shape), np.int32)
input_signature = (input_sd, input_sd)
weights, state = compute_residual.init(input_signature)

output, state = compute_residual(
    x=(test.numpy(), )*2, 
    weights=weights,
    state=state
)

output, state

(DeviceArray([[ 1.1745598 ,  0.07205822,  0.6006763 ,  0.7884164 ],
              [-0.5550922 ,  1.0654054 ,  0.6392876 , -0.07618227],
              [-0.10679559,  0.02530393,  0.9922953 ,  0.02808513],
              [-0.27318668,  0.24871418,  0.8338994 ,  1.1428925 ]],            dtype=float32),
 array([[1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1]]))

In [555]:
class ReversibleHalfResidual(nn.Module):
    
    def __init__(self, residual_layers: list):
        super(ReversibleHalfResidual, self).__init__()
        self.residual_layers = residual_layers
        self.model = nn.Sequential(*residual_layers)
        
        
    def compute_residual(self, inputs, **kwargs):
        """
        Replicating the JAX class of the same name step by step.
        
        """
        
        # replicating the compute_residuals operation
        output = self.model(inputs)
        # replicating the tl.Add() operation
        output += output
        return output, self.model
    
    def reverse(self, inputs):
        #inputs -= inputs
        output = self.compute_residual(inputs)[0] - inputs
        return output

In [569]:
test = [[1,0,0,0], [0,1,0,0],
       [0,0,1,0], [0,0,0,1]]
test = torch.tensor(test, dtype=torch.float32)
test

tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])

In [557]:
class GeLU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (
            1 + torch.tanh(math.sqrt(2 / math.pi) * (
                x + 0.044715 * torch.pow(x, 3))))

In [558]:
class RevNetBlock(nn.Module):
    
    def __init__(self, d_in, d_out, dropout=0.1, lol=[]):
        super(RevNetBlock, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.dropout = dropout

        layers = []
        if lol == list():
            layers.append(nn.LayerNorm((d_in,d_out)))
            layers.append(nn.Linear(d_in, d_out))
            layers.append(GeLU())
            layers.append(nn.Linear(d_in, d_out))
        else:
            for layer in lol:
                layers.append(layer)
        
        self.bottleneck_block = nn.Sequential(*layers)
    
    def forward(self, x):
        x = torch.cat((x, x), dim=1)
        x1, x2 = self.split(x)
        Fx2 = self.bottleneck_block(x2)
        y1 = Fx2 + x1
        return (x2, y1)
    
    def inverse(self, x):
        x2, y1 = x[0], x[1]
        Fx2 = - self.bottleneck_block(x2)
        x1 = Fx2 + y1
        return (x1, x2)

    @staticmethod
    def split(x):
        n = int(x.size()[1] / 2)
        x1 = x[:, :n].contiguous()
        x2 = x[:, n:].contiguous()
        return (x1, x2)

In [581]:
rb = RevNetBlock(d_in=4, d_out=4)
rb(test)

(tensor([[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.]]), tensor([[ 1.4723,  0.0600,  0.5037,  0.5120],
         [ 0.3461,  0.8839,  0.3724,  0.2543],
         [ 0.7384,  0.1297,  1.1142,  0.3288],
         [ 0.5839, -0.2881,  0.0728,  1.2639]], grad_fn=<AddBackward0>))

In [1025]:
class RevNetHalfAttnBlock(nn.Module):
    
    def __init__(self, d_in, d_out, dropout=0.1, lol=[]):
        super(RevNetHalfAttnBlock, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.dropout = dropout

        layers = []
        if lol == list():
            layers.append(nn.LayerNorm((d_in,d_out)))
            layers.append(nn.Linear(d_out, d_out))
            layers.append(GeLU())
            layers.append(nn.Linear(d_out, d_out))
        else:
            for layer in lol:
                layers.append(layer)
        
        self.bottleneck_block = nn.Sequential(*layers)
    
    def forward(self, x):
        x = torch.cat((x, x), dim=1)
        x1, x2 = self.split(x)
        Fx2 = self.bottleneck_block(x2)
        y1 = Fx2 + x1
        return (x2, y1)
    
    def inverse(self, x):
        x2, y1 = x[0], x[1]
        Fx2 = - self.bottleneck_block(x2)
        x1 = Fx2 + y1
        return (x1, x2)

    @staticmethod
    def split(x):
        n = int(x.size()[1] / 2)
        x1 = x[:, :n].contiguous()
        x2 = x[:, n:].contiguous()
        return (x1, x2)

In [343]:
class ReversibleHalfResidual(tl.ReversibleLayer, tl.Serial):
    """Half of a RevNet-style residual (only updates part of the hidden state)."""

    def __init__(self, residual_layers):
        self.compute_residual = tl.Serial(
            # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
            tl.Parallel([], tl.Dup()),
            tl.Swap(),
            tl.Parallel(residual_layers, [], []),
        )

        layers = [
            self.compute_residual,
            tl.Parallel(tl.Add(), [])
        ]
        super(ReversibleHalfResidual, self).__init__(layers)

        self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
        self.reverse_layers = [self.compute_residual, self.subtract_top]

    def reverse(self, output, weights=(), state=(), new_state=(), **kwargs):
        reconstructed_x = output
        rng = kwargs.pop('rng', None)
        rngs = (None,) * self._n_layers
        if rng is not None:
            rngs = backend.random.split(rng, self._n_layers)
        # Note that self.sublayers aligns exactly with self.reverse_layers in
        # terms of parameter and rng usage, so no re-ordering is required.
        for layer, p, s, ns, rng in zip(
            self.reverse_layers, weights, state, new_state, rngs):
            reconstructed_x = layer(reconstructed_x, weights=p,
                                  state=s, new_state=ns, rng=rng, **kwargs)
        return reconstructed_x

    def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(),
                       **kwargs):
        rng = kwargs.pop('rng', None)
        rngs = (None,) * self._n_layers
        if rng is not None:
            rngs = backend.random.split(rng, self._n_layers)

        def call_compute_residual(x, weights):
            res = self.compute_residual(x, weights=weights, state=state[0],
                                  rng=rngs[0], **kwargs)
            return res

        assert len(ct) == 2
        ct = ((ct[0], ct[0], ct[1]))

        stack_with_residual, vjpfun = jax.vjp(
            call_compute_residual, output, weights[0])
        reconstructed_x = self.subtract_top(
            stack_with_residual, weights=weights[-1], state=state[-1], rng=rngs[-1],
            **kwargs)

        x_ct, residual_weights_ct = vjpfun(ct)
        assert not jax.tree_util.tree_leaves(weights[-1])
        add_top_weights_ct = weights[-1]
        return reconstructed_x, (x_ct, [residual_weights_ct, add_top_weights_ct])

In [560]:
sfo = SplitForOutput()

In [612]:
t = torch.rand((4,4,64)).numpy()
a, b = sfo.forward((t,t))
c, d = sfo.reverse((t,t))

In [826]:
class ComputeAttentionHeads(nn.Module):
    def __init__(self, n_heads=1, d_head=64):
        super(ComputeAttentionHeads, self).__init__()
        self.n_heads= n_heads
        self.d_head = d_head
        
    def forward(self, x):
        seqlen = x.shape[1]
        res = x
        
        # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head
        res = np.reshape(res, (x.shape[0], seqlen, self.n_heads, self.d_head))
        # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head
        res = np.transpose(res, (0, 2, 1, 3))
        # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
        res = np.reshape(res, (-1, seqlen, self.d_head))
        res = nn.Linear(res.shape[-1], res.shape[-1])(res)
        return res
    
class ComputeAttentionOutput(nn.Module):
    def __init__(self):
        super(ComputeAttentionOutput, self).__init__()
        self.n_heads = 1
        
    def forward(self, x):
        seqlen = x.shape[1]
        d_head = x.shape[2]
        
        x = np.reshape(a=x, newshape=(-1, self.n_heads, seqlen, d_head))
        x = np.reshape(x, (-1, self.n_heads, seqlen, d_head))
        x = np.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
        x = np.reshape(x, (-1, seqlen, self.n_heads * d_head))
        x = torch.tensor(x)
        x = nn.Linear(x.shape[-1], x.shape[-1])(x)
        return x

In [1057]:
class DecoderBlock(nn.Module):
    
    def __init__(self, 
                 d_in, 
                 d_out, 
                 attn_k=64, 
                 attn_v=64, 
                 n_heads=1, 
                 n_chunks=2, 
                 share_qk=True, 
                 attn_type=None, 
                 dropout=None, 
                 ff_activation=None, 
                 ff_use_sru=None, 
                 mode='train'):
        super(DecoderBlock, self).__init__()
        
        self.d_in = d_in
        self.d_out = d_out
        self.attn_k = attn_k
        self.attn_v = attn_v
        self.n_heads = n_heads
        self.n_chunks = n_chunks
        self.attn_type = attn_type
        self.dropout = dropout
        self.share_qk = share_qk
        self.ff_activation = ff_activation
        self.ff_use_sru = ff_use_sru
        self.mode = mode
        
        
    def pre_attention(self, x):
        x1, x2 = torch.chunk(x, self.n_chunks)
        #x1 = x2 = torch.cat(torch.chunk(x, chunks=self.n_chunks, dim=-2))
        k_layers = [ComputeAttentionHeads(self.n_heads, self.attn_k), nn.LayerNorm((x.shape[1], x.shape[2]))]
        k_model = nn.Sequential(*k_layers)

        v_layers = [ComputeAttentionHeads(self.n_heads, self.attn_v), nn.LayerNorm((x.shape[1], x.shape[2]))]
        v_model = nn.Sequential(*v_layers)

        k = k_model(x1)
        v = v_model(x2)

        if not self.share_qk:
            q_layers= k_layers
            q_model = nn.Sequential(*q_layers)
            q = q_model(x1)
            
            q = Fx1
            
            return (q, k, v)
        else:
            return (k, k, v)
    
    
    def attention(self, inputs):
        
        assert len(inputs) == 2 or len(inputs) == 3
        if len(inputs) == 2:
            k, v = inputs
            q = k
        else:
            q, k, v = inputs
        
        mask_size = q.shape[-2]
        mask = torch.tril(
            torch.ones((1, mask_size, mask_size), dtype=torch.bool), 
            diagonal=0
        )
        
        attn = self.dotproductattention(q, k, v, mask)
        return attn
    
    
    def dotproductattention(self, q, k, v, mask, dropout=0.1):
        
        depth = q.shape[-1]
        dots = torch.matmul(q, k.transpose(-1, -2)) / np.sqrt(depth)
        dots = F.log_softmax(torch.where(mask, dots, torch.full_like(dots, -1e9)), dim=0)
        
        keep_prob = 1 - dropout
        keep = np.random.binomial(n=1, p=keep_prob, size=dots.shape)
        
        dots = torch.where(
            torch.tensor(keep, dtype=torch.bool), 
            dots / torch.tensor(keep_prob), 
            torch.zeros_like(dots)
        )
        attn = torch.matmul(dots, v)
        return attn
    
    
    def post_attention(self, x):
        
        cao = ComputeAttentionOutput()
        unchunk = Unchunk(n_sections = self.n_chunks)
        bd = pBroadcastedDropout(rate=self.dropout)
        
        res = cao(x)
        #res = torch.cat((res, res), dim=-3)
        res = bd(res)
        return res
    
    def forward(self, x):
        x = self.pre_attention(x)
        x = tuple(torch.tensor(y) for y in x)
        x = self.attention(x)
        x = self.post_attention(x)
        
        return x

In [1067]:
db = DecoderBlock(d_in=10, d_out=10, dropout=0.1)
q,k,v = db.pre_attention(torch.tensor(t))
attn = db.attention((k,v))
res = db.post_attention(torch.tensor(attn))
res = db(torch.tensor(t))

  after removing the cwd from sys.path.


In [1068]:
rb = RevNetHalfAttnBlock(d_in=res.shape[-2], d_out=res.shape[-1])
output = rb(res)
output = torch.cat(output)
output.shape

torch.Size([4, 4, 64])

In [1071]:
output

tensor([[[-2.3375e+00,  7.5782e-01, -1.3530e+00,  ..., -4.8981e-01,
           6.5929e-01,  2.1808e-01],
         [-6.1448e+00,  2.0287e+00, -4.2357e+00,  ..., -1.6251e+00,
           2.1483e+00,  1.1392e+00],
         [-5.0683e+00,  1.7662e+00, -2.7701e+00,  ..., -1.2889e+00,
           1.2900e+00,  7.0707e-01],
         [-5.2799e+00,  1.8458e+00, -2.8506e+00,  ..., -1.3264e+00,
           1.3344e+00,  7.0618e-01]],

        [[-6.7761e+00,  4.8747e+00, -2.2219e+00,  ..., -1.0755e+00,
           1.4934e+00, -9.0276e-01],
         [-3.8855e+00,  2.2801e+00, -1.2036e+00,  ..., -1.2019e+00,
           7.6750e-01, -1.4830e-03],
         [-5.1901e+00,  3.2412e+00, -1.5614e+00,  ..., -1.1380e+00,
           7.8082e-01, -3.8353e-01],
         [-4.1313e+00,  2.7029e+00, -1.3616e+00,  ..., -8.7370e-01,
           7.8209e-01, -2.3808e-01]],

        [[-2.5208e+00,  5.6488e-01, -1.2040e+00,  ..., -4.2917e-01,
           6.0526e-01,  1.7405e-01],
         [-6.6358e+00,  1.6249e+00, -3.8189e+00,  .

In [103]:

def ReformerLM(vocab_size,
               d_model=512,
               d_ff=2048,
               d_attention_key=64,
               d_attention_value=64,
               n_layers=6,
               n_heads=8,
               dropout=0.1,
               max_len=2048,
               n_chunks=0,
               n_attention_chunks=1,
               attention_type=tl.DotProductCausalAttention,
               share_qk=False,
               axial_pos_shape=(),
               d_axial_pos_embs=None,
               ff_activation=tl.FastGelu,
               ff_use_sru=0,
               mode='train'):
  """Reversible transformer language model (only uses a decoder, no encoder).
  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    mode: str: 'train', 'eval', or 'predict'
  Returns:
    the layer.
  """
  if n_chunks == 0:
    n_chunks = 1
    concatenate_input_chunks = []
  else:
    concatenate_input_chunks = tl.Concatenate(n_items=n_chunks)

  if not axial_pos_shape:
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, mode=mode)
  else:
    assert d_axial_pos_embs is not None
    positional_encoding = tl.AxialPositionalEncoding(
        shape=axial_pos_shape, d_embs=d_axial_pos_embs,
        dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
        dropout=dropout, mode=mode)

  positional_embedder = [
      tl.Embedding(d_model, vocab_size),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      positional_encoding,
  ]

  decoder_blocks = []

  if isinstance(attention_type, (tuple, list)):
    assert n_layers % len(attention_type) == 0
  else:
    attention_type = [attention_type]
  for layer_idx in range(n_layers):
    layer_attention_type = attention_type[layer_idx % len(attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        n_attention_chunks,
        attention_type=layer_attention_type,
        dropout=dropout,
        share_qk=(share_qk or issubclass(layer_attention_type,
                                         tl.LSHCausalAttention)),
        ff_activation=ff_activation,
        ff_use_sru=ff_use_sru,
        mode=mode)
    decoder_blocks.append(decoder_block)

  return tl.Serial(
      concatenate_input_chunks,
      tl.ShiftRight(mode=mode),
      positional_embedder,
      tl.Dup(),
      tl.ReversibleSerial(decoder_blocks + [
          SplitForOutput(n_sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
      ]),
      Map([
          # TODO(kitaev): Test whether dropout should go before or after the
          # LayerNorm, and whether dropout broadcasting is needed here.
          tl.LayerNorm(),
          BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
          tl.Dense(vocab_size),
          tl.LogSoftmax(),
      ], n_sections=n_chunks),
  )


In [104]:

def ReformerShortenLM(vocab_size,
                      shorten_factor=1,
                      d_embedding=256,
                      d_model=512,
                      d_ff=2048,
                      d_attention_key=64,
                      d_attention_value=64,
                      n_layers=6,
                      n_heads=8,
                      dropout=0.1,
                      max_len=2048,
                      n_attention_chunks=1,
                      attention_type=tl.DotProductCausalAttention,
                      share_qk=False,
                      axial_pos_shape=(),
                      d_axial_pos_embs=None,
                      ff_activation=tl.FastGelu,
                      ff_use_sru=0,
                      mode='train'):
  """Reversible transformer language model with shortening.
  When shorten_factor is F and processing an input of shape [batch, length],
  we embed the (shifted-right) input and then group each F elements (on length)
  into a single vector -- so that in the end we process a tensor of shape
    [batch, length // F, d_model]
  almost until the end -- at the end it's un-shortend and a SRU is applied.
  This reduces the length processed inside the main model body, effectively
  making the model faster but possibly slightly less accurate.
  Args:
    vocab_size: int: vocab size
    shorten_factor: by how much to shorten, see above
    d_embedding: the depth of the embedding layer and final logits
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, values must sum to d_embedding.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    mode: str: 'train' or 'eval'
  Returns:
    the layer.
  """
  assert mode != 'predict'  # TODO(lukaszkaiser,kitaev): fast inference

  if not axial_pos_shape:
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, mode=mode)
  else:
    assert d_axial_pos_embs is not None
    positional_encoding = tl.AxialPositionalEncoding(
        shape=axial_pos_shape, d_embs=d_axial_pos_embs,
        dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
        dropout=dropout, mode=mode)

  positional_embedder = [
      tl.Embedding(d_embedding, vocab_size),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      positional_encoding,
  ]

  decoder_blocks = []

  if isinstance(attention_type, (tuple, list)):
    assert n_layers % len(attention_type) == 0
  else:
    attention_type = [attention_type]
  for layer_idx in range(n_layers):
    layer_attention_type = attention_type[layer_idx % len(attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        n_attention_chunks,
        attention_type=layer_attention_type,
        dropout=dropout,
        share_qk=(share_qk or issubclass(layer_attention_type,
                                         tl.LSHCausalAttention)),
        ff_activation=ff_activation,
        ff_use_sru=ff_use_sru,
        mode=mode)
    decoder_blocks.append(decoder_block)

  # pylint: disable=g-long-lambda
  return tl.Serial(
      tl.ShiftRight(),
      positional_embedder,
      tl.Dup(),              # Stack has (x, x), the first will be shortened
      # Before shortening, we need to pad by shorten factor so as not to leak
      # information into the future. To understand why, imagine shorten factor
      # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we
      # would have 0ABC, which gets grouped to [0A][BC] on input, which is
      # predicting ABCD as targets. The problem is that [0A] has access to A
      # and [BC] has access to C -- it will learn to copy it, peek into
      # the future. Shifting twice to [00][AB] solves the problem as the first
      # "big" symbol becomes all-0 and the rest is shifted enough.
      tl.ShiftRight(n_shifts=shorten_factor - 1),
      tl.Fn(lambda x: np.reshape(  # Shorten -- move to depth.
          x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1),
      tl.Dense(d_model),
      tl.Dup(),  # Stack has (short_x, short_x, x)
      tl.ReversibleSerial(decoder_blocks),
      tl.Select([0], n_in=2),
      tl.LayerNorm(),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      tl.Dense(shorten_factor * d_embedding),
      tl.Fn(lambda x: np.reshape(  # Prolong back.
          x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1),
      tl.Concatenate(),  # Concatenate with just the embeddings.
      tl.CausalConv(d_embedding),
      tl.Relu(),
      tl.SRU(d_embedding),  # One RNN layer for conditional dependence.
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  )