In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


def exists(val):
    return val is not None


def shift(t, amount, mask = None):
    if amount == 0:
        return t
    else:
        amount = min(amount, t.shape[1])

    if exists(mask):
        t = t.masked_fill(~mask[..., None], 0.)

    return F.pad(t, (0, 0, amount, -amount), value = 0.)

class ShiftTokens(nn.Module):
    def __init__(self, shifts, fn):
        super().__init__()
        self.fn = fn
        self.shifts = tuple(shifts)

    def forward(self, x, **kwargs):
        mask = kwargs.get('mask', None)
        shifts = self.shifts
        segments = len(shifts)
        feats_per_shift = x.shape[-1] // segments
        splitted = x.split(feats_per_shift, dim = -1)
        segments_to_shift, rest = splitted[:segments], splitted[segments:]
        segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
        x = torch.cat((*segments_to_shift, *rest), dim = -1)
        return self.fn(x, **kwargs)


In [98]:
test = torch.randn(10, 16, 320, 320)
l = nn.Conv2d(16, 8, (1,1))

test = l(test)
print(test.shape)
v = torch.randn(10, 1, 320, 128)
print(v.shape)
out = torch.matmul(test, v)
print(out.shape)

torch.Size([10, 8, 320, 320])
torch.Size([10, 1, 320, 128])
torch.Size([10, 8, 320, 128])


In [89]:
l.weight.shape

torch.Size([8, 16, 1, 1])

In [93]:
nn.Parameter(torch.ones(1, 8, 1, 1) * 15.5 + torch.randn(1, 8, 1, 1) * 5, requires_grad=True)

Parameter containing:
tensor([[[[14.2882]],

         [[25.2246]],

         [[15.2835]],

         [[ 9.7837]],

         [[15.9891]],

         [[ 4.2996]],

         [[18.0781]],

         [[14.7954]]]], requires_grad=True)

In [9]:
tk_shift = ShiftTokens(range(0, 2), nn.Identity())
inputs = torch.randn(10, 300, 768)
tk_shift(inputs).shape

torch.Size([10, 300, 768])

In [37]:
import torch
import math
from torch import nn
import torch.nn.functional as F
from operator import mul
from functools import reduce


# constant

TOKEN_SELF_ATTN_VALUE = -5e4 # carefully set for half precision to work

# helper functions

def exists(val):
    return val is not None

def default(value, d):
    return d if not exists(value) else value

def to(t):
    return {'device': t.device, 'dtype': t.dtype}

def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

def merge_dims(ind_from, ind_to, tensor):
    shape = list(tensor.shape)
    arr_slice = slice(ind_from, ind_to + 1)
    shape[arr_slice] = [reduce(mul, shape[arr_slice])]
    return tensor.reshape(*shape)

def expand_dim(t, dim, k, unsqueeze=True):
    if unsqueeze:
        t = t.unsqueeze(dim)
    expand_shape = [-1] * len(t.shape)
    expand_shape[dim] = k
    return t.expand(*expand_shape)

def pad_to_multiple(tensor, multiple, dim=-1, value=0):
    seqlen = tensor.shape[dim]
    m = seqlen / multiple
    if m.is_integer():
        return tensor
    remainder = math.ceil(m) * multiple - seqlen
    pad_offset = (0,) * (-1 - dim) * 2
    return F.pad(tensor, (*pad_offset, 0, remainder), value=value)

def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
    t = x.shape[1]
    dims = (len(x.shape) - dim) * (0, 0)
    padded_x = F.pad(x, (*dims, backward, forward), value= pad_value)
    tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)]
    return torch.cat(tensors, dim=dim)

# main class

