In [None]:
import torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim
import numpy as np
from einops import rearrange, repeat
from torch import einsum

In [None]:
class DynamicPositionBias(nn.Module):
    def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
        super().__init__()
        assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
        self.log_distance = log_distance

        self.mlp = nn.ModuleList([])

        self.mlp.append(nn.Sequential(
            nn.Linear(1, dim),
            nn.LayerNorm(dim) if norm else nn.Identity(),
            nn.ReLU()
        ))

        for _ in range(depth - 1):
            self.mlp.append(nn.Sequential(
                nn.Linear(dim, dim),
                nn.LayerNorm(dim) if norm else nn.Identity(),
                nn.ReLU()
            ))

        self.mlp.append(nn.Linear(dim, heads))

    def forward(self, n, device, dtype):

        # get the (n x n) matrix of distances
        seq_arange = torch.arange(n, device = device)
        context_arange = torch.arange(n, device = device)
        indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
        indices += (n - 1)
        
        # input to continuous positions MLP
        pos = torch.arange(-n + 1, n, device = device, dtype = dtype)
        pos = rearrange(pos, '... -> ... 1')

        if self.log_distance:
            pos = torch.sign(pos) * torch.log(pos.abs() + 1)  # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)

        for layer in self.mlp:
            pos = layer(pos)

        # get position biases        
        bias = pos[indices]
        bias = rearrange(bias, 'i j h -> h i j')
        return bias

torch.Size([5, 9, 94, 32, 64]) dots

In [None]:
dots = torch.randn(5, 9, 94, 32, 64)

In [None]:
def ChunkGrid(Total_Size, Block_Size):
    Psize = Total_Size // Block_Size
    chunk_grid = (torch.arange(0, Psize).repeat(Psize,1) - torch.arange(0, Psize).repeat(Psize,1).T ).repeat_interleave(Block_Size, dim=1).abs()
    return chunk_grid

In [None]:
ChunkGrid(3008, 32)

In [None]:
def ChunkGrid(Total_Size, Block_Size):
    Psize = Total_Size // Block_Size
    chunk_grid = (torch.arange(0, Psize).repeat(Psize,1) - torch.arange(0, Psize).repeat(Psize,1).T ).repeat_interleave(Block_Size, dim=1).abs()
    #chunk_grid = 1 - (chunk_grid / chunk_grid.max(dim=-1)[0].unsqueeze(-1))
    return chunk_grid

In [None]:
chunkgrid = ChunkGrid(4800, 48)
pareto = torch.distributions.pareto.Pareto(torch.tensor(3.0), torch.tensor(2.0)).sample(chunkgrid.shape)
chunkgrid = chunkgrid - pareto

column = 0
print(chunkgrid[column].topk(384, largest=False).indices.max())
chunkgrid[column].topk(200, largest=False).indices.sort(-1).values

In [None]:
chunkgrid.shape

In [None]:
from matplotlib import pyplot as plt

In [None]:
dist = torch.distributions.pareto.Pareto(3, 2).sample(torch.tensor([100]))
plt.hist(dist)
plt.show()
print(dist.sort(-1).values)

In [None]:
print(chunkgrid[0].topk(40, largest=False).indices.max())
chunkgrid[0].topk(40, largest=False).indices.sort(-1).values


In [None]:
torch.ones(3, 3).triu(1).bool().repeat_interleave(3, dim=1)

In [None]:
class GumbelSigmoid():
    """
    TAKEN FROM: https://github.com/yandexdataschool/gumbel_lstm/blob/master/gumbel_sigmoid.py
    A gumbel-sigmoid nonlinearity with gumbel(0,1) noize
    In short, it's a function that mimics #[a>0] indicator where a is the logit
    
    Explaination and motivation: https://arxiv.org/abs/1611.01144
    
    Math:
    Sigmoid is a softmax of two logits: a and 0
    e^a / (e^a + e^0) = 1 / (1 + e^(0 - a)) = sigm(a)
    
    Gumbel-sigmoid is a gumbel-softmax for same logits:
    gumbel_sigm(a) = e^([a+gumbel1]/t) / [ e^([a+gumbel1]/t) + e^(gumbel2/t)]
    where t is temperature, gumbel1 and gumbel2 are two samples from gumbel noize: -log(-log(uniform(0,1)))
    gumbel_sigm(a) = 1 / ( 1 +  e^(gumbel2/t - [a+gumbel1]/t) = 1 / ( 1+ e^(-[a + gumbel1 - gumbel2]/t)
    gumbel_sigm(a) = sigm([a+gumbel1-gumbel2]/t)
    
    For computation reasons:
    gumbel1-gumbel2 = -log(-log(uniform1(0,1)) +log(-log(uniform2(0,1)) = -log( log(uniform2(0,1)) / log(uniform1(0,1)) )
    gumbel_sigm(a) = sigm([a-log(log(uniform2(0,1))/log(uniform1(0,1))]/t)
    
    
    :param t: temperature of sampling. Lower means more spike-like sampling. Can be symbolic.
    :param eps: a small number used for numerical stability
    :returns: a callable that can (and should) be used as a nonlinearity
    
    """
    def __init__(self, t=0.1, eps=1e-20):
        self.temperature=t
        self.eps=eps
         
    def __call__(self,logits):
        """computes a gumbel sigmoid sample"""
                
        #sample from Gumbel(0, 1)
        uniform1 = torch.rand(logits.shape)
        uniform2 = torch.rand(logits.shape)
        
        noise = -torch.log(torch.log(uniform2 + self.eps)/torch.log(uniform1 + self.eps) +self.eps)
        
        #draw a sample from the Gumbel-Sigmoid distribution
        return ((logits + noise) / self.temperature).sigmoid()

In [None]:
class GumbelSigmoid():
    """
    adapted from: https://github.com/yandexdataschool/gumbel_lstm/blob/master/gumbel_sigmoid.py
    and https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
    """
    def __init__(self, t=0.1, eps=None):
        self.temperature=t
         
    def __call__(self,logits):

        """computes a gumbel sigmoid sample"""
        gumbels = -torch.empty_like(logits).exponential_().log()
        gumbels = (logits + gumbels) / self.temperature
        gumbels = gumbels.sigmoid()
        return gumbels

In [None]:
# random mask with 0.6 probability of being 1
mask = (torch.randint(0, 100, (1, 10)) < 60)
print(mask.)
# pad
mask = F.pad(mask, (0, 100 - 10), value=1)
mask

In [None]:
def exists(val):
    return val is not None

class DynamicPositionBias(nn.Module):
    def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
        super().__init__()
        assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
        self.log_distance = log_distance

        self.mlp = nn.ModuleList([])

        self.mlp.append(nn.Sequential(
            nn.Linear(1, dim),
            nn.LayerNorm(dim) if norm else nn.Identity(),
            nn.ReLU()
        ))

        for _ in range(depth - 1):
            self.mlp.append(nn.Sequential(
                nn.Linear(dim, dim),
                nn.LayerNorm(dim) if norm else nn.Identity(),
                nn.ReLU()
            ))

        self.mlp.append(nn.Linear(dim, heads))

    def forward(self, n, device, dtype):

        # get the (n x n) matrix of distances
        seq_arange = torch.arange(n, device = device)
        context_arange = torch.arange(n, device = device)
        indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
        indices += (n - 1)
        
        # input to continuous positions MLP
        pos = torch.arange(-n + 1, n, device = device, dtype = dtype)
        pos = rearrange(pos, '... -> ... 1')

        if self.log_distance:
            pos = torch.sign(pos) * torch.log(pos.abs() + 1)  # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)

        for layer in self.mlp:
            pos = layer(pos)

        # get position biases        
        bias = pos[indices]
        bias = rearrange(bias, 'i j h -> h i j')
        return bias

class MyopicAttention(nn.Module):
    def __init__(
        self,
        n_feats,
        head_dim,
        n_heads,
        dropout=0.0,
        max_keep_keys=50,
        chunk_window=3,
        bias=True,
        return_attention=False,
        causal=False,
    ):
        super().__init__()
        self.n_feats = n_feats
        self.head_dim = head_dim
        self.n_heads = n_heads
        self.dropout = dropout
        self.bias = bias
        self.return_attention = return_attention

        self.causal = causal

        self.scale = head_dim ** -0.5

        self.max_keep_keys = max_keep_keys
        self.W = chunk_window

        self.positional_bias = DynamicPositionBias(
            dim = n_feats,
            heads = n_heads,
            depth = 2,
            log_distance = False,
            norm = False
        )

        self.qkv_proj = nn.Linear(n_feats, 3 * n_heads * head_dim, bias=bias)
        self.out_proj = nn.Linear(n_heads * head_dim, n_feats, bias=bias)

    def pad_to_window_size(self, x, window_size, axis=3, mask=None):
        """
        Pad the input on two sides to be divisible by `window_size`
        """
        QKV, batch_size, heads, sequence_length, hidden_size = x.shape
        padding_length = (window_size - sequence_length % window_size) % window_size
        padding = torch.zeros(QKV, batch_size, heads, padding_length, hidden_size,
            device=x.device,
            dtype=x.dtype,
        )
        mask = F.pad(mask, (0, padding_length), value=True) 
        return torch.cat([x, padding], axis=axis), padding_length, mask

    def unpad(self, x, padding_length):
        """
        Undo padding.
        """
        if padding_length > 0:
            return x[:, :-padding_length]
        return x

    def ChunkGrid(self, Total_Size, Block_Size):
        Psize = Total_Size // Block_Size
        chunk_grid = (torch.arange(0, Psize).repeat(Psize,1) - torch.arange(0, Psize).repeat(Psize,1).T ).repeat_interleave(Block_Size, dim=1).abs()
        #chunk_grid = 1 - (chunk_grid / chunk_grid.max(dim=-1)[0].unsqueeze(-1)) # don't normalize cus it'll stretch the distribution by sequence length
        return chunk_grid    

    def causal_windowed_mask(self, window_number, window_size, device):
        '''
        Create a block diagonal causal mask, to prevent selecting future tokens in the topk key selection
        '''
        return torch.ones(window_number, window_number, device=device).triu(1).bool().repeat_interleave(window_size, dim=1)

    def standard_forward(self, qkv, mask):
        query, key, value = qkv
        dots = torch.einsum('bhid,bhjd->bhij', query, key) * self.scale
        positions = self.positional_bias(dots.shape[-1], device=dots.device, dtype=dots.dtype)
        dots += positions
        attn_mask = rearrange(mask, "b n -> b () n ()") * rearrange(mask, "b n -> b () () n")
    
        if self.causal:
            # create a regular causal mask
            causal_mask = torch.ones(dots.shape[-2], dots.shape[-1], device=dots.device).triu(1).bool()
            attn_mask = torch.logical_or(attn_mask, causal_mask)

        
        dots.masked_fill_(attn_mask, -torch.finfo(dots.dtype).max)
    
        attn = dots.softmax(dim=-1)
        out = torch.einsum('bhij,bhjd->bhid', attn, value)
        return out


    def forward(self, x, mask, return_attention=False):
        assert mask is not None, 'pls wear a mask'
        B, N, C, H, D = *x.shape, self.n_heads, self.head_dim

        tokeep = min(self.max_keep_keys, N) if self.max_keep_keys != -1 else N # number of keys to keep
        W = min(self.W, N) if self.W != -1 else N # window size

        qkv = rearrange(self.qkv_proj(x), "b n (h d qkv) -> qkv b h n d", qkv=3, h=H, d=D) # qkv projection

        qkv, pad_n, mask = self.pad_to_window_size(qkv, W, axis=3, mask=mask) # add padding so it's divisible by W
        q, kv = qkv[0], qkv[1:] # separate q and kv, we keep kv together for now as we apply the same operations to both
        
        q = rearrange(q, "b h (n w) d -> b h n w d", w=W)# split q into windows/chunks of size W
      
        q_mask = repeat(rearrange(mask, "b (n w) -> b n w", w=W), "b n w -> b h n w", h=H) # do the same for the mask
            
        kv = repeat(kv, "kv b h n d -> kv b h nw n d", nw=q.shape[2]) # duplicate k and v for total number of windows
        #print(q.shape, kv.shape)
        KV, B, H, NW, N, D = kv.shape

        chunkgrid = self.ChunkGrid(Total_Size=N, Block_Size=W).to(q.device)
        chunkgrid = repeat(chunkgrid, "w n -> b h w n", b=B, h=H).contiguous()

        SCALE = torch.tensor(3.0, device=q.device, dtype=q.dtype)
        ALPHA = torch.tensor(2.0, device=q.device, dtype=q.dtype)
        pareto_dist = torch.distributions.pareto.Pareto(SCALE, ALPHA).sample(chunkgrid.shape).to(q.device)
        chunkgrid = chunkgrid - pareto_dist

        chunkgrid = repeat(chunkgrid, "b h w n -> kv b h w n", kv=2)

        cmask = repeat(mask, 'b n -> kv b h nw n', kv=2, h=H, nw=NW)

        if self.causal:
            causal_mask = self.causal_windowed_mask(window_number=NW, window_size=W, device=q.device)
            cmask = torch.logical_or(cmask, causal_mask)
        
        chunkgrid = chunkgrid.masked_fill(cmask, torch.finfo(q.dtype).max) # max cus we topk in reverse order 

        keep_indices = chunkgrid.topk(k=tokeep, dim=-1, sorted=False, largest=False).indices.sort(dim=-1).values
        KV, B, H, NW, N, D = kv.shape 
        kv = kv.gather(-2, repeat(keep_indices, "kv b h w n -> kv b h w n d", d=D))

        kv_mask = repeat(mask, "b n -> b h nw n", h=H, nw=NW)
     
        kv_mask = kv_mask.gather(-1, keep_indices[0])

        k, v = kv
        # nw (number of windows) = p (in the einsum below)
        dots = einsum("b h n p d, b h n z d -> b h n p z ", q, k) * self.scale # Z is number of chunks in Q, N is max sequence length after dropping
       
        ## positional stuff
        pos_bias = self.positional_bias(N, device=dots.device, dtype=dots.dtype)
        pos_bias = repeat(pos_bias, 'h i j -> b h i j', b = B)
        pos_bias = rearrange(pos_bias, 'b h (n w) j -> b h n w j', w = W)

        keep_indices = repeat(keep_indices, "kv b h nw n -> kv b h nw w n", w=W)[0] 
        pos_bias = pos_bias.gather(-1, keep_indices)
        
        dots = dots + pos_bias

        mask_val = -torch.finfo(dots.dtype).max
        
        qk_mask = rearrange(q_mask, "b h n w -> b h n w ()") * rearrange(kv_mask, "b h w n -> b h w () n")

        if self.causal:
            causal_mask = keep_indices > rearrange(torch.arange(0, N, device=q.device), "(nw w) -> w nw ()", w=NW, nw=W)
            qk_mask = torch.logical_or(qk_mask, causal_mask)
    
        dots.masked_fill_(qk_mask, mask_val)
      
        #print(dots.shape)
        attn = dots.softmax(dim=-1)
      

        normal_attn = self.standard_forward(qkv=qkv, mask=mask)
        normal_attn = rearrange(normal_attn, "b h n d -> b n (h d)")
     

        out = einsum("b h n w z, b h n z d -> b h n w d", attn, v) 

        out = rearrange(out, "b h n w d -> b (n w) (h d)")
   
        
        out = self.unpad(out, pad_n)
        
        out = self.out_proj(out)
     
        return out if not return_attention else (out, attn)

In [None]:
import time
class MyopicAttention3(nn.Module):
    def __init__(
        self,
        n_feats,
        head_dim,
        n_heads,
        dropout=0.0,
        max_keep_keys=50,
        chunk_window=3,
        bias=True,
        return_attention=False,
        causal=False,
        **kwargs
    ):
        super().__init__()
        self.n_feats = n_feats
        self.head_dim = head_dim
        self.n_heads = n_heads
        self.dropout = dropout
        self.bias = bias
        self.return_attention = return_attention

        self.causal = causal

        self.scale = head_dim ** -0.5

        self.max_keep_keys = max_keep_keys
        self.W = chunk_window

        self.positional_bias = DynamicPositionBias(
            dim = n_feats,
            heads = n_heads,
            depth = 2,
            log_distance = False,
            norm = False
        )

        self.grid_pos_projection = nn.Linear(1, head_dim*n_heads)
        print(self.grid_pos_projection.weight.shape)
        self.grid_k_projection = nn.Linear(head_dim, head_dim)
        grid_activation = kwargs.get("grid_activation", 'relu')
        assert grid_activation in ['relu', 'silu'], "grid_activation must be relu or silu"
     
        self.grid_activation = nn.SiLU() if grid_activation == 'silu' else nn.ReLU()
        print("Using grid activation", self.grid_activation)
        self.grid_scaler_projection = nn.Linear(head_dim, 1)
        self.gumbel_sigmoid = GumbelSigmoid(t=0.1, eps=1e-20)

        self.qkv_proj = nn.Linear(n_feats, 3 * n_heads * head_dim, bias=bias)
        self.out_proj = nn.Linear(n_heads * head_dim, n_feats, bias=bias)

    def pad_to_window_size(self, x, window_size, axis=3, mask=None):
        """
        Pad the input on two sides to be divisible by `window_size`
        """
        QKV, batch_size, heads, sequence_length, hidden_size = x.shape
        padding_length = (window_size - sequence_length % window_size) % window_size
        padding = torch.zeros(QKV, batch_size, heads, padding_length, hidden_size,
            device=x.device,
            dtype=x.dtype,
        )
        mask = F.pad(mask, (0, padding_length), value=True) 
        return torch.cat([x, padding], axis=axis), padding_length, mask

    def unpad(self, x, padding_length):
        """
        Undo padding.
        """
        if padding_length > 0:
            return x[:, :-padding_length]
        return x

    def valuegrid(self, total_size, block_size, k):
        n = total_size // block_size
        device, dtype = k.device, k.dtype
        # get time
        t = time.time()
        indices = (rearrange(torch.arange(n, device = device), 'i -> i 1') - rearrange(torch.arange(n, device = device), 'j -> 1 j')) + (n - 1)
        pos = torch.arange(-n + 1, n, device = device, dtype = torch.float32)
        pos = rearrange(pos, '... -> ... 1')
        pos = self.grid_pos_projection(pos)[indices]
    
        pos = pos.repeat_interleave(block_size, dim=1)
        pos = rearrange(pos, "p n (h d) -> () h p n d", h = self.n_heads, d = self.head_dim)
        print("pos", time.time() - t)
    
        # get time
        t = time.time()
        k_voting = self.grid_k_projection(k)
        print("k_voting k projection", time.time() - t)
        k_voting = k_voting + pos
        t = time.time()
        k_voting = self.grid_activation(k_voting)
        print("k_voting activation", time.time() - t)
        t = time.time()
        k_voting = self.grid_scaler_projection(k_voting).squeeze(-1)
        print("k_voting scaler projection", time.time() - t)
        t = time.time()
        k_voting = k_voting / k_voting.sum(dim=-1, keepdim=True)
        print("normalization", time.time() - t)
        t = time.time()
        k_voting = self.gumbel_sigmoid(k_voting)
        print("gumbel sigmoid", time.time() - t)
        return k_voting    

    def causal_windowed_mask(self, window_number, window_size, device):
        '''Create a block diagonal causal mask, to prevent selecting future tokens in the topk key selection'''
        return torch.ones(window_number, window_number, device=device).triu(1).bool().repeat_interleave(window_size, dim=1)

    def standard_forward(self, qkv, mask):
        query, key, value = qkv
        dots = torch.einsum('bhid,bhjd->bhij', query, key) * self.scale
        positions = self.positional_bias(dots.shape[-1], device=dots.device, dtype=dots.dtype)
        dots += positions
        attn_mask = rearrange(mask, "b n -> b () n ()") * rearrange(mask, "b n -> b () () n")
    
        if self.causal:
            # create a regular causal mask
            causal_mask = torch.ones(dots.shape[-2], dots.shape[-1], device=dots.device).triu(1).bool()
            attn_mask = torch.logical_or(attn_mask, causal_mask)

        
        dots.masked_fill_(attn_mask, -torch.finfo(dots.dtype).max)
    
        attn = dots.softmax(dim=-1)
        out = torch.einsum('bhij,bhjd->bhid', attn, value)
        return out


    def forward(self, x, mask, return_attention=False):
        assert mask is not None, 'pls wear a mask'
        B, N, C, H, D = *x.shape, self.n_heads, self.head_dim

        tokeep = min(self.max_keep_keys, N) if self.max_keep_keys != -1 else N # number of keys to keep
        W = min(self.W, N) if self.W != -1 else N # window size

        qkv = rearrange(self.qkv_proj(x), "b n (h d qkv) -> qkv b h n d", qkv=3, h=H, d=D) # qkv projection

        qkv, pad_n, mask = self.pad_to_window_size(qkv, W, axis=3, mask=mask) # add padding so it's divisible by W
        q, kv = qkv[0], qkv[1:] # separate q and kv, we keep kv together for now as we apply the same operations to both
        
        q = rearrange(q, "b h (n w) d -> b h n w d", w=W)# split q into windows/chunks of size W
      
        q_mask = repeat(rearrange(mask, "b (n w) -> b n w", w=W), "b n w -> b h n w", h=H) # do the same for the mask
            
        kv = repeat(kv, "kv b h n d -> kv b h nw n d", nw=q.shape[2]) # duplicate k and v for total number of windows
        #print(q.shape, kv.shape)
        KV, B, H, NW, N, D = kv.shape

        # get current time
        t = time.time()
        valuegrid = self.valuegrid(total_size=N, block_size=W, k=kv[0]).to(kv.device)
        valuegrid = repeat(valuegrid, "b h w n -> kv b h w n", kv=2)
        kv = kv * valuegrid.unsqueeze(-1) # maybe just do this for the keys? idk
        print(f'value gridding done in: {time.time() - t:.2f}s')

        cmask = repeat(mask, 'b n -> kv b h nw n', kv=2, h=H, nw=NW)

        if self.causal:
            causal_mask = self.causal_windowed_mask(window_number=NW, window_size=W, device=q.device)
            cmask = torch.logical_or(cmask, causal_mask)
        
        valuegrid = valuegrid.masked_fill(cmask, -torch.finfo(q.dtype).max) 

        # get current time
        t = time.time()
        keep_indices = valuegrid.topk(k=tokeep, dim=-1, sorted=False, largest=True).indices.sort(dim=-1).values
        print(f'topk done in: {time.time() - t:.2f}s')
        KV, B, H, NW, N, D = kv.shape 
       
        # get current time
        t = time.time()
        kv = kv.gather(-2, repeat(keep_indices, "kv b h w n -> kv b h w n d", d=D))
        print(f'gather done in: {time.time() - t:.2f}s')

        kv_mask = repeat(mask, "b n -> b h nw n", h=H, nw=NW)
     
        # get current time
        t = time.time()
        kv_mask = kv_mask.gather(-1, keep_indices[0])
        print(f'kv mask done in: {time.time() - t:.2f}s')

        k, v = kv
        # nw (number of windows) = p (in the einsum below)
        # get current time
        t = time.time()
        dots = einsum("b h n p d, b h n z d -> b h n p z ", q, k) * self.scale # Z is number of chunks in Q, N is max sequence length after dropping
        print(f'dots done in: {time.time() - t:.2f}s')


        ## positional stuff
        # get current time
        t = time.time()
        pos_bias = self.positional_bias(N, device=dots.device, dtype=dots.dtype)
        pos_bias = repeat(pos_bias, 'h i j -> b h i j', b = B)
        pos_bias = rearrange(pos_bias, 'b h (n w) j -> b h n w j', w = W)

        keep_indices = repeat(keep_indices, "kv b h nw n -> kv b h nw w n", w=W)[0] 
        pos_bias = pos_bias.gather(-1, keep_indices)
        
        dots = dots + pos_bias
        print(f'pos bias done in: {time.time() - t:.2f}s')

        mask_val = -torch.finfo(dots.dtype).max
        
        qk_mask = rearrange(q_mask, "b h n w -> b h n w ()") * rearrange(kv_mask, "b h w n -> b h w () n")

        if self.causal:
            causal_mask = keep_indices > rearrange(torch.arange(0, N, device=q.device), "(nw w) -> w nw ()", w=NW, nw=W)
            qk_mask = torch.logical_or(qk_mask, causal_mask)
    
        dots.masked_fill_(qk_mask, mask_val)
      
        #print(dots.shape)
        attn = dots.softmax(dim=-1)
      

        #normal_attn = self.standard_forward(qkv=qkv, mask=mask)
        #normal_attn = rearrange(normal_attn, "b h n d -> b n (h d)")
     

        out = einsum("b h n w z, b h n z d -> b h n w d", attn, v) 

        out = rearrange(out, "b h n w d -> b (n w) (h d)")
   
        
        out = self.unpad(out, pad_n)
        
        out = self.out_proj(out)
     
        return out if not return_attention else (out, attn)

In [None]:
torch.randint(0, 10, (1, 1, max(50 // 4,1))).repeat_interleave(5, dim=-1).topk(k=30)

In [None]:
class SequenceDropoutAttention(nn.Module):
    def __init__(
        self,
        n_feats,
        head_dim,
        n_heads,
        dropout=0.0,
        sequence_dropout=0.4,
        bias=False,
        return_attention=False,
        causal=False,
    ):
        super().__init__()
        self.n_feats = n_feats
        self.head_dim = head_dim
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.bias = bias
        self.return_attention = return_attention

        self.causal = causal

        self.scale = head_dim ** -0.5

        self.sequence_dropout = sequence_dropout

        self.positional_bias = DynamicPositionBias(
            dim = n_feats,
            heads = n_heads,
            depth = 2,
            log_distance = False,
            norm = False
        )

     
        self.qkv_proj = nn.Linear(n_feats, 3 * n_heads * head_dim, bias=bias)
        self.out_proj = nn.Linear(n_heads * head_dim, n_feats, bias=bias)

    def standard_forward(self, qkv, mask, pos_fn): 
        query, key, value = qkv
        dots = torch.einsum('bhid,bhjd->bhij', query, key) * self.scale
        dots += pos_fn(dots.shape[-1], device=dots.device, dtype=dots.dtype)
        attn_mask = rearrange(mask, "b n -> b () n ()") * rearrange(mask, "b n -> b () () n")
    
        if self.causal: # create a regular causal mask
            causal_mask = torch.ones(dots.shape[-2], dots.shape[-1], device=dots.device).triu(1).bool()
            attn_mask = torch.logical_or(attn_mask, causal_mask)
        
        dots.masked_fill_(attn_mask, -torch.finfo(dots.dtype).max)
    
        attn = dots.softmax(dim=-1)        
        attn = self.dropout(attn)
        return torch.einsum("bhij,bhjd->bhid", attn, value)
        

    def forward(self, x, pos_fn, mask=None, return_attention=False, standard_attention=False):
        assert pos_fn is not None, 'pls provide a position function'
        B, N, C, H, D = *x.shape, self.n_heads, self.head_dim

        if mask is None:
            mask = torch.zeros(B, N, device=x.device, dtype=torch.bool)

        qkv = rearrange(self.qkv_proj(x), "b n (h d qkv) -> qkv b h n d", qkv=3, h=H, d=D) # qkv projection
        if not standard_attention:
            q, kv = qkv[0], qkv[1:] # separate q and kv, we keep kv together for now as we apply the same operations to both
            # get a random sequence of zero or ones that is a quarter of the sequence length
            print(N)
            seq_mask = torch.randint(0, 10, (B, H, max(N // 4,1)), device=x.device, dtype=torch.int8) # int8 to save memory
            seq_mask = seq_mask.repeat_interleave(5, dim=-1)[..., :N] # repeat the mask to be the same length as the sequence
            if seq_mask.shape[-1] != N:
                seq_mask = F.pad(seq_mask, (0, N - seq_mask.shape[-1]), value=4) # do this better
            
            k = int(N * (1 - self.sequence_dropout)) # get the number of tokens to keep
            # mask out seq_mask to prioritize keeping non-masked values
            print(seq_mask.shape, mask.shape)
            seq_mask = seq_mask.masked_fill(mask.unsqueeze(1), -1)
            keep_indices = torch.topk(seq_mask, k, dim=-1, largest=True, sorted=False).indices.sort(dim=-1).values # get the indices to keep
            _, kv_B, kv_H, kv_N, kv_D = kv.shape
            # get the keys and values to keep
            kv = kv.gather(-2, repeat(keep_indices, "b h n -> kv b h n d", kv=2, d=D)) # get the keys and values to keep
        
            k_mask = repeat(mask, "b n -> b h n", h=H).gather(-1, keep_indices) # get the mask for the keys
            k, v = kv # separate the keys and values

            dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale # dot product between the queries and keys
            
            # positional stuff
            pos_bias = pos_fn(N, device=dots.device, dtype=dots.dtype)
            pos_bias = repeat(pos_bias, 'h i j -> b h i j', b = B)
            keep_indices = repeat(keep_indices, "b h n -> b h i n", i=N)
            pos_bias = pos_bias.gather(-1, keep_indices)
            dots = dots + pos_bias
        
            qk_mask = rearrange(mask, "b n -> b () n ()") * rearrange(k_mask, "b h n -> b h () n")
            
            if self.causal:
                causal_mask = keep_indices > rearrange(torch.arange(0, N, device=q.device), "n -> n ()", n=N)
                qk_mask = torch.logical_or(qk_mask, causal_mask)
            
            dots.masked_fill_(qk_mask, -torch.finfo(dots.dtype).max)
            attn = dots.softmax(dim=-1)
            attn = self.dropout(attn)
            out = einsum("bhij,bhjd->bhid", attn, v)
        else:
            out = self.standard_forward(qkv, mask, pos_fn)


        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.out_proj(out)
        return out if not return_attention else (out, attn)

In [None]:
torch.randint(0, 10, (3, 12, max(6 // 4,1)), device=x.device, dtype=torch.int8).shape

In [139]:
class SequenceMaskingAttention(nn.Module):
    def __init__(
        self,
        n_feats,
        head_dim,
        n_heads,
        dropout=0.0,
        sequence_dropout=1.0,
        bias=False,
        return_attention=False,
        causal=False,
        activation='softmax',
        **kwargs
    ):
        super().__init__()
        assert activation in ['relusq', 'softmax']
        self.n_feats = n_feats
        self.head_dim = head_dim
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.bias = bias
        self.return_attention = return_attention

        self.causal = causal

        self.scale = head_dim ** -0.5

        self.sequence_dropout = sequence_dropout
        self.activation = torch.nn.Softmax(dim=-1) 

        self.qkv_proj = nn.Linear(n_feats, 3 * n_heads * head_dim, bias=bias)
        self.out_proj = nn.Linear(n_heads * head_dim, n_feats, bias=bias)

    def standard_forward(self, qkv, mask, pos_fn): 
        query, key, value = qkv
        dots = torch.einsum('bhid,bhjd->bhij', query, key) * self.scale
        dots += pos_fn(dots.shape[-1], device=dots.device, dtype=dots.dtype)
        attn_mask = rearrange(mask, "b n -> b () n ()") * rearrange(mask, "b n -> b () () n")
    
        if self.causal: # create a regular causal mask
            causal_mask = torch.ones(dots.shape[-2], dots.shape[-1], device=dots.device).triu(1).bool()
            attn_mask = torch.logical_or(attn_mask, causal_mask)
        
        dots.masked_fill_(attn_mask, -torch.finfo(dots.dtype).max)
    
        attn = self.activation(dots)   
        attn = self.dropout(attn)
        return torch.einsum("bhij,bhjd->bhid", attn, value)
        

    def forward(self, x, pos_fn, mask=None, return_attention=False, standard_attention=False):
        assert pos_fn is not None, 'pls provide a position function'
       
        B, N, C, H, D = *x.shape, self.n_heads, self.head_dim

        standard_attention = standard_attention if self.training else True # we don't want to use dropout during inference

        if mask is None:
            mask = torch.zeros(B, N, device=x.device, dtype=torch.bool)

        qkv = rearrange(self.qkv_proj(x), "b n (h d qkv) -> qkv b h n d", qkv=3, h=H, d=D) # qkv projection
        if not standard_attention:
            q, k, v = qkv
            # get a random sequence of zero or ones that is a quarter of the sequence length
            seq_mask = torch.randint(0, 100, (B, H, max(N // 4,1)), device=x.device, dtype=torch.int8) < (self.sequence_dropout * 100)
            seq_mask = seq_mask.repeat_interleave(5, dim=-1)[..., :N] # repeat the mask to be the same length as the sequence
            if seq_mask.shape[-1] != N:
                diff = N - seq_mask.shape[-1]
                pad_seq = torch.randint(0, 100, (B, H, diff), device=x.device, dtype=torch.int8) < (self.sequence_dropout * 100)
                seq_mask = torch.cat((seq_mask, pad_seq), dim=-1)

            dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale # dot product between the queries and keys
            
            # positional stuff
            pos_bias = pos_fn(N, device=dots.device, dtype=dots.dtype)
            dots = dots + pos_bias
            
            attn_mask = rearrange(mask, "b n -> b () n ()") * rearrange(mask, "b n -> b () () n")
            
            if self.causal:
                causal_mask = torch.ones(dots.shape[-2], dots.shape[-1], device=dots.device).triu(1).bool()
                qk_mask = torch.logical_or(attn_mask, causal_mask)
            
            dots.masked_fill_(qk_mask, -torch.finfo(dots.dtype).max)
            attn = self.activation(dots)
            # apply the sequence mask
            attn = attn.masked_fill(rearrange(seq_mask, "b h n -> b h () n"), 0)

            attn = self.dropout(attn)
            out = einsum("bhij,bhjd->bhid", attn, v)
        else:
            out = self.standard_forward(qkv, mask, pos_fn)


        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.out_proj(out)
        return out if not return_attention else (out, attn)

In [140]:
positional_bias = DynamicPositionBias(
    dim = 216,
    heads = 12,
    depth = 2,
    log_distance = False,
    norm = False
)

attention = SequenceMaskingAttention(n_feats=216, head_dim=24, n_heads=12, causal=True, sequence_dropout=1.0)
N = 10
B = 30
x = torch.ones(B, N, 216) + torch.randn(B, N, 216) * 0.01
mask = torch.zeros(B, N).bool()
#mask[0, 0:10] = True
#mask[2, 23:45] = True
#attention.eval()
attention.requires_grad_(False)
print()
a = attention(x, pos_fn=positional_bias, mask=mask, standard_attention=False)


torch.Size([30, 12, 10, 10]) torch.Size([30, 12, 10])


In [131]:
a

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

In [None]:


for i in range(1,100):
    print(i)
    attention = SequenceDropoutAttention(n_feats=216, head_dim=24, n_heads=12, causal=True, sequence_dropout=0.5)
    N = i
    B = 3
    x = torch.ones(B, N, 216) + torch.randn(B, N, 216) * 0.01
    mask = torch.zeros(B, N).bool()
    #mask[0, 0:10] = True
    #mask[2, 23:45] = True
    attention.eval()
    attention.requires_grad_(False)
    print()
    attn= attention(x, pos_fn=attention.positional_bias, mask=mask, standard_attention=False)

In [None]:
attn= attention(x, pos_fn=attention.positional_bias, mask=mask, standard_attention=False)

In [None]:
attn

In [None]:
sa

In [None]:
attn1 = attn

In [None]:

class MyopicAttention2(nn.Module):
    def __init__(
        self,
        n_feats,
        head_dim,
        n_heads,
        dropout=0.0,
        max_keep_keys=50,
        chunk_window=3,
        bias=True,
        return_attention=False,
        causal=False,
        **kwargs
    ):
        super().__init__()
        self.n_feats = n_feats
        self.head_dim = head_dim
        self.n_heads = n_heads
        self.dropout = dropout
        self.bias = bias
        self.return_attention = return_attention

        self.causal = causal

        self.scale = head_dim ** -0.5

        self.max_keep_keys = max_keep_keys
        self.W = chunk_window

        self.positional_bias = DynamicPositionBias(
            dim = n_feats,
            heads = n_heads,
            depth = 2,
            log_distance = False,
            norm = False
        )

        self.half_precision_mode = kwargs.get("half_precision_mode", 'float16') # 'float16' or 'bfloat16' or float32

        self.scale = nn.Parameter(torch.tensor(kwargs.get('scale', 3.0), requires_grad=True))
        self.alpha = nn.Parameter(torch.tensor(kwargs.get('alpha', 2.0), requires_grad=True))
        self.distance_multiplier = nn.Parameter(torch.tensor([kwargs.get('distance_multiplier_prior', 1.0)]*n_heads, requires_grad=True))

        self.qkv_proj = nn.Linear(n_feats, 3 * n_heads * head_dim, bias=bias)
        self.out_proj = nn.Linear(n_heads * head_dim, n_feats, bias=bias)

    def pad_to_window_size(self, x, window_size, axis=3, mask=None):
        """
        Pad the input on two sides to be divisible by `window_size`
        """
        QKV, batch_size, heads, sequence_length, hidden_size = x.shape
        padding_length = (window_size - sequence_length % window_size) % window_size
        padding = torch.zeros(QKV, batch_size, heads, padding_length, hidden_size,
            device=x.device,
            dtype=x.dtype,
        )
        mask = F.pad(mask, (0, padding_length), value=True) 
        return torch.cat([x, padding], axis=axis), padding_length, mask

    def unpad(self, x, padding_length):
        """
        Undo padding.
        """
        if padding_length > 0:
            return x[:, :-padding_length]
        return x

    def ChunkGrid(self, Total_Size, Block_Size):
        Psize = Total_Size // Block_Size
        chunk_grid = (torch.arange(0, Psize).repeat(Psize,1) - torch.arange(0, Psize).repeat(Psize,1).T ).repeat_interleave(Block_Size, dim=1).abs()
        return chunk_grid #* self.distance_multiplier.to(chunk_grid.device)

    def causal_windowed_mask(self, window_number, window_size, device):
        '''
        Create a block diagonal causal mask, to prevent selecting future tokens in the topk key selection
        '''
        return torch.ones(window_number, window_number, device=device).triu(1).bool().repeat_interleave(window_size, dim=1)


    def half_precision_if_on_cuda(self, x, is_cuda):
        if not is_cuda:
            return x
        elif self.half_precision_mode == 'float16':
            return x.half()
        elif self.half_precision_mode == 'bfloat16':
            return x.bfloat16()
        else:
            return x
      

    def standard_forward(self, qkv, mask):
        query, key, value = qkv
        dots = torch.einsum('bhid,bhjd->bhij', query, key) * self.scale
        positions = self.positional_bias(dots.shape[-1], device=dots.device, dtype=dots.dtype)
        dots += positions
        attn_mask = rearrange(mask, "b n -> b () n ()") * rearrange(mask, "b n -> b () () n")
    
        if self.causal:
            # create a regular causal mask
            causal_mask = torch.ones(dots.shape[-2], dots.shape[-1], device=dots.device).triu(1).bool()
            attn_mask = torch.logical_or(attn_mask, causal_mask)

        
        dots.masked_fill_(attn_mask, -torch.finfo(dots.dtype).max)
    
        attn = dots.softmax(dim=-1)
        out = torch.einsum('bhij,bhjd->bhid', attn, value)
        return out


    def forward(self, x, mask, return_attention=False):
        assert mask is not None, 'pls wear a mask'
        B, N, C, H, D = *x.shape, self.n_heads, self.head_dim

        tokeep = min(self.max_keep_keys, N) if self.max_keep_keys != -1 else N # number of keys to keep
        W = min(self.W, N) if self.W != -1 else N # window size

        qkv = rearrange(self.qkv_proj(x), "b n (h d qkv) -> qkv b h n d", qkv=3, h=H, d=D) # qkv projection

        qkv, pad_n, mask = self.pad_to_window_size(qkv, W, axis=3, mask=mask) # add padding so it's divisible by W
        q, kv = qkv[0], qkv[1:] # separate q and kv, we keep kv together for now as we apply the same operations to both
        
        q = rearrange(q, "b h (n w) d -> b h n w d", w=W) # split q into windows/chunks of size W
      
        q_mask = repeat(rearrange(mask, "b (n w) -> b n w", w=W), "b n w -> b h n w", h=H) # do the same for the mask
        kv = repeat(kv, "kv b h n d -> kv b h nw n d", nw=q.shape[2]) # duplicate k and v for total number of windows
        
        KV, B, H, NW, N, D = kv.shape

        chunkgrid = self.ChunkGrid(Total_Size=N, Block_Size=W).to(q.device)
        chunkgrid = self.half_precision_if_on_cuda(chunkgrid, q.is_cuda)
        chunkgrid = repeat(chunkgrid, "w n -> b h w n", b=B, h=H).contiguous() 
        distance_multiplier = self.half_precision_if_on_cuda(self.distance_multiplier, q.is_cuda).to(q.device)
        distance_multiplier = repeat(self.distance_multiplier, "h -> b h w n", b=B, w=NW, n=N)
        chunkgrid = chunkgrid * distance_multiplier
        

        SCALE = self.scale.to(q.device).to(chunkgrid.dtype)
        ALPHA = self.alpha.to(q.device).to(chunkgrid.dtype)
        pareto_dist = torch.distributions.pareto.Pareto(SCALE, ALPHA).rsample(chunkgrid.shape).to(chunkgrid.device) # rsample so we can backprop
        chunkgrid = chunkgrid - pareto_dist

        chunkgrid = repeat(chunkgrid, "b h w n -> kv b h w n", kv=2)

        cmask = repeat(mask, 'b n -> kv b h nw n', kv=2, h=H, nw=NW)

        if self.causal:
            causal_mask = self.causal_windowed_mask(window_number=NW, window_size=W, device=q.device)
            cmask = torch.logical_or(cmask, causal_mask)
        
        chunkgrid = chunkgrid.masked_fill(cmask, torch.finfo(chunkgrid.dtype).max) # max cus we topk in reverse order 

        keep_indices = chunkgrid.topk(k=tokeep, dim=-1, sorted=False, largest=False)
        '''
         we want to take half of the keep indices (the ones with the largest vals) and apply a softmax to the values
         then scatter the values with a multiply reduction to the k tensor
         this allows the model to learn which keys to keep using the parametized pareto distribution
         so kinda works like a relu (but not really)
        '''
        sorted_vals_with_indices = keep_indices.values.sort(-1)
        num_to_scatter = tokeep // 4 # number of keys to scatter to kv
        scatter_indices = sorted_vals_with_indices.indices[..., -num_to_scatter:].long() # indices of the keys to scatter
        scatter_indices = repeat(scatter_indices[0], 'b h w n -> b h w n d', d=D)
        scatter_vals = sorted_vals_with_indices.values[..., -num_to_scatter:] # values of the keys to scatter
        scatter_vals = scatter_vals[0].softmax(-1) #* -1 + 1 # softmax but we want the smallest values to be the largest
        scatter_vals = repeat(scatter_vals, 'b h w n -> b h w n d', d=D)
        scatter_vals = scatter_vals * torch.randn(D, device=scatter_vals.device, dtype=scatter_vals.dtype) # add some noise to the values

        print(scatter_vals.shape, 'scatter_vals')
        print(kv.shape, 'kv')

        kv = kv.contiguous() # we need kv to be contigous so we can scatter properly
        kv[0] = scatter(
            src = scatter_vals,
            index = scatter_indices,
            dim = -2,
            out = kv[0].clone(),
            reduce = 'mul'
        )

        ###
        keep_indices = keep_indices.indices.sort(dim=-1).values
        KV, B, H, NW, N, D = kv.shape 
        kv = kv.gather(-2, repeat(keep_indices, "kv b h w n -> kv b h w n d", d=D))

        kv_mask = repeat(mask, "b n -> b h nw n", h=H, nw=NW)
     
        kv_mask = kv_mask.gather(-1, keep_indices[0])

        k, v = kv
        # nw (number of windows) = p (in the einsum below)
        dots = einsum("b h n p d, b h n z d -> b h n p z ", q, k) * self.scale # Z is number of chunks in Q, N is max sequence length after dropping
       
        ## positional stuff
        pos_bias = self.positional_bias(N, device=dots.device, dtype=dots.dtype)
        pos_bias = repeat(pos_bias, 'h i j -> b h i j', b = B)
        pos_bias = rearrange(pos_bias, 'b h (n w) j -> b h n w j', w = W)

        keep_indices = repeat(keep_indices, "kv b h nw n -> kv b h nw w n", w=W)[0] 
        pos_bias = pos_bias.gather(-1, keep_indices)
        
        dots = dots + pos_bias

        mask_val = -torch.finfo(dots.dtype).max
        
        qk_mask = rearrange(q_mask, "b h n w -> b h n w ()") * rearrange(kv_mask, "b h w n -> b h w () n")

        if self.causal:
            causal_mask = keep_indices > rearrange(torch.arange(0, N, device=q.device), "(nw w) -> w nw ()", w=NW, nw=W)
            qk_mask = torch.logical_or(qk_mask, causal_mask)
    
        dots.masked_fill_(qk_mask, mask_val)
      
        #print(dots.shape)
        attn = dots.softmax(dim=-1)
      

        normal_attn = self.standard_forward(qkv=qkv, mask=mask)
        normal_attn = rearrange(normal_attn, "b h n d -> b n (h d)")
     

        out = einsum("b h n w z, b h n z d -> b h n w d", attn, v) 

        out = rearrange(out, "b h n w d -> b (n w) (h d)")
   
        
        out = self.unpad(out, pad_n)
        
        out = self.out_proj(out)
     
        return out if not return_attention else (out, attn)

In [None]:
attention = SequenceDropoutAttention(n_feats=216, head_dim=24, n_heads=12, max_keep_keys=256, chunk_window=128, causal=True)

x = torch.ones(10, 5000, 216) + torch.randn(10, 5000, 216) * 0.01
mask = torch.zeros(10, 5000).bool()
mask[0, 0:10] = True
mask[2, 23:45] = True

attn = attention(x, mask)

In [None]:
from torch_scatter import scatter

src = torch.arange(1, 11, requires_grad=True, dtype=torch.float32).reshape((2, 5))

index = torch.tensor([[2,2,2,2,2],[1,1,1,1,1]])

out = torch.ones(3, 5, dtype=src.dtype, requires_grad=True)
# Broadcasting in the first and last dim.

out2 = scatter(
    src=src.clone().to(torch.float16),
    index=index.clone().long(),
    dim=0,
    out=out.clone().to(torch.float16),
    reduce='mul'
)

out2


In [None]:
src = torch.arange(1, 11, requires_grad=True, dtype=torch.float32).reshape((2, 5))
index = torch.tensor([[2,2,2,2,2],[1,1,1,1,1]])
out = torch.ones(3, 5, dtype=src.dtype, requires_grad=True)
out = out.scatter(0, index, src, reduce="multiply")
out

In [None]:
num1 = 0
num2 = 100
torch.allclose(attn[1][num1,num2][:99],attn[0][num1,num2][:99])

In [None]:
attn[0][num1,num2][:99]

In [None]:
attn[1][num1,num2][:99]

In [None]:
torch.allclose(attn[0], attn[1])

In [None]:
torch.tensor([0,1,2,3,4,5,6]).topk(k=3, sorted=False, largest=False).indices.sort(dim=-1).values

In [None]:
x = torch.randn(10, 1000, 216)
mask = torch.zeros(10, 1000).bool()

In [None]:
### OLD CHUNKGRID SHIT
def ChunkGrid(self, Total_Size, Block_Size):
    Psize = Total_Size // Block_Size
    chunk_grid = (torch.arange(0, Psize).repeat(Psize,1) - torch.arange(0, Psize).repeat(Psize,1).T ).repeat_interleave(Block_Size, dim=1).abs()
    chunk_grid = 1 - (chunk_grid / chunk_grid.max(dim=-1)[0].unsqueeze(-1))
    return chunk_grid    

chunkgrid = repeat(chunkgrid, "w n -> b h w n", b=B, h=H).contiguous()
MEAN = torch.tensor(0, device=q.device, dtype=q.dtype)
STD = torch.tensor(0.125, device=q.device, dtype=q.dtype)
uniform_dist = torch.distributions.normal.Normal(MEAN, STD).sample(chunkgrid.shape).to(q.device)
chunkgrid += uniform_dist
chunkgrid = repeat(chunkgrid, "b h w n -> kv b h w n", kv=2)

In [None]:
def causal_windowed_mask(window_number, window_size, device):
    mask = torch.ones(window_number, window_number, device=device).triu(1).bool().repeat_interleave(window_size, dim=1)
    return mask

In [None]:
causal_windowed_mask(3, 4, device='cpu')

In [None]:
a.indices.sort().values[0, 0, 0, 0]

In [None]:
torch.tensor(2).unsqueeze(-1).repeat(2).expand(2, 2)

In [None]:
torch.randn()

In [None]:
km.shape

In [None]:
km[0, 0, :, 0, :100].sum(-1)

In [None]:
km.shape

In [None]:
(torch.arange(0, 3008).repeat(3008,1) - torch.arange(0, 3008).repeat(3008,1).T).reshape(32, -1, 3008).shape

- duplicated across KV
- each batch, head and Window have a different view of the keys
- 94 is the number of windows i.e 94*32(win size) = 3008 (sequence length)

In [None]:
kv.shape

In [None]:
kv[km].reshape(2, 5, 8, 3, -1, 24).shape

In [None]:
repeat(cg, "W N -> KV B H W N", B=5, H=8, KV=2).shape

In [None]:
kv.s

In [None]:
def ChunkGrid(N_BLOCKS, BLOCK_SIZE):
    chunk_grid = (torch.arange(0, N_BLOCKS).repeat(BLOCK_SIZE,1) - torch.arange(0, BLOCK_SIZE).repeat(N_BLOCKS,1).T).repeat_interleave(BLOCK_SIZE, dim=1).abs()
    chunk_grid = chunk_grid / chunk_grid.max(dim=-1)[0].unsqueeze(-1)
    return chunk_grid

In [None]:
cg = ChunkGrid(41, 3)
uniform_dist = torch.distributions.uniform.Uniform(0, 1).sample(cg.shape)
cg += uniform_dist
keep_indices = cg.topk(9, dim=-1).indices
keep_mask = torch.zeros_like(cg).scatter_(1, keep_indices, 1).bool()

In [None]:
cg.shape

In [None]:
keep_mask