In [524]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import jax
import trax
from trax import backend
from trax import layers as tl
from trax.backend import numpy as np
from trax.layers.combinators import _pop_rng_and_split
from trax.shapes import signature, ShapeDtype

In [74]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [75]:
from trax.models.research.reformer import ReformerLM

# Map

In [87]:
class pMap(nn.Module):
    """Combinator for applying a layers to a list/tuple"""
    def __init__(self, layer, n_sections=1, check_shapes=True):
        """
        
        :param layers: the layer you wish to apply to each element.
        :param n_sections: number of sections to map to. defaults to 1.
        :param check_shapes: tests to see the shapes are identical.
        :returns: new layers with mapped layer to all elements.
        """
        super(pMap, self).__init__()
        if layer is None or isinstance(layer, (list, tuple)):
            layer = nn.Sequential(*layer)
        self.layer = layer
        self.check_shapes = check_shapes
        self.n_sections = n_sections
        self.n_in = n_sections
        self.n_out = n_sections
        
        
    def forward(self, inputs, **kwargs):
        """
        Trying to replace with a basic forward pass. Trax implementation
        uses a PRNG key split into subsections zipped with the inputs var
        then returns a list of those forward passed subsections as results
        """
            
        if self.n_sections == 1:
            results = self.layer(inputs, **kwargs)
        
        else:
            results = [self.layer(x) for x in inputs] 
            
        return results, self.layer.state_dict()

In [88]:
pmap = pMap(layer=[
    nn.LayerNorm(1,1), 
    nn.Linear(1,1),
    nn.LogSoftmax()], n_sections=10)

pmap

pMap(
  (layer): Sequential(
    (0): LayerNorm((1,), eps=1, elementwise_affine=True)
    (1): Linear(in_features=1, out_features=1, bias=True)
    (2): LogSoftmax()
  )
)

In [89]:
test = torch.tensor([0,1,2,3,4,5,6,7,8,9], dtype=torch.float)
print(test.unsqueeze(0).shape)
print(test)

torch.Size([1, 10])
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])


In [90]:
pmap(inputs=test.view(-1,1))

  input = module(input)


([tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>),
  tensor([0.], grad_fn=<LogSoftmaxBackward>)],
 OrderedDict([('0.weight', tensor([1.])),
              ('0.bias', tensor([0.])),
              ('1.weight', tensor([[0.8039]])),
              ('1.bias', tensor([0.5340]))]))

# BroadcastedDropout

In [1422]:
class BroadcastedDropout(nn.Module):
    
    def __init__(self, rate=0.0, mode='train', broadcast_dims=(-2,)):
        super(BroadcastedDropout, self).__init__()
        
        self.rate = rate
        if self.rate >= 1.0:
            raise ValueError(f'Dropout rate ({self.rate}) must be lower than 1')
        elif self.rate < 0:
            raise ValueError(f'Dropout rate ({self.rate}) must be at least 0.0')
        
        self.broadcast_dims = broadcast_dims
        self.mode = mode
        
    def forward(self, x: torch.tensor, **kwargs):
        if self.mode == 'train' and self.rate > 0.0:
            noise_shape = list(x.shape)
            
            for dim in self.broadcast_dims:
                noise_shape[dim] = 1
                
            keep_prob = 1 - self.rate
            keep = np.random.binomial(n=1, p=keep_prob, size=tuple(noise_shape))
            keep = torch.tensor(keep)
            multiplier = keep / keep_prob
            return x * multiplier
        else:
            return x

In [1423]:
bd = BroadcastedDropout(mode='train', rate=0.2)

In [1424]:
bd(test.unsqueeze(0))

tensor([[[1.2500, 0.0000, 0.0000, 0.0000],
         [0.0000, 1.2500, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 1.2500]]])

# Simple FeedForward

In [1425]:
def FeedForward(sample, 
                d_input: int, 
                d_output: int, 
                dropout: float, 
                activation: bool, 
                mode: str):
    """
    Building a simple Feed Forward NN
    :param sample: sample data used to set the LayerNorm dimensions.
    :param d_input: model input dimension.
    :param d_output: model output dimension.
    :param dropout: amount of dropout to apply on range [0.0, 1.0).
    :param activation: True applies ReLU activation.
    :param mode: whether to run dropout in train or eval mode.
    """
    
    if len(sample.shape) == 1:
        norm_output = sample.shape[0]
        norm_input = 1
    elif len(sample.shape) == 2:
        norm_output = sample.shape[1]
        norm_input = sample.shape[0]
    else:
        norm_output = sample.shape[-1]
        norm_input = sample.shape[-2]
        
    
    return nn.Sequential(
        nn.LayerNorm(normalized_shape=(norm_input, norm_output)),
        nn.Linear(d_input, d_output),
        BroadcastedDropout(rate=dropout, mode=mode),
        activation,
        nn.Linear(d_output, d_output),
        BroadcastedDropout(rate=dropout, mode=mode)
    )
    

In [549]:
ff = FeedForward(sample=test.view(1,-1), 
                 d_input=test.view(1,-1).shape[1], 
                 d_output=10, 
                 dropout=0.2, 
                 activation=nn.ReLU(), 
                 mode='train')
ff(test.view(1,-1))

AttributeError: module 'jax.numpy' has no attribute 'random'

In [1372]:
class SplitForOutput(nn.Module):
    """
    Splits activations into sections, to be used prior to the output layer.
    
    After the reversible portion of the network, there is a final output portion that's 
    non-reversible where the minimum includes normalization, output projection, and log-softmax. 
    The output portion needs to operate on chucks of the sequence to avoid running out of memory
    for large vocabulary sizes.
    
    This layer concatenates the two subparts of the activations along the feature dimension
    then splits into chunks along the time dimension. 
    
    """
    
    def __init__(self, n_sections=2, axis=-2, n_in=2):
        super(SplitForOutput, self).__init__()
        self.n_sections = n_sections
        self.axis = axis
        self.n_in = 2
        self.n_out = n_sections
            
    def forward(self, inputs: torch.tensor):
                    
        x1, x2 = inputs

        x1_split = np.split(x1, self.n_sections, self.axis)
        x2_split = np.split(x2, self.n_sections, self.axis)

        res = [np.concatenate(ys, -1) for ys in zip(x1_split, x2_split)]
        return tuple(res)

    def reverse(self, output, **kwargs):
        
        x1_split = []
        x2_split = []
        for y in output:
            y1, y2 = np.split(y, 2, -1)
            x1_split.append(y1)
            x2_split.append(y2)

        x1 = np.concatenate(x1_split, self.axis)
        x2 = np.concatenate(x2_split, self.axis)

        return (x1, x2)
        

In [1316]:
class Chunk(nn.Module):
    
    def __init__(self, n_sections=2):
        super(Chunk, self).__init__()
        self.n_sections = n_sections
        
    def forward(self, x):
        assert x.shape[1] % self.n_sections == 0
        return torch.cat(torch.chunk(x, chunks=self.n_sections, dim=-2))
    
class Unchunk(nn.Module):
    
    def __init__(self, n_sections=2, dim=-3):
        super(Unchunk, self).__init__()
        self.n_sections = n_sections
        self.dim = dim
        
    def forward(self, x):
        assert x.shape[0] % self.n_sections == 0
        return torch.cat(torch.chunk(x, chunks=self.n_sections, dim=self.dim), dim=-2)

In [666]:
def Chunk(x, n_sections=2):
    assert x.shape[1] % n_sections == 0
    
    return torch.cat(torch.chunk(x, chunks=n_sections, dim=-2))

def Unchunk(x, n_sections=2):
    assert x.shape[0] % n_sections == 0
    
    return torch.cat(torch.chunk(x, chunks=n_sections, dim=-3), dim=-2)

In [552]:
test = torch.tensor([[1,0,0,0],[0,1,0,0],
                    [0,0,1,0], [0,0,0,1]])

In [580]:
torch.chunk(test, chunks=2, dim=-2)

(tensor([[1., 0., 0., 0.],
         [0., 1., 0., 0.]]), tensor([[0., 0., 1., 0.],
         [0., 0., 0., 1.]]))

In [553]:
test

tensor([[1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1]])

## compute_residual

In [554]:
#test = test.view(1,-1)

ff = feed_forward(d_model=4, d_ff=4, dropout=0.2, activation=True, mode='train')

residual_layers = ff
compute_residual = tl.Serial(
    tl.Parallel([], tl.Dup()),
    tl.Swap(),
    tl.Parallel(residual_layers, [], [])
)

layers = [compute_residual, tl.Parallel(tl.Add(), [])]
compute_residual = tl.Serial(layers)

input_sd = ShapeDtype(tuple(test.shape), np.int32)
input_signature = (input_sd, input_sd)
weights, state = compute_residual.init(input_signature)

output, state = compute_residual(
    x=(test.numpy(), )*2, 
    weights=weights,
    state=state
)

output, state

(DeviceArray([[ 1.1745598 ,  0.07205822,  0.6006763 ,  0.7884164 ],
              [-0.5550922 ,  1.0654054 ,  0.6392876 , -0.07618227],
              [-0.10679559,  0.02530393,  0.9922953 ,  0.02808513],
              [-0.27318668,  0.24871418,  0.8338994 ,  1.1428925 ]],            dtype=float32),
 array([[1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1]]))

In [555]:
class ReversibleHalfResidual(nn.Module):
    
    def __init__(self, residual_layers: list):
        super(ReversibleHalfResidual, self).__init__()
        self.residual_layers = residual_layers
        self.model = nn.Sequential(*residual_layers)
        
        
    def compute_residual(self, inputs, **kwargs):
        """
        Replicating the JAX class of the same name step by step.
        
        """
        
        # replicating the compute_residuals operation
        output = self.model(inputs)
        # replicating the tl.Add() operation
        output += output
        return output, self.model
    
    def reverse(self, inputs):
        #inputs -= inputs
        output = self.compute_residual(inputs)[0] - inputs
        return output

In [569]:
test = [[1,0,0,0], [0,1,0,0],
       [0,0,1,0], [0,0,0,1]]
test = torch.tensor(test, dtype=torch.float32)
test

tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])

In [557]:
class GeLU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (
            1 + torch.tanh(math.sqrt(2 / math.pi) * (
                x + 0.044715 * torch.pow(x, 3))))

In [558]:
class RevNetBlock(nn.Module):
    
    def __init__(self, d_in, d_out, dropout=0.1, lol=[]):
        super(RevNetBlock, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.dropout = dropout

        layers = []
        if lol == list():
            layers.append(nn.LayerNorm((d_in,d_out)))
            layers.append(nn.Linear(d_in, d_out))
            layers.append(GeLU())
            layers.append(nn.Linear(d_in, d_out))
        else:
            for layer in lol:
                layers.append(layer)
        
        self.bottleneck_block = nn.Sequential(*layers)
    
    def forward(self, x):
        x = torch.cat((x, x), dim=1)
        x1, x2 = self.split(x)
        Fx2 = self.bottleneck_block(x2)
        y1 = Fx2 + x1
        return (x2, y1)
    
    def inverse(self, x):
        x2, y1 = x[0], x[1]
        Fx2 = - self.bottleneck_block(x2)
        x1 = Fx2 + y1
        return (x1, x2)

    @staticmethod
    def split(x):
        n = int(x.size()[1] / 2)
        x1 = x[:, :n].contiguous()
        x2 = x[:, n:].contiguous()
        return (x1, x2)

In [581]:
rb = RevNetBlock(d_in=4, d_out=4)
rb(test)

(tensor([[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.]]), tensor([[ 1.4723,  0.0600,  0.5037,  0.5120],
         [ 0.3461,  0.8839,  0.3724,  0.2543],
         [ 0.7384,  0.1297,  1.1142,  0.3288],
         [ 0.5839, -0.2881,  0.0728,  1.2639]], grad_fn=<AddBackward0>))

In [1025]:
class RevNetHalfAttnBlock(nn.Module):
    
    def __init__(self, d_in, d_out, dropout=0.1, lol=[]):
        super(RevNetHalfAttnBlock, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.dropout = dropout

        layers = []
        if lol == list():
            layers.append(nn.LayerNorm((d_in,d_out)))
            layers.append(nn.Linear(d_out, d_out))
            layers.append(GeLU())
            layers.append(nn.Linear(d_out, d_out))
        else:
            for layer in lol:
                layers.append(layer)
        
        self.bottleneck_block = nn.Sequential(*layers)
    
    def forward(self, x):
        x = torch.cat((x, x), dim=1)
        x1, x2 = self.split(x)
        Fx2 = self.bottleneck_block(x2)
        y1 = Fx2 + x1
        return (x2, y1)
    
    def inverse(self, x):
        x2, y1 = x[0], x[1]
        Fx2 = - self.bottleneck_block(x2)
        x1 = Fx2 + y1
        return (x1, x2)

    @staticmethod
    def split(x):
        n = int(x.size()[1] / 2)
        x1 = x[:, :n].contiguous()
        x2 = x[:, n:].contiguous()
        return (x1, x2)

In [343]:
class ReversibleHalfResidual(tl.ReversibleLayer, tl.Serial):
    """Half of a RevNet-style residual (only updates part of the hidden state)."""

    def __init__(self, residual_layers):
        self.compute_residual = tl.Serial(
            # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
            tl.Parallel([], tl.Dup()),
            tl.Swap(),
            tl.Parallel(residual_layers, [], []),
        )

        layers = [
            self.compute_residual,
            tl.Parallel(tl.Add(), [])
        ]
        super(ReversibleHalfResidual, self).__init__(layers)

        self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
        self.reverse_layers = [self.compute_residual, self.subtract_top]

    def reverse(self, output, weights=(), state=(), new_state=(), **kwargs):
        reconstructed_x = output
        rng = kwargs.pop('rng', None)
        rngs = (None,) * self._n_layers
        if rng is not None:
            rngs = backend.random.split(rng, self._n_layers)
        # Note that self.sublayers aligns exactly with self.reverse_layers in
        # terms of parameter and rng usage, so no re-ordering is required.
        for layer, p, s, ns, rng in zip(
            self.reverse_layers, weights, state, new_state, rngs):
            reconstructed_x = layer(reconstructed_x, weights=p,
                                  state=s, new_state=ns, rng=rng, **kwargs)
        return reconstructed_x

    def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(),
                       **kwargs):
        rng = kwargs.pop('rng', None)
        rngs = (None,) * self._n_layers
        if rng is not None:
            rngs = backend.random.split(rng, self._n_layers)

        def call_compute_residual(x, weights):
            res = self.compute_residual(x, weights=weights, state=state[0],
                                  rng=rngs[0], **kwargs)
            return res

        assert len(ct) == 2
        ct = ((ct[0], ct[0], ct[1]))

        stack_with_residual, vjpfun = jax.vjp(
            call_compute_residual, output, weights[0])
        reconstructed_x = self.subtract_top(
            stack_with_residual, weights=weights[-1], state=state[-1], rng=rngs[-1],
            **kwargs)

        x_ct, residual_weights_ct = vjpfun(ct)
        assert not jax.tree_util.tree_leaves(weights[-1])
        add_top_weights_ct = weights[-1]
        return reconstructed_x, (x_ct, [residual_weights_ct, add_top_weights_ct])

In [560]:
sfo = SplitForOutput()

In [1213]:
t = torch.rand((4,4,64)).numpy()
a, b = sfo.forward((t,t))
c, d = sfo.reverse((t,t))

In [1214]:
t

array([[[0.9332608 , 0.37275243, 0.24749881, ..., 0.3007813 ,
         0.01336372, 0.91984534],
        [0.06810844, 0.31975365, 0.4596733 , ..., 0.6953241 ,
         0.516515  , 0.7335064 ],
        [0.2682591 , 0.4061324 , 0.4862305 , ..., 0.00894731,
         0.70863795, 0.68239385],
        [0.61330193, 0.73361546, 0.7224821 , ..., 0.38717663,
         0.71035814, 0.3269791 ]],

       [[0.80395615, 0.35125375, 0.64326906, ..., 0.7773044 ,
         0.2102406 , 0.01249462],
        [0.9474387 , 0.17382026, 0.3857957 , ..., 0.6967543 ,
         0.88277763, 0.80197996],
        [0.42688662, 0.26215553, 0.9343588 , ..., 0.73154384,
         0.8642012 , 0.03059494],
        [0.683786  , 0.95499694, 0.22347832, ..., 0.8070612 ,
         0.9337378 , 0.18357903]],

       [[0.60924685, 0.17127681, 0.5226319 , ..., 0.8543361 ,
         0.13858616, 0.07278848],
        [0.77171665, 0.9567531 , 0.4942938 , ..., 0.386685  ,
         0.8730096 , 0.46646416],
        [0.3734027 , 0.38423145, 0.6

In [1130]:
t.shape

(4, 4, 64)

In [1141]:
print(t.shape)
print(np.reshape(t, (t.shape[0], t.shape[1], 1, 64)).shape)
print(torch.reshape(torch.tensor(t), (t.shape[0], t.shape[1], 1, 64)).shape)
res = np.reshape(t, (t.shape[0], t.shape[1], 1, 64))
print(np.transpose(res, (0,2,1,3)).shape)
print(torch.transpose(torch.tensor(res), 1,2).shape)

(4, 4, 64)
(4, 4, 1, 64)
torch.Size([4, 4, 1, 64])
(4, 1, 4, 64)
torch.Size([4, 1, 4, 64])


In [1528]:
class ComputeAttentionHeads(nn.Module):
    def __init__(self, n_heads=1, d_head=64):
        super(ComputeAttentionHeads, self).__init__()
        self.n_heads= n_heads
        self.d_head = d_head
        
    def forward(self, x):
        if not torch.is_tensor(x):
            x = torch.tensor(x, dtype=torch.float)
        
        seqlen = x.shape[1]
        res = x
        
        # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head
        res = torch.reshape(res, (x.shape[0], seqlen, self.n_heads, self.d_head))
        # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head
        res = torch.transpose(res, 1, 2)
        # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
        res = torch.reshape(res, (-1, seqlen, self.d_head))
        res = nn.Linear(res.shape[-1], res.shape[-1])(res)
        return res
    
class ComputeAttentionOutput(nn.Module):
    def __init__(self, n_heads=1):
        super(ComputeAttentionOutput, self).__init__()
        self.n_heads = n_heads
        
    def forward(self, x):
        
        if not torch.is_tensor(x):
            x = torch.tensor(x, dtype=torch.float32)
        
        seqlen = x.shape[1]
        d_head = x.shape[2]
        
        x = torch.reshape(x, (-1, self.n_heads, seqlen, d_head))
        x = torch.transpose(x, 1, 2)
        x = torch.reshape(x, (-1, seqlen, self.n_heads * d_head))
        x = nn.Linear(x.shape[-1], x.shape[-1])(x)
        return x

In [1529]:
class DecoderBlock(nn.Module):
    
    def __init__(self, 
                 d_in, 
                 d_out, 
                 attn_k=64, 
                 attn_v=64, 
                 n_heads=1, 
                 n_chunks=2, 
                 share_qk=True, 
                 attn_type=None, 
                 dropout=None, 
                 ff_activation=None, 
                 ff_use_sru=None, 
                 mode='train'):
        super(DecoderBlock, self).__init__()
        
        self.d_in = d_in
        self.d_out = d_out
        self.attn_k = attn_k
        self.attn_v = attn_v
        self.n_heads = n_heads
        self.n_chunks = n_chunks
        self.attn_type = attn_type
        self.dropout = dropout
        self.share_qk = share_qk
        self.ff_activation = ff_activation
        self.ff_use_sru = ff_use_sru
        self.mode = mode
        
        
    def pre_attention(self, x):
        
        x1, x2 = torch.chunk(x, self.n_chunks)
        k_layers = [ComputeAttentionHeads(self.n_heads, self.attn_k), nn.LayerNorm((x.shape[1], x.shape[2]))]
        k_model = nn.Sequential(*k_layers)

        v_layers = [ComputeAttentionHeads(self.n_heads, self.attn_v), nn.LayerNorm((x.shape[1], x.shape[2]))]
        v_model = nn.Sequential(*v_layers)

        k = k_model(x1)
        v = v_model(x2)

        if not self.share_qk:
            q_layers = k_layers
            q_model = nn.Sequential(*q_layers)
            q = q_model(x1)
            
            q = Fx1
            
            return (q, k, v)
        else:
            return (k, k, v)
    
    
    def attention(self, inputs):
        
        assert len(inputs) == 2 or len(inputs) == 3
        if len(inputs) == 2:
            k, v = inputs
            q = k
        else:
            q, k, v = inputs
        
        mask_size = q.shape[-2]
        mask = torch.tril(
            torch.ones((1, mask_size, mask_size), dtype=torch.bool), 
            diagonal=0
        )
        
        attn = self.dotproductattention(q, k, v, mask)
        return attn
    
    
    def dotproductattention(self, q, k, v, mask, dropout=0.1):
        
        depth = q.shape[-1]
        dots = torch.matmul(q, k.transpose(-1, -2)) / np.sqrt(depth)
        dots = F.log_softmax(torch.where(mask, dots, torch.full_like(dots, -1e9)), dim=0)
        
        keep_prob = 1 - dropout
        keep = np.random.binomial(n=1, p=keep_prob, size=dots.shape)
        
        dots = torch.where(
            torch.tensor(keep, dtype=torch.bool), 
            dots / torch.tensor(keep_prob), 
            torch.zeros_like(dots)
        )
        attn = torch.matmul(dots, v)
        return attn
    
    
    def post_attention(self, x):
        
        cao = ComputeAttentionOutput()
        unchunk = Unchunk(n_sections = self.n_chunks, dim=-2)
        bd = BroadcastedDropout(rate=self.dropout)
        
        res = cao(x)
        #res = torch.cat((res, res), dim=-3)
        res = unchunk(res)
        res = bd(res)
        return res
    
    def forward(self, x):
        
        if not torch.is_tensor(x):
            x = torch.tensor(x, dtype=torch.float32)
        
        x = self.pre_attention(x)
        #x = tuple(torch.tensor(y) for y in x)
        x = self.attention(x)
        x = self.post_attention(x)
        return torch.cat((x,x))

In [1530]:
db = DecoderBlock(d_in=10, d_out=10, dropout=0.1)
q,k,v = db.pre_attention(t)
attn = db.attention((k,v))
res = db.post_attention(torch.tensor(attn))
res = db(torch.tensor(t))
res.shape

  after removing the cwd from sys.path.
  """


torch.Size([4, 4, 64])

In [1531]:
t = torch.tensor(t)

  """Entry point for launching an IPython kernel.


In [1532]:
len(sfo.reverse(t))

2

In [1533]:
t.shape

torch.Size([4, 4, 64])

In [1534]:
Chunk()(torch.tensor(t)).shape

  """Entry point for launching an IPython kernel.


torch.Size([8, 2, 64])

In [1535]:
db = DecoderBlock(d_in=10, d_out=10, dropout=0.1)
q,k,v = db.pre_attention(torch.tensor(t))
attn = db.attention((k,v))
res = db.post_attention(torch.tensor(attn))
res = db(torch.tensor(t))

# q,k,v = db.pre_attention(res)
# attn = db.attention((k,v))
# res = db.post_attention(torch.tensor(attn))
# res = db(torch.tensor(t))

# q,k,v = db.pre_attention(res)
# attn = db.attention((k,v))
# res = db.post_attention(torch.tensor(attn))
# res = db(torch.tensor(t))

len(torch.chunk(res, chunks=2))

  
  after removing the cwd from sys.path.
  """


2

In [1536]:
rb = RevNetHalfAttnBlock(d_in=res.shape[-2], d_out=res.shape[-1])
output = rb(res)
output = torch.cat(output)
output.shape

torch.Size([8, 4, 64])

In [1537]:
output

tensor([[[-1.3323, -2.1675, -1.0479,  ...,  0.6190, -0.4930, -1.8913],
         [-1.6069, -2.6593, -1.1764,  ...,  0.7585, -0.3427, -2.3052],
         [-1.0038, -1.1933, -0.6131,  ...,  0.4285, -0.6670, -1.0500],
         [-2.3543, -2.9169, -1.4736,  ...,  1.0477, -1.2215, -2.5753]],

        [[-0.6914, -0.0000, -1.1049,  ...,  0.2784, -0.8299,  0.0807],
         [-0.2466, -0.0000, -1.2179,  ...,  0.3073, -0.7717, -0.4986],
         [-0.3930, -0.0000, -2.6215,  ...,  0.8104, -2.0397, -1.2980],
         [-0.0596, -0.0000, -1.1229,  ...,  0.3441, -0.9026, -0.7586]],

        [[-1.3323, -2.1675, -1.0479,  ...,  0.6190, -0.4930, -1.8913],
         [-1.6069, -2.6593, -1.1764,  ...,  0.7585, -0.3427, -2.3052],
         [-1.0038, -1.1933, -0.6131,  ...,  0.4285, -0.6670, -1.0500],
         [-2.3543, -2.9169, -1.4736,  ...,  1.0477, -1.2215, -2.5753]],

        ...,

        [[-0.4555, -0.0740, -1.2646,  ...,  0.3365, -0.8418,  0.1093],
         [ 0.0178, -0.0785, -1.3758,  ...,  0.3962, -0.78

In [1538]:
class ReformerLM(nn.Module):
    
    def __init__(self,
                vocab_size,
                d_in,
                d_out,
                attn_k=64,
                attn_v=64,
                n_layers=6,
                n_heads=1,
                dropout=0.1,
                max_len=2048,
                n_chunks=2,
                n_attention_chunks=2,
                share_qk=True,
                axial_pos_shape=(),
                d_axial_pos_embs=None,
                mode='train'):
        super(ReformerLM, self).__init__()
        
        self.vocab_size = vocab_size
        self.d_in = d_in
        self.d_out = d_out
        self.attn_k = attn_k
        self.attn_v = attn_v
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.dropout = dropout
        self.max_len = max_len
        self.n_chunks = n_chunks
        self.n_attention_chunks = n_attention_chunks
        self.share_qk = share_qk
        self.axial_pos_shape = axial_pos_shape
        self.d_axial_pos_embs = d_axial_pos_embs
        self.mode = mode
 
        self.layers = []
        self.layers.append(
            DecoderBlock(
                    d_in = self.d_in,
                    d_out = self.d_out, 
                    attn_k = self.attn_k,
                    attn_v = self.attn_v,
                    n_heads = self.n_heads,
                    n_chunks = self.n_attention_chunks,
                    share_qk = self.share_qk,
                    attn_type = None,
                    dropout = self.dropout
                )
        )
        
        for layer in range(self.n_layers - 1):
            #self.layers.append(Chunk(n_sections=self.n_attention_chunks))
            self.layers.append(
                DecoderBlock(
                    d_in = self.d_out,
                    d_out = self.d_out, 
                    attn_k = self.attn_k,
                    attn_v = self.attn_v,
                    n_heads = self.n_heads,
                    n_chunks = self.n_attention_chunks,
                    share_qk = self.share_qk,
                    attn_type = None,
                    dropout = self.dropout
                )
            )
            
        self.ff_layers = [
            nn.LayerNorm((1, self.d_out * self.d_in)),
            nn.Linear(self.d_out * self.d_in, self.d_out * self.d_in),
            BroadcastedDropout(rate=self.dropout, mode=self.mode),
            GeLU(),
            nn.Linear(self.d_out * self.d_in, self.vocab_size),
            nn.LogSoftmax()
        ]
        
        self.model = nn.Sequential(*self.layers)
        self.ff_model = nn.Sequential(*self.ff_layers)
    
    
    def forward(self, x):
        
        x = self.model(x)
        # Flattening
        x = x.view(x.shape[0],1,-1)
        x = self.ff_model(x)
        
        return x

In [1539]:
rlm = ReformerLM(vocab_size=100, 
                 d_in=t.shape[-2], 
                 d_out=t.shape[-1], 
                 n_layers=6)
#rlm._build_model()
output = rlm(t)
print(torch.tensor(t).shape)
print(output.shape)

torch.Size([4, 4, 64])
torch.Size([4, 1, 100])


  import sys


In [1540]:
output.shape

torch.Size([4, 1, 100])

# Prelim-Testing

In [1541]:
nn.MultiheadAttention(test.shape[-2], test.shape[-2])(test)

TypeError: forward() missing 2 required positional arguments: 'key' and 'value'

In [1562]:
from tqdm import tqdm, tqdm_notebook

In [1571]:
batch_sizes = range(4,16,2)
ydim = range(4,32,2)
seqlen = range(10,100,10)

for bs in tqdm_notebook(batch_sizes):
    for y in tqdm_notebook(ydim):
        for seq in tqdm_notebook(seqlen):
            test = torch.rand((bs, y, seq))
            try:
                rlm = ReformerLM(vocab_size=100, 
                                 d_in=test.shape[-2], 
                                 d_out=test.shape[-1], 
                                 n_layers=6, 
                                 n_heads=1, 
                                 attn_k=test.shape[-1], 
                                 attn_v=test.shape[-1], 
                                )

                output = rlm(test)
                assert output.shape == torch.Size([bs, 1, rlm.vocab_size])
            except AssertionError as e:
                print(f'Error on: {test.shape}\n{e}')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  import sys


HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))





HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([6, 4, 10])

Error on: torch.Size([6, 4, 20])

Error on: torch.Size([6, 4, 30])

Error on: torch.Size([6, 4, 40])

Error on: torch.Size([6, 4, 50])

Error on: torch.Size([6, 4, 60])

Error on: torch.Size([6, 4, 70])

Error on: torch.Size([6, 4, 80])

Error on: torch.Size([6, 4, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([6, 6, 10])

Error on: torch.Size([6, 6, 20])

Error on: torch.Size([6, 6, 30])

Error on: torch.Size([6, 6, 40])

Error on: torch.Size([6, 6, 50])

Error on: torch.Size([6, 6, 60])

Error on: torch.Size([6, 6, 70])

Error on: torch.Size([6, 6, 80])

Error on: torch.Size([6, 6, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([6, 8, 10])

Error on: torch.Size([6, 8, 20])

Error on: torch.Size([6, 8, 30])

Error on: torch.Size([6, 8, 40])

Error on: torch.Size([6, 8, 50])

Error on: torch.Size([6, 8, 60])

Error on: torch.Size([6, 8, 70])

Error on: torch.Size([6, 8, 80])

Error on: torch.Size([6, 8, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([6, 10, 10])

Error on: torch.Size([6, 10, 20])

Error on: torch.Size([6, 10, 30])

Error on: torch.Size([6, 10, 40])

Error on: torch.Size([6, 10, 50])

Error on: torch.Size([6, 10, 60])

Error on: torch.Size([6, 10, 70])

Error on: torch.Size([6, 10, 80])

Error on: torch.Size([6, 10, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([6, 12, 10])

Error on: torch.Size([6, 12, 20])

Error on: torch.Size([6, 12, 30])

Error on: torch.Size([6, 12, 40])

Error on: torch.Size([6, 12, 50])

Error on: torch.Size([6, 12, 60])

Error on: torch.Size([6, 12, 70])

Error on: torch.Size([6, 12, 80])

Error on: torch.Size([6, 12, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([6, 14, 10])

Error on: torch.Size([6, 14, 20])

Error on: torch.Size([6, 14, 30])

Error on: torch.Size([6, 14, 40])

Error on: torch.Size([6, 14, 50])

Error on: torch.Size([6, 14, 60])

Error on: torch.Size([6, 14, 70])

Error on: torch.Size([6, 14, 80])

Error on: torch.Size([6, 14, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([6, 16, 10])

Error on: torch.Size([6, 16, 20])

Error on: torch.Size([6, 16, 30])

Error on: torch.Size([6, 16, 40])

Error on: torch.Size([6, 16, 50])

Error on: torch.Size([6, 16, 60])

Error on: torch.Size([6, 16, 70])

Error on: torch.Size([6, 16, 80])

Error on: torch.Size([6, 16, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([6, 18, 10])

Error on: torch.Size([6, 18, 20])

Error on: torch.Size([6, 18, 30])

Error on: torch.Size([6, 18, 40])

Error on: torch.Size([6, 18, 50])

Error on: torch.Size([6, 18, 60])

Error on: torch.Size([6, 18, 70])

Error on: torch.Size([6, 18, 80])

Error on: torch.Size([6, 18, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([6, 20, 10])

Error on: torch.Size([6, 20, 20])

Error on: torch.Size([6, 20, 30])

Error on: torch.Size([6, 20, 40])

Error on: torch.Size([6, 20, 50])

Error on: torch.Size([6, 20, 60])

Error on: torch.Size([6, 20, 70])

Error on: torch.Size([6, 20, 80])

Error on: torch.Size([6, 20, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([6, 22, 10])

Error on: torch.Size([6, 22, 20])

Error on: torch.Size([6, 22, 30])

Error on: torch.Size([6, 22, 40])

Error on: torch.Size([6, 22, 50])

Error on: torch.Size([6, 22, 60])

Error on: torch.Size([6, 22, 70])

Error on: torch.Size([6, 22, 80])

Error on: torch.Size([6, 22, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([6, 24, 10])

Error on: torch.Size([6, 24, 20])

Error on: torch.Size([6, 24, 30])

Error on: torch.Size([6, 24, 40])

Error on: torch.Size([6, 24, 50])

Error on: torch.Size([6, 24, 60])

Error on: torch.Size([6, 24, 70])

Error on: torch.Size([6, 24, 80])

Error on: torch.Size([6, 24, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([6, 26, 10])

Error on: torch.Size([6, 26, 20])

Error on: torch.Size([6, 26, 30])

Error on: torch.Size([6, 26, 40])

Error on: torch.Size([6, 26, 50])

Error on: torch.Size([6, 26, 60])

Error on: torch.Size([6, 26, 70])

Error on: torch.Size([6, 26, 80])

Error on: torch.Size([6, 26, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([6, 28, 10])

Error on: torch.Size([6, 28, 20])

Error on: torch.Size([6, 28, 30])

Error on: torch.Size([6, 28, 40])

Error on: torch.Size([6, 28, 50])

Error on: torch.Size([6, 28, 60])

Error on: torch.Size([6, 28, 70])

Error on: torch.Size([6, 28, 80])

Error on: torch.Size([6, 28, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([6, 30, 10])

Error on: torch.Size([6, 30, 20])

Error on: torch.Size([6, 30, 30])

Error on: torch.Size([6, 30, 40])

Error on: torch.Size([6, 30, 50])

Error on: torch.Size([6, 30, 60])

Error on: torch.Size([6, 30, 70])

Error on: torch.Size([6, 30, 80])

Error on: torch.Size([6, 30, 90])





HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))





HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([10, 4, 10])

Error on: torch.Size([10, 4, 20])

Error on: torch.Size([10, 4, 30])

Error on: torch.Size([10, 4, 40])

Error on: torch.Size([10, 4, 50])

Error on: torch.Size([10, 4, 60])

Error on: torch.Size([10, 4, 70])

Error on: torch.Size([10, 4, 80])

Error on: torch.Size([10, 4, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([10, 6, 10])

Error on: torch.Size([10, 6, 20])

Error on: torch.Size([10, 6, 30])

Error on: torch.Size([10, 6, 40])

Error on: torch.Size([10, 6, 50])

Error on: torch.Size([10, 6, 60])

Error on: torch.Size([10, 6, 70])

Error on: torch.Size([10, 6, 80])

Error on: torch.Size([10, 6, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([10, 8, 10])

Error on: torch.Size([10, 8, 20])

Error on: torch.Size([10, 8, 30])

Error on: torch.Size([10, 8, 40])

Error on: torch.Size([10, 8, 50])

Error on: torch.Size([10, 8, 60])

Error on: torch.Size([10, 8, 70])

Error on: torch.Size([10, 8, 80])

Error on: torch.Size([10, 8, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([10, 10, 10])

Error on: torch.Size([10, 10, 20])

Error on: torch.Size([10, 10, 30])

Error on: torch.Size([10, 10, 40])

Error on: torch.Size([10, 10, 50])

Error on: torch.Size([10, 10, 60])

Error on: torch.Size([10, 10, 70])

Error on: torch.Size([10, 10, 80])

Error on: torch.Size([10, 10, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([10, 12, 10])

Error on: torch.Size([10, 12, 20])

Error on: torch.Size([10, 12, 30])

Error on: torch.Size([10, 12, 40])

Error on: torch.Size([10, 12, 50])

Error on: torch.Size([10, 12, 60])

Error on: torch.Size([10, 12, 70])

Error on: torch.Size([10, 12, 80])

Error on: torch.Size([10, 12, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([10, 14, 10])

Error on: torch.Size([10, 14, 20])

Error on: torch.Size([10, 14, 30])

Error on: torch.Size([10, 14, 40])

Error on: torch.Size([10, 14, 50])

Error on: torch.Size([10, 14, 60])

Error on: torch.Size([10, 14, 70])

Error on: torch.Size([10, 14, 80])

Error on: torch.Size([10, 14, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([10, 16, 10])

Error on: torch.Size([10, 16, 20])

Error on: torch.Size([10, 16, 30])

Error on: torch.Size([10, 16, 40])

Error on: torch.Size([10, 16, 50])

Error on: torch.Size([10, 16, 60])

Error on: torch.Size([10, 16, 70])

Error on: torch.Size([10, 16, 80])

Error on: torch.Size([10, 16, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([10, 18, 10])

Error on: torch.Size([10, 18, 20])

Error on: torch.Size([10, 18, 30])

Error on: torch.Size([10, 18, 40])

Error on: torch.Size([10, 18, 50])

Error on: torch.Size([10, 18, 60])

Error on: torch.Size([10, 18, 70])

Error on: torch.Size([10, 18, 80])

Error on: torch.Size([10, 18, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([10, 20, 10])

Error on: torch.Size([10, 20, 20])

Error on: torch.Size([10, 20, 30])

Error on: torch.Size([10, 20, 40])

Error on: torch.Size([10, 20, 50])

Error on: torch.Size([10, 20, 60])

Error on: torch.Size([10, 20, 70])

Error on: torch.Size([10, 20, 80])

Error on: torch.Size([10, 20, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([10, 22, 10])

Error on: torch.Size([10, 22, 20])

Error on: torch.Size([10, 22, 30])

Error on: torch.Size([10, 22, 40])

Error on: torch.Size([10, 22, 50])

Error on: torch.Size([10, 22, 60])

Error on: torch.Size([10, 22, 70])

Error on: torch.Size([10, 22, 80])

Error on: torch.Size([10, 22, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([10, 24, 10])

Error on: torch.Size([10, 24, 20])

Error on: torch.Size([10, 24, 30])

Error on: torch.Size([10, 24, 40])

Error on: torch.Size([10, 24, 50])

Error on: torch.Size([10, 24, 60])

Error on: torch.Size([10, 24, 70])

Error on: torch.Size([10, 24, 80])

Error on: torch.Size([10, 24, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([10, 26, 10])

Error on: torch.Size([10, 26, 20])

Error on: torch.Size([10, 26, 30])

Error on: torch.Size([10, 26, 40])

Error on: torch.Size([10, 26, 50])

Error on: torch.Size([10, 26, 60])

Error on: torch.Size([10, 26, 70])

Error on: torch.Size([10, 26, 80])

Error on: torch.Size([10, 26, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([10, 28, 10])

Error on: torch.Size([10, 28, 20])

Error on: torch.Size([10, 28, 30])

Error on: torch.Size([10, 28, 40])

Error on: torch.Size([10, 28, 50])

Error on: torch.Size([10, 28, 60])

Error on: torch.Size([10, 28, 70])

Error on: torch.Size([10, 28, 80])

Error on: torch.Size([10, 28, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([10, 30, 10])

Error on: torch.Size([10, 30, 20])

Error on: torch.Size([10, 30, 30])

Error on: torch.Size([10, 30, 40])

Error on: torch.Size([10, 30, 50])

Error on: torch.Size([10, 30, 60])

Error on: torch.Size([10, 30, 70])

Error on: torch.Size([10, 30, 80])

Error on: torch.Size([10, 30, 90])





HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))





HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([14, 4, 10])

Error on: torch.Size([14, 4, 20])

Error on: torch.Size([14, 4, 30])

Error on: torch.Size([14, 4, 40])

Error on: torch.Size([14, 4, 50])

Error on: torch.Size([14, 4, 60])

Error on: torch.Size([14, 4, 70])

Error on: torch.Size([14, 4, 80])

Error on: torch.Size([14, 4, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([14, 6, 10])

Error on: torch.Size([14, 6, 20])

Error on: torch.Size([14, 6, 30])

Error on: torch.Size([14, 6, 40])

Error on: torch.Size([14, 6, 50])

Error on: torch.Size([14, 6, 60])

Error on: torch.Size([14, 6, 70])

Error on: torch.Size([14, 6, 80])

Error on: torch.Size([14, 6, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([14, 8, 10])

Error on: torch.Size([14, 8, 20])

Error on: torch.Size([14, 8, 30])

Error on: torch.Size([14, 8, 40])

Error on: torch.Size([14, 8, 50])

Error on: torch.Size([14, 8, 60])

Error on: torch.Size([14, 8, 70])

Error on: torch.Size([14, 8, 80])

Error on: torch.Size([14, 8, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([14, 10, 10])

Error on: torch.Size([14, 10, 20])

Error on: torch.Size([14, 10, 30])

Error on: torch.Size([14, 10, 40])

Error on: torch.Size([14, 10, 50])

Error on: torch.Size([14, 10, 60])

Error on: torch.Size([14, 10, 70])

Error on: torch.Size([14, 10, 80])

Error on: torch.Size([14, 10, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([14, 12, 10])

Error on: torch.Size([14, 12, 20])

Error on: torch.Size([14, 12, 30])

Error on: torch.Size([14, 12, 40])

Error on: torch.Size([14, 12, 50])

Error on: torch.Size([14, 12, 60])

Error on: torch.Size([14, 12, 70])

Error on: torch.Size([14, 12, 80])

Error on: torch.Size([14, 12, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([14, 14, 10])

Error on: torch.Size([14, 14, 20])

Error on: torch.Size([14, 14, 30])

Error on: torch.Size([14, 14, 40])

Error on: torch.Size([14, 14, 50])

Error on: torch.Size([14, 14, 60])

Error on: torch.Size([14, 14, 70])

Error on: torch.Size([14, 14, 80])

Error on: torch.Size([14, 14, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([14, 16, 10])

Error on: torch.Size([14, 16, 20])

Error on: torch.Size([14, 16, 30])

Error on: torch.Size([14, 16, 40])

Error on: torch.Size([14, 16, 50])

Error on: torch.Size([14, 16, 60])

Error on: torch.Size([14, 16, 70])

Error on: torch.Size([14, 16, 80])

Error on: torch.Size([14, 16, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([14, 18, 10])

Error on: torch.Size([14, 18, 20])

Error on: torch.Size([14, 18, 30])

Error on: torch.Size([14, 18, 40])

Error on: torch.Size([14, 18, 50])

Error on: torch.Size([14, 18, 60])

Error on: torch.Size([14, 18, 70])

Error on: torch.Size([14, 18, 80])

Error on: torch.Size([14, 18, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([14, 20, 10])

Error on: torch.Size([14, 20, 20])

Error on: torch.Size([14, 20, 30])

Error on: torch.Size([14, 20, 40])

Error on: torch.Size([14, 20, 50])

Error on: torch.Size([14, 20, 60])

Error on: torch.Size([14, 20, 70])

Error on: torch.Size([14, 20, 80])

Error on: torch.Size([14, 20, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([14, 22, 10])

Error on: torch.Size([14, 22, 20])

Error on: torch.Size([14, 22, 30])

Error on: torch.Size([14, 22, 40])

Error on: torch.Size([14, 22, 50])

Error on: torch.Size([14, 22, 60])

Error on: torch.Size([14, 22, 70])

Error on: torch.Size([14, 22, 80])

Error on: torch.Size([14, 22, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([14, 24, 10])

Error on: torch.Size([14, 24, 20])

Error on: torch.Size([14, 24, 30])

Error on: torch.Size([14, 24, 40])

Error on: torch.Size([14, 24, 50])

Error on: torch.Size([14, 24, 60])

Error on: torch.Size([14, 24, 70])

Error on: torch.Size([14, 24, 80])

Error on: torch.Size([14, 24, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([14, 26, 10])

Error on: torch.Size([14, 26, 20])

Error on: torch.Size([14, 26, 30])

Error on: torch.Size([14, 26, 40])

Error on: torch.Size([14, 26, 50])

Error on: torch.Size([14, 26, 60])

Error on: torch.Size([14, 26, 70])

Error on: torch.Size([14, 26, 80])

Error on: torch.Size([14, 26, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([14, 28, 10])

Error on: torch.Size([14, 28, 20])

Error on: torch.Size([14, 28, 30])

Error on: torch.Size([14, 28, 40])

Error on: torch.Size([14, 28, 50])

Error on: torch.Size([14, 28, 60])

Error on: torch.Size([14, 28, 70])

Error on: torch.Size([14, 28, 80])

Error on: torch.Size([14, 28, 90])




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

Error on: torch.Size([14, 30, 10])

Error on: torch.Size([14, 30, 20])

Error on: torch.Size([14, 30, 30])

Error on: torch.Size([14, 30, 40])

Error on: torch.Size([14, 30, 50])

Error on: torch.Size([14, 30, 60])

Error on: torch.Size([14, 30, 70])

Error on: torch.Size([14, 30, 80])

Error on: torch.Size([14, 30, 90])






In [1572]:
test = torch.rand((4,4,40))

rlm = ReformerLM(vocab_size=100, 
                 d_in=test.shape[-2], 
                 d_out=test.shape[-1], 
                 n_layers=6, 
                 n_heads=1, 
                 attn_k=test.shape[-1], 
                 attn_v=test.shape[-1], 
                )

output = rlm(test)
output.shape

torch.Size([4, 1, 100])