class LocalAttention(nn.Module):
    def __init__(
        self,
        window_size,
        causal = False,
        look_backward = 1,
        look_forward = None,
        dropout = 0.,
        shared_qk = False,
        rel_pos_emb_config = None,
        dim = None,
        autopad = False,
        exact_windowsize = False
    ):
        super().__init__()
        look_forward = default(look_forward, 0 if causal else 1)
        assert not (causal and look_forward > 0), 'you cannot look forward if causal'

        self.window_size = window_size
        self.causal = causal
        self.look_backward = look_backward
        self.look_forward = look_forward
        self.exact_windowsize = exact_windowsize
        self.autopad = autopad

        self.dropout = nn.Dropout(dropout)

        self.shared_qk = shared_qk

       

    def forward(self, q, k, v, input_mask = None):
        shape = q.shape

        merge_into_batch = lambda t: t.reshape(-1, *t.shape[-2:])
        q, k, v = map(merge_into_batch, (q, k, v))

     

        if self.autopad:
            orig_t = q.shape[1]
            q, k, v = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v))

        window_size, causal, look_backward, look_forward, shared_qk = self.window_size, self.causal, self.look_backward, self.look_forward, self.shared_qk
        b, t, e, device, dtype = *q.shape, q.device, q.dtype
        assert (t % window_size) == 0, f'sequence length {t} must be divisible by window size {window_size} for local attention'

        windows = t // window_size

        if shared_qk:
            k = F.normalize(k, 2, dim=-1).type_as(q)

        ticker = torch.arange(t, device=device, dtype=torch.long)[None, :]
        b_t = ticker.reshape(1, windows, window_size)

        bucket_fn = lambda t: t.reshape(b, windows, window_size, -1)
        bq, bk, bv = map(bucket_fn, (q, k, v))

        look_around_kwargs = {'backward': look_backward, 'forward': look_forward}
        bk = look_around(bk, **look_around_kwargs)
        bv = look_around(bv, **look_around_kwargs)

        bq_t = b_t
        bq_k = look_around(b_t, **look_around_kwargs)

        dots = torch.einsum('bhie,bhje->bhij', bq, bk) * (e ** -0.5)

        mask_value = max_neg_value(dots)

        if shared_qk:
            mask = bq_t[:, :, :, None] == bq_k[:, :, None, :]
            dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE)
            del mask

        if causal:
            mask = bq_t[:, :, :, None] < bq_k[:, :, None, :]

            if self.exact_windowsize:
                max_causal_window_size = (self.window_size * self.look_backward)
                mask = mask | (bq_t[:, :, :, None] > (bq_k[:, :, None, :] + max_causal_window_size))

            dots.masked_fill_(mask, mask_value)
            del mask

        mask = bq_k[:, :, None, :] == -1
        dots.masked_fill_(mask, mask_value)
        del mask

        if input_mask is not None:
            h = b // input_mask.shape[0]
            if self.autopad:
                input_mask = pad_to_multiple(input_mask, window_size, dim=-1, value=False)
            input_mask = input_mask.reshape(-1, windows, window_size)
            mq = mk = input_mask
            mk = look_around(mk, pad_value=False, **look_around_kwargs)
            mask = (mq[:, :, :, None] * mk[:, :, None, :])
            mask = merge_dims(0, 1, expand_dim(mask, 1, h))
            dots.masked_fill_(~mask, mask_value)
            del mask

        attn = dots.softmax(dim=-1)
        attn = self.dropout(attn)

        out = torch.einsum('bhij,bhje->bhie', attn, bv)
        print(out.shape)
        out = out.reshape(-1, t, e)

        if self.autopad:
            out = out[:, :orig_t, :]
     
        return out.reshape(*shape)

In [49]:
la = LocalAttention(
    window_size=128,
    causal=True,
)
inputs = torch.randn(10, 1024, 256)
q,k,v = rearrange(inputs, 'b n (h d) -> b h n d', h = 4).chunk(3, dim=-1)
la(q,k,v).shape

torch.Size([40, 8, 128, 20])


RuntimeError: shape '[-1, 1024, 22]' is invalid for input of size 819200

In [22]:
model = transformer_lm(
    dim=256,
    vocab_size=29,
    depth=12,
    heads=8,
    dim_head=32,
    causal=True,
    shared_kv=True
)
inputs = torch.randint(0, 29, (1, 512))
model(inputs)

{'out': tensor([[[ 0.2689,  0.6911, -0.5602,  ..., -0.3779,  0.6310, -0.2672],
          [ 0.3529,  0.8222, -0.0128,  ..., -1.3357, -0.6433, -0.3723],
          [ 0.8861, -0.3546, -0.8761,  ..., -0.0750,  0.4937,  0.2132],
          ...,
          [ 0.2952, -0.8775, -0.7833,  ...,  0.3043,  1.0915,  0.3812],
          [ 0.1907,  0.2562,  0.1314,  ...,  0.8126, -0.2748,  0.4826],
          [-0.2756, -1.3920,  1.3631,  ..., -0.5297,  0.3971,  1.0681]]],
        grad_fn=<AddBackward0>),
 'interim_logits': tensor([[[[-1.3953e-01, -8.2908e-02, -5.8428e-01,  ...,  5.3182e-02,
             6.9809e-01, -2.3322e-01],
           [-2.6622e-01,  2.8328e-01, -1.5494e-01,  ..., -1.2958e+00,
            -3.5514e-01, -4.9902e-01],
           [ 7.5474e-01, -7.4899e-01, -1.4107e+00,  ...,  3.1767e-01,
             3.0542e-01,  5.9867e-03],
           ...,
           [ 5.9971e-01, -7.5500e-01, -1.1011e+00,  ...,  3.0627e-01,
             4.2294e-01, -1.1166e-01],
           [ 2.2694e-03, -5.0096e-01,  1.

In [433]:
import torch, torch.nn as nn, torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat
from torch import einsum
from torch.utils.checkpoint import checkpoint # # gradient/activation checkpointing
from functools import partial


class DynamicPositionBias(nn.Module):
    '''taken From Phil Wang's x-transformers library'''
    def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
        super().__init__()
        assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
        self.log_distance = log_distance

        self.mlp = nn.ModuleList([])

        self.mlp.append(nn.Sequential(
            nn.Linear(1, dim),
            nn.LayerNorm(dim) if norm else nn.Identity(),
            nn.ReLU()
        ))

        for _ in range(depth - 1):
            self.mlp.append(nn.Sequential(
                nn.Linear(dim, dim),
                nn.LayerNorm(dim) if norm else nn.Identity(),
                nn.ReLU()
            ))

        self.mlp.append(nn.Linear(dim, heads))

    def forward(self, n, device, dtype):

        # get the (n x n) matrix of distances
        seq_arange = torch.arange(n, device = device)
        context_arange = torch.arange(n, device = device)
        indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
        indices += (n - 1)
        
        # input to continuous positions MLP
        pos = torch.arange(-n + 1, n, device = device, dtype = dtype)
        pos = rearrange(pos, '... -> ... 1')

        if self.log_distance:
            pos = torch.sign(pos) * torch.log(pos.abs() + 1)  # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)

        for layer in self.mlp:
            pos = layer(pos)

        # get position biases        
        bias = pos[indices]
        bias = rearrange(bias, 'i j h -> h i j')
        return bias

class ScaledSinuEmbedding(nn.Module):
    '''taken From Phil Wang's x-transformers library'''
    def __init__(self, dim):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(1,))
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, x):
        n, device = x.shape[1], x.device
        t = torch.arange(n, device = device).type_as(self.inv_freq)
        sinu = einsum('i , j -> i j', t, self.inv_freq)
        emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
        return emb * self.scale

class ReLUSquared(nn.Module):
    def forward(self, x):
        return torch.pow(F.relu(x), 2)

def l2norm(t, groups = 1, dim = -1):
    if groups == 1:
        return F.normalize(t, p = 2, dim = dim)
    t = rearrange(t, '... (g d) -> ... g d', g = groups)
    t = F.normalize(t, p = 2, dim = dim)
    return rearrange(t, '... g d -> ... (g d)')

class CosineAttention(nn.Module):
    def __init__(
        self,
        n_feats,
        head_dim,
        n_heads,
        dropout=0.1,
        bias=False,
        temperature=15.5,
        return_attention=False,
        causal=False,
        activation='softmax',
        **kwargs
    ):
        super().__init__()
        assert activation in ['relusq', 'softmax']
        self.n_feats = n_feats
        self.head_dim = head_dim
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.bias = bias
        self.return_attention = return_attention

        self.causal = causal

        self.temperature = torch.nn.Parameter(torch.tensor(temperature), requires_grad=True) if isinstance(temperature, float) else temperature

        self.activation = ReLUSquared() if activation == 'relusq' else nn.Softmax(dim=-1)

        self.qkv_proj = nn.Linear(n_feats, 3 * n_heads * head_dim, bias=bias)
        self.out_proj = nn.Linear(n_heads * head_dim, n_feats, bias=bias)

    def head_diversity(self, dots):

        dots_cc = einsum('b h i j -> b i j h', dots) #  permute heads to last dim
        dots_cc = rearrange(dots_cc, 'b i j h -> b i j h ()')
        dots_cc = einsum('b i j h e, b i j o e -> b i j h o', dots_cc, dots_cc) #  compute pairwise dot products
        # (T = seq_len)
        dots_cc = dots_cc.sum(dim=1) / dots.shape[1]
        print(dots_cc.shape)
        dots_cc = rearrange(dots_cc, 'b i j h -> b h i j')
        dots_identity = dots * torch.eye(dots.shape[-1], device=dots.device)
        print(dots_identity.shape)
        dots_dv = dots_cc - dots_identity
        dots_dv = dots_dv.sum((2, 3)) / (dots.shape[2] * dots.shape[3])
       
        return dots_dv.mean((1,0))

       

    def attend(self, qkv, mask, pos_fn):
        query, key, value = qkv
        
        query, key = map(l2norm, (query, key))

        dots = einsum('bhid,bhjd->bhij', query, key) * self.temperature
        self.head_diversity(dots)
        dots += pos_fn(dots.shape[-1], device=dots.device, dtype=dots.dtype)
        qkmask = ~mask
        attn_mask = ~(rearrange(qkmask, "b n -> b () n ()") * rearrange(qkmask, "b n -> b () () n"))
    
        if self.causal: # create a regular causal mask
            causal_mask = torch.ones(dots.shape[-2], dots.shape[-1], device=dots.device).triu(1).bool()
            attn_mask = torch.logical_or(attn_mask, causal_mask)
        
        dots.masked_fill_(attn_mask, -torch.finfo(dots.dtype).max)
    
        attn = self.activation(dots)   
        attn = self.dropout(attn)
        return einsum("bhij,bhjd->bhid", attn, value)


    def forward(self, x, pos_fn, mask=None):
        assert pos_fn is not None, 'pls provide a position function'
        B, N, C, H, D = *x.shape, self.n_heads, self.head_dim
        #print(x.shape, mask.shape)

        if mask is None:
            mask = torch.zeros(B, N, device=x.device, dtype=torch.bool)

        qkv = rearrange(self.qkv_proj(x), "b n (h d qkv) -> qkv b h n d", qkv=3, h=H, d=D) # qkv projection
    
        out = self.attend(qkv, mask, pos_fn)

        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.out_proj(out)
        return out

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(self.norm(x), *args, **kwargs)


class GLU(nn.Module):
    def __init__(self, dim_in, dim_out, activation):
        super().__init__()
        self.act = activation
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim = -1)
        return x * self.act(gate)



class transformer(nn.Module):
    def __init__(
            self, 
            dim, 
            depth, 
            heads, 
            dim_head, 
            causal=True,
            temperature=15.5,
            shared_temperture=False,
            intermediate_loss=True,
            dropout = 0.1,
            checkpoint = True,
            **kwargs
        ):
        super().__init__()
        if depth == 1:
            intermediate_loss = False

        ff_mult = kwargs.get('ff_mult', 4)

     
        self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=True) if shared_temperture else temperature

        self.intermediate_loss = intermediate_loss

        self.depth = depth
        self.positional_bias = DynamicPositionBias(
            dim = dim // 4,
            heads = heads,
            depth = 2,
            log_distance = False,
            norm = False
        )
        self.grad_checkpointing = checkpoint
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, CosineAttention(
                    dim, 
                    n_heads=heads, 
                    head_dim=dim_head, 
                    causal=causal,
                    temperature=self.temperature,
                    dropout=dropout,
                    **kwargs
                )),
                PreNorm(dim, self.ff(dim, mult=ff_mult))
            ]))

    @staticmethod
    def ff(dim, mult=4, dropout=0.1):
        return nn.Sequential(
            GLU(dim, dim * mult, nn.SiLU()),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    @staticmethod
    def create_custom_forward(module):
        def custom_forward(*args, **kwargs):
            return module(*args, **kwargs)
        return custom_forward

    def checkpoint(self, layer, module, *args, **kwargs):
        condition = self.training and self.grad_checkpointing and layer < self.depth - 1
        return checkpoint(self.create_custom_forward(module), *args, **kwargs) if condition else module(*args, **kwargs)


    def forward(self, x, mask=None, self_condtioning=None):
        intermediate_logits = []
        for i, (attn, ff) in enumerate(self.layers):
            x = self.checkpoint(i, attn, x, self.positional_bias, mask) + x
            x = self.checkpoint(i, ff, x) + x

            if i < self.depth - 1 and self_condtioning is not None:
                x, logits = self_condtioning(x)
                intermediate_logits.append(logits)

        # stack intermediate logits
        if len(intermediate_logits) > 0:
            intermediate_logits = torch.stack(intermediate_logits, dim=0) # D x B x N x V

        return x, intermediate_logits

class shared_embedding_output_layer(nn.Module):
    '''
    Pass a embedding layer and then use this module as the output layer
    '''
    def __init__(self, embedding_layer, bias=False):
        super().__init__()
        self.embedding_layer = embedding_layer
        self.use_bias = bias
        if bias:
            self.bias = nn.Parameter(torch.zeros(embedding_layer.weight.shape[0]))#
            nn.init.xavier_uniform_(self.bias)

    def forward(self, x):
        return F.linear(x, weight=self.embedding_layer.weight, bias=self.bias if self.use_bias else None)


class transformer_lm(nn.Module):
    def __init__(
        self,
        dim,
        vocab_size,
        depth,
        heads,
        dim_head,
        causal=True,
        temperature=15.5,
        dropout=0.,
        shared_temperture=True,
        self_conditioning=False,
        intermediate_loss=True,
        use_abs_pos=False,
        **kwargs
    ):
        super().__init__()
        if depth == 1:
            self_conditioning == False

        self.self_conditioning = True if self_conditioning else None
        self.intermediate_loss = intermediate_loss

        self.use_abs_pos = use_abs_pos
        if self.use_abs_pos:
            self.abs_pos_fn = ScaledSinuEmbedding(dim=dim)
        self.abs_pos = lambda x: x + self.abs_pos_fn(x) if self.use_abs_pos else x

        if self_conditioning:
            self.reprojection_layer = nn.Linear(vocab_size, dim)

        self.layers = transformer(
            dim = dim, 
            depth = depth, 
            heads = heads, 
            dim_head = dim_head, 
            causal = causal, 
            dropout = dropout,
            temperature = temperature,
            shared_temperture = shared_temperture,
            intermediate_loss = intermediate_loss,
            **kwargs
        )

        self.tie_embedding = kwargs.get('tie_embedding', False)
        print('Tie embedding:', self.tie_embedding) if self.tie_embedding else None
 
        self.embedding = nn.Embedding(vocab_size, dim)
        self.to_logits = shared_embedding_output_layer(self.embedding) if self.tie_embedding else nn.Linear(dim, vocab_size)
        

        self.post_norm = nn.LayerNorm(dim)


    def self_condition_fn(self):
        def self_condition(x):
            logits = self.to_logits(self.post_norm(x))
            if self.self_conditioning:
                z = F.softmax(logits, dim=-1)
                z = self.reprojection_layer(z)
                x = z + x
            return x, logits
        return self_condition if (self.self_conditioning or self.intermediate_loss) and self.training else None


    def forward(self, x, mask=None):
        x = self.embedding(x)
        x = self.abs_pos(x)
        x, interim_logits = self.layers(x, mask=~mask if mask is not None else None, self_condtioning=self.self_condition_fn())
        x = self.post_norm(x)
        x = self.to_logits(x)

        return  { 'out': x, 'interim_logits': interim_logits } if self.training else x


In [434]:
model = transformer_lm(dim=256, vocab_size=128, depth=6, heads=8, dim_head=32, causal=True)
model(x, mask)

torch.Size([30, 190, 8, 8])
torch.Size([30, 8, 190, 190])


RuntimeError: The size of tensor a (8) must match the size of tensor b (190) at non-singleton dimension 3

In [92]:
N = 190
B = 30
x = torch.randint(0, 128, (B, N))
mask = torch.ones(B, N).bool()
mask[:, 50:] = False


In [19]:
mask

tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]])

In [300]:
import torch, torch.nn as nn, torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat
from torch import einsum
from torch.utils.checkpoint import checkpoint # # gradient/activation checkpointing
from functools import partial





class transformer(nn.Module):
    def __init__(
            self, 
            dim, 
            depth, 
            heads, 
            dim_head, 
            causal=True,
            temperature=15.5,
            shared_temperture=False,
            intermediate_loss=True,
            dropout = 0.1,
            checkpoint = True,
            **kwargs
        ):
        super().__init__()
        if depth == 1:
            intermediate_loss = False

        ff_mult = kwargs.get('ff_mult', 4)
     
        self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=True) if shared_temperture else temperature

        self.intermediate_loss = intermediate_loss

        self.depth = depth
 
        self.grad_checkpointing = checkpoint
        self.MLP_fnet = PreNorm(dim, MLPAttenion(
            dim, 
            n_heads=heads, 
            head_dim=dim_head, 
            causal=causal,
            temperature=self.temperature,
            dropout=dropout,
            **kwargs
        ))
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                self.MLP_fnet,
                PreNorm(dim, self.ff(dim, mult=ff_mult))
            ]))

    @staticmethod
    def ff(dim, mult=4, dropout=0.1):
        return nn.Sequential(
            GLU(dim, dim * mult, nn.SiLU()),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    @staticmethod
    def create_custom_forward(module):
        def custom_forward(*args, **kwargs):
            return module(*args, **kwargs)
        return custom_forward

    def checkpoint(self, layer, module, *args, **kwargs):
        condition = self.training and self.grad_checkpointing and layer < self.depth - 1
        return checkpoint(self.create_custom_forward(module), *args, **kwargs) if condition else module(*args, **kwargs)


    def forward(self, x, mask=None, self_condtioning=None):
        intermediate_logits = []
        for i, (attn, ff) in enumerate(self.layers):
            x = self.checkpoint(i, attn, x, mask) + x
            x = self.checkpoint(i, ff, x) + x

            if i < self.depth - 1 and self_condtioning is not None:
                x, logits = self_condtioning(x)
                intermediate_logits.append(logits)

        # stack intermediate logits
        if len(intermediate_logits) > 0:
            intermediate_logits = torch.stack(intermediate_logits, dim=0) # D x B x N x V
   
        return x, intermediate_logits


class transformer_lm(nn.Module):
    def __init__(
        self,
        dim,
        vocab_size,
        depth,
        heads,
        dim_head,
        causal=True,
        temperature=15.5,
        dropout=0.,
        shared_temperture=True,
        self_conditioning=False,
        intermediate_loss=True,
        use_abs_pos=False,
        **kwargs
    ):
        super().__init__()
        if depth == 1:
            self_conditioning == False

        self.self_conditioning = True if self_conditioning else None
        self.intermediate_loss = intermediate_loss

        self.use_abs_pos = use_abs_pos
        if self.use_abs_pos:
            self.abs_pos_fn = ScaledSinuEmbedding(dim=dim)
        self.abs_pos = lambda x: x + self.abs_pos_fn(x) if self.use_abs_pos else x

        if self_conditioning:
            self.reprojection_layer = nn.Linear(vocab_size, dim)

        self.layers = transformer(
            dim = dim, 
            depth = depth, 
            heads = heads, 
            dim_head = dim_head, 
            causal = causal, 
            dropout = dropout,
            temperature = temperature,
            shared_temperture = shared_temperture,
            intermediate_loss = intermediate_loss,
            **kwargs
        )
 
        self.to_logits = nn.Linear(dim, vocab_size)
        self.embedding = nn.Embedding(vocab_size, dim)
        self.post_norm = nn.LayerNorm(dim)

    def self_condition_fn(self):
        def self_condition(x):
            logits = self.to_logits(self.post_norm(x))
            if self.self_conditioning:
                z = F.softmax(logits, dim=-1)
                z = self.reprojection_layer(z)
                x = z + x
            return x, logits
        return self_condition if (self.self_conditioning or self.intermediate_loss) and self.training else None


    def forward(self, x, mask=None):
        x = self.embedding(x)
        x = self.abs_pos(x)
        x, interim_logits = self.layers(x, mask=~mask if mask is not None else None, self_condtioning=self.self_condition_fn())
        x = self.post_norm(x)
        x = self.to_logits(x)

        return  { 'out': x, 'interim_logits': interim_logits } if self.training else x




def unpad(x, padding_length):
    """
    Undo padding.
    """
    if padding_length > 0:
        return x[:, :-padding_length]
    return x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(self.norm(x), *args, **kwargs)


class GLU(nn.Module):
    def __init__(self, dim_in, dim_out, activation):
        super().__init__()
        self.act = activation
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim = -1)
        return x * self.act(gate)

class MLPAttenion(nn.Module):
    def __init__(
        self,
        n_feats,
        head_dim,
        n_heads,
        dropout=0.1,
        bias=False,
        temperature=15.5,
        return_attention=False,
        causal=False,
        activation='softmax',
        **kwargs
    ):
        super().__init__()
        assert activation in ['relusq', 'softmax']
        self.n_feats = n_feats
        self.head_dim = head_dim
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.bias = bias
        self.return_attention = return_attention

        self.in_proj = nn.Linear(n_feats, n_heads * head_dim)

        window_sizes = [4, 8, 16]
        self.RS_layers = nn.ModuleList(
            [recurrent_shift(dim_head=head_dim, window_size=ws, n_heads=n_heads, dropout=dropout) for ws in window_sizes]
        )
        self.out_proj = nn.Linear(n_heads * head_dim, n_feats)


    def forward(self, x, mask=None):
        B, N, C, H, D = *x.shape, self.n_heads, self.head_dim
       
        if mask is None:
            mask = torch.zeros(B, N, device=x.device, dtype=torch.bool)

        x = self.in_proj(x)
        x = rearrange(x, 'b n (h d) -> b h n d', h=H)       
    
        #mask = rearrange(mask, 'b n -> b () n ()')
        for rs in self.RS_layers:
            x = rs(x, mask=mask)

        out = rearrange(x, 'b h n d -> b n (h d)')

        out = self.out_proj(out)
      
       
        return out

def pad_to_window_size(x, window_size, axis=3, mask=None):
    """
    Pad the input on two sides to be divisible by `window_size`
    """
    batch_size, heads, sequence_length, hidden_size = x.shape
    if sequence_length % window_size == 0:
        return x, 0, mask
    padding_length = (window_size - sequence_length % window_size) % window_size
    padding = torch.zeros(batch_size, heads, padding_length, hidden_size,
        device=x.device,
        dtype=x.dtype,
    )
    mask = F.pad(mask, (0, padding_length), value=True) if mask is not None else None
    return torch.cat([x, padding], axis=axis), padding_length, mask

class OffsetScale(nn.Module):
    def __init__(self, dim, heads = 8):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(1, heads, 1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, heads, 1, 1, dim))
        nn.init.normal_(self.gamma, std = 0.02)

    def forward(self, x):
        return x * self.gamma + self.beta

class FNetBlock(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, x):
    x = torch.fft.fft(torch.fft.fft(x, dim=-1), dim=-2).real
    return x



class recurrent_shift(nn.Module):
    def __init__(self, dim_head, n_heads, window_size, dropout=0.1, bias=True):
        super().__init__()
        self.dim = dim_head
        self.dropout = nn.Dropout(dropout)
        self.bias = bias
        self.WINDOW_SIZE = window_size
     

        self.project_out = nn.Linear(self.dim * 2, self.dim, bias=bias)

        self.fnet = FNetBlock()
        self.zero_vector = torch.zeros([1, n_heads, 1, 1, self.dim])

        self.v_offset = OffsetScale(dim_head, n_heads)

 

    def forward(self, x, mask):
        WINDOW_SIZE = self.WINDOW_SIZE
      
        x, pad_n, new_mask = pad_to_window_size(x, window_size=WINDOW_SIZE, axis=-2, mask=mask) 

       
        x = rearrange(x, 'b h (w n) d -> b h w n d', w=x.shape[-2]// WINDOW_SIZE, n=WINDOW_SIZE) # group into windows
        v = self.v_offset(x)
        v_mask = rearrange(new_mask, 'b (w n) -> b 1 w n', w=new_mask.shape[-1]// WINDOW_SIZE, n=WINDOW_SIZE) # group into windows
        v = self.fnet(v)
        v.masked_fill_(v_mask.unsqueeze(-1), 0)
     
       
        zero_vector = self.zero_vector.to(v.device).expand(v.shape[0], -1, -1, v.shape[-2], -1)
    
        v = torch.cat([zero_vector, v], dim=-3)[:,:,:-1] # shift by one
      
        x = torch.cat([x, v], dim=-1)
       
        x = self.project_out(x)
        x = rearrange(x, 'b h w n d -> b h (w n) d')
        
        if pad_n > 0:
            x = x[:,:,:-pad_n]

        print(x.shape, mask.shape)
        x.masked_fill_(rearrange(mask, 'b n -> b () n ()'), 0)
        return x
        

In [301]:
model = transformer_lm(dim=256, vocab_size=8000, depth=12, heads=8, dim_head=32, causal=True)
model(x, mask)

torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) torch.Size([30, 190])
torch.Size([30, 8, 190, 32]) to

{'out': tensor([[[-0.2982,  0.1873, -1.9930,  ..., -0.3897, -0.3533, -1.2466],
          [-0.6645,  0.1250, -1.3112,  ...,  0.6729,  0.1355, -0.7009],
          [-0.9118, -0.2027, -1.1339,  ...,  0.5862, -0.2717, -1.0989],
          ...,
          [-0.4473, -0.8250, -0.5216,  ...,  0.4069,  0.3177, -1.0214],
          [-0.7444,  1.0751,  0.0451,  ...,  0.0867, -0.3953, -0.5858],
          [-0.4822, -0.2591, -0.3168,  ...,  0.1166,  0.7044, -1.1715]],
 
         [[-0.2937,  0.5568, -0.9290,  ...,  0.4349, -0.1473, -1.1261],
          [-0.1165, -0.8903, -0.3944,  ...,  0.8364, -0.7816, -0.3695],
          [-0.5482,  0.4587,  0.3291,  ...,  0.0801, -0.9207, -1.1145],
          ...,
          [-0.4250, -0.6316, -0.2543,  ..., -0.5047,  0.5773, -0.4286],
          [-0.1323, -0.0953, -0.7231,  ..., -0.1950,  0.0868, -0.6367],
          [-0.2676,  0.5273, -0.8669,  ...,  0.0447,  0.1551, -0.7647]],
 
         [[-0.6735, -0.1389, -0.9193,  ..., -0.0415, -0.2693,  0.5203],
          [-1.0036, -