In [15]:
cache_len = torch.tensor([2,5,2,8])
cur_lens = torch.tensor([5,2,5,7])
total_len = cache_lens + cur_lens
print(total_len)

tensor([ 7,  7,  7, 15])


In [28]:
causal_mask.shape

torch.Size([4, 15])

In [21]:
torch.arange(causal_mask.shape[0])[None].T

tensor([[0],
        [1],
        [2],
        [3]])

In [107]:
(cache_lens[:,None,None] + torch.arange(cur_lens.max())[None,:,None]).shape

torch.Size([4, 7, 1])

In [71]:
torch.arange(total_len.max())[None,None,:]

torch.Size([1, 1, 15])

In [119]:
print(total_len, cur_lens, cache_len)

tensor([ 7,  7,  7, 15]) tensor([5, 2, 5, 7]) tensor([2, 5, 2, 8])


In [129]:
causal_mask = repeat(torch.arange(total_len.max()), 'i -> b r i', b=len(total_len), r=cur_lens.max())
causal_mask = causal_mask >= (cache_lens[:,None,None] + torch.arange(cur_lens.max())[None,:,None] + 1)
causal_mask[-3,0]


torch.Size([4, 7, 15]) torch.Size([4, 1, 1])


tensor([False, False, False, False, False, False,  True,  True,  True,  True,
         True,  True,  True,  True,  True])

In [25]:
causal_mask = torch.arange(total_len.max()).expand(len(total_len), -1)
causal_mask = causal_mask > cache_len[:,None]
causal_mask = causal_mask < torch.arange(causal_mask.shape[0])[None].T + cache_len[:,None] 

In [27]:
torch.arange(causal_mask.shape[0])[None].T + cache_len[:,None] 

tensor([[ 2],
        [ 6],
        [ 4],
        [11]])

In [357]:
import torch, torch.nn as nn, torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat
from torch import einsum
from torch.utils.checkpoint import checkpoint # # gradient/activation checkpointing
from functools import partial
from typing import Dict, List, Optional, Tuple, Union


def exists(val):
    return val is not None

# token shifting
# lucidrains implementation: https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
# BlinkDL idea from RWKV-LM https://github.com/BlinkDL/RWKV-LM
def shift(t, amount, mask = None):
    if amount == 0:
        return t
    else:
        amount = min(amount, t.shape[1])

    if exists(mask):
        t = t.masked_fill(~mask[..., None], 0.)

    return F.pad(t, (0, 0, amount, -amount), value = 0.)

class ShiftTokens(nn.Module):
    def __init__(self, shifts, fn):
        super().__init__()
        self.fn = fn
        self.shifts = tuple(shifts)

    def forward(self, x, **kwargs):
        mask = kwargs.get('mask', None)
        shifts = self.shifts
        segments = len(shifts)
        feats_per_shift = x.shape[-1] // segments
        splitted = x.split(feats_per_shift, dim = -1)
        segments_to_shift, rest = splitted[:segments], splitted[segments:]
        segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
        x = torch.cat((*segments_to_shift, *rest), dim = -1)
        return self.fn(x, **kwargs)


class DynamicPositionBias(nn.Module):
    '''Adapted From Phil Wang's x-transformers library to handle non-square matrices'''
    def __init__(self, dim, *, heads, depth, log_distance = False, norm = False, activation=nn.SiLU):
        super().__init__()
        assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
        self.log_distance = log_distance

        self.mlp = nn.ModuleList([])

        self.mlp.append(nn.Sequential(
            nn.Linear(1, dim),
            nn.LayerNorm(dim) if norm else nn.Identity(),
            activation()
        ))

        for _ in range(depth - 1):
            self.mlp.append(nn.Sequential(
                nn.Linear(dim, dim),
                nn.LayerNorm(dim) if norm else nn.Identity(),
                activation()
            ))

        self.mlp.append(nn.Linear(dim, heads))



    def forward(self, i, j, device, dtype):
        # get the (i x j) matrix of distances
        assert i >= 1 and j >= 1 and i <= j, 'I should be in the range [1, j] and j >= 1'
        seq_arange = torch.arange(i, device = device)
        context_arange = torch.arange(j, device = device)
        indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
        indices += (j-1)
        
        # input to continuous positions MLP
        pos = torch.arange(-i + 1, (j+i), device = device, dtype = dtype)
        pos = rearrange(pos, '... -> ... 1')
     
        if self.log_distance:
            pos = torch.sign(pos) * torch.log(pos.abs() + 1)  # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)

        for layer in self.mlp:
            pos = layer(pos) 

        # get position biases        
        bias = pos[indices]
        bias = rearrange(bias, 'i j h -> h i j')
        return bias

class ScaledSinuEmbedding(nn.Module):
    '''taken From Phil Wang's x-transformers library'''
    def __init__(self, dim):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(1,))
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, x):
        n, device = x.shape[1], x.device
        t = torch.arange(n, device = device).type_as(self.inv_freq)
        sinu = einsum('i , j -> i j', t, self.inv_freq)
        emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
        return emb * self.scale

class ReLUSquared(nn.Module):
    def forward(self, x):
        return torch.pow(F.relu(x), 2)

def l2norm(t, groups = 1, dim = -1):
    if groups == 1:
        return F.normalize(t, p = 2, dim = dim)
    t = rearrange(t, '... (g d) -> ... g d', g = groups)
    t = F.normalize(t, p = 2, dim = dim)
    return rearrange(t, '... g d -> ... (g d)')

class CosineAttention(nn.Module):
    def __init__(
        self,
        n_feats,
        head_dim,
        n_heads,
        dropout=0.1,
        bias=False,
        temperature=15.5,
        return_attention=False,
        causal=False,
        activation='softmax',
        **kwargs
    ):
        super().__init__()
        assert activation in ['relusq', 'softmax']
        self.shared_kv = kwargs.get('shared_kv', False)
        self.talking_heads = kwargs.get('talking_heads', False)

        self.n_feats, self.head_dim, self.n_heads = n_feats, head_dim, n_heads
        self.dropout = nn.Dropout(dropout)
        self.bias = bias
        self.return_attention = return_attention
        self.causal = causal

        if self.talking_heads:
            self._head_proj = nn.Conv2d(n_heads, n_heads, (1, 1))

        self.temperature = torch.nn.Parameter(torch.tensor(temperature), requires_grad=True) if isinstance(temperature, float) else temperature

        self.activation = ReLUSquared() if activation == 'relusq' else nn.Softmax(dim=-1)

        if not self.shared_kv:
            self.qkv_proj = nn.Linear(n_feats, 3 * n_heads * head_dim, bias=bias)
            self.qkv = lambda x: rearrange(self.qkv_proj(x), "b n (h d qkv) -> qkv b h n d", qkv=3, h=n_heads, d=head_dim)
        else:
            self.q_proj, self.kv_proj = [nn.Linear(n_feats, el, bias=bias) for el in [n_heads * head_dim, 2 * head_dim]]
            map_q, map_kv = lambda q: rearrange(q, 'b n (h d) -> b h n d', h=n_heads), lambda kv: rearrange(kv, 'b n (kv d) -> kv b () n d', kv=2, d=head_dim)
            self.qkv = lambda x: (map_q(self.q_proj(x)), *map_kv(self.kv_proj(x)))

        self.out_proj = nn.Linear(n_heads * head_dim, n_feats, bias=bias)
    
    def head_proj(self, dots):
        if not self.talking_heads:
            return dots
        dots = self._head_proj(dots)
        return dots      

    def attend(self, query, key, value, attn_mask, pos_bias):
        query, key = map(l2norm, (query, key))
        
        dots = einsum('bhid,bhjd->bhij', query, key) * self.temperature
        dots = self.head_proj(dots)

        #dots += pos_bias
     

        dots.masked_fill_(attn_mask, -torch.finfo(dots.dtype).max)

        attn = self.activation(dots)
     
        attn = self.dropout(attn)
        return einsum("bhij,bhjd->bhid", attn, value)


    def attach_cache(self, kv, cache, cache_indices):
        kv = torch.stack(kv, dim=0)
        if cache is None:
            return kv
        zero_vector = torch.zeros_like(kv[:, :, :, :1, :])
        kv_w_cache = torch.cat([cache, kv, zero_vector], dim=-2)
        print(kv_w_cache.shape, 'pre_gather')
        kv_w_cache = torch.gather(kv_w_cache, dim=-2, index=cache_indices) # we do this to remove unnecessary padding
        return kv_w_cache

    def forward(self, x, pos_bias, mask, cache=None, cache_indices=None):
        B, N, C, H, D = *x.shape, self.n_heads, self.head_dim
    
        q, k, v  = self.qkv(x)
        kv = self.attach_cache([k, v], cache, cache_indices)
        k, v = kv

        out = self.attend(q, k, v, mask, pos_bias)

        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.out_proj(out)
        return out, kv

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(self.norm(x), *args, **kwargs)


class GLU(nn.Module):
    def __init__(self, dim_in, dim_out, activation):
        super().__init__()
        self.act = activation
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim = -1)
        return x * self.act(gate)




class transformer(nn.Module):
    def __init__(
            self, 
            dim, 
            depth, 
            heads, 
            dim_head, 
            causal=True,
            temperature=15.5,
            shared_temperture=False,
            intermediate_loss=True,
            dropout = 0.1,
            **kwargs
        ):
        super().__init__()
        if depth == 1:
            intermediate_loss = False

        ff_mult = kwargs.get('ff_mult', 4)
        self.checkpoint_every_n = kwargs.get('checkpoint_every_n', 0)
        self.token_shift = kwargs.get('token_shift', False)
        self.causal = causal

        self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=True) if shared_temperture else temperature
    

        self.intermediate_loss = intermediate_loss

        self.depth = depth
        self.positional_bias = DynamicPositionBias(
            dim = dim // 4,
            heads = heads,
            depth = 2,
            log_distance = False,
            norm = False
        )
        

        self.token_shifter = lambda x: x
        if self.token_shift:
            self.token_shifter = ShiftTokens(range(0, 2), nn.Identity())
        self.token_shift = lambda x: self.token_shifter(x)

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, CosineAttention(
                    dim, 
                    n_heads=heads, 
                    head_dim=dim_head, 
                    causal=causal,
                    temperature=self.temperature,
                    dropout=dropout,
                    **kwargs
                )),
                PreNorm(dim, self.ff(dim, mult=ff_mult))
            ]))

    @staticmethod
    def ff(dim, mult=4, dropout=0.1):
        return nn.Sequential(
            GLU(dim, dim * mult, nn.SiLU()),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    @staticmethod
    def create_custom_forward(module):
        def custom_forward(*args, **kwargs):
            return module(*args, **kwargs)
        return custom_forward

    def checkpoint(self, layer, module, *args, **kwargs):
        condition = self.training and self.checkpoint_every_n != 0 and layer < self.depth - 1 and layer % self.checkpoint_every_n == 0
        return checkpoint(self.create_custom_forward(module), *args, **kwargs) if condition else module(*args, **kwargs)


    @staticmethod
    def get_cache(cache, layer):
        if cache is None:
            return None
        return cache['cache'][layer]

    @staticmethod
    def get_cache_indices(x_lens, cache_lens, cache_kv, x):  
        # used later w/ gather to remove padding when cache is concatenated with current input to remove padding
        max_new_len = (x_lens + cache_lens).max()

        B, H, N, D = x.shape[0], 1, (x.shape[1] + cache_kv.shape[-2]), cache_kv.shape[-1]
        indices = []
        for i in range(B): # stinky for loop to sort out indices for gather 
            cache_indices = torch.arange(cache_lens[i], device='cpu')
            total_length = cache_lens[i] + x_lens[i]
            diff_from_max_len = max_new_len - total_length
            x_indices = torch.arange(x_lens[i]+diff_from_max_len, device='cpu') + cache_kv.shape[-2]
            if diff_from_max_len > 0:
                x_indices[-diff_from_max_len:] = N # last index will be used for padding
            new_indices = torch.cat([cache_indices, x_indices])
            indices.append(new_indices)

        indices = torch.stack(indices, dim=0)
        
        indices = rearrange(indices, 'b n -> () b () n ()').expand(2, B, H,-1, D) # 2 for key and value
        return indices.to(x.device)


    def create_masks(self, x, length, cache): # could clean this up ):
        x_len = length if length is not None else torch.tensor(x.shape[-2]).expand(x.shape[0])
        cache_len = cache['cache_lengths'] if exists(cache) else 0
        total_len = x_len + cache_len
        kv_mask = torch.arange(total_len.max(), device=x.device).expand(len(total_len), -1) >= total_len.unsqueeze(-1)
        q_mask = torch.arange(x_len.max(), device=x.device).expand(len(x_len), -1) >= x_len.unsqueeze(-1)
        attn_mask = ~(rearrange(~q_mask, "b n -> b () n ()") * rearrange(~kv_mask, "b n -> b () () n"))

        if self.causal:
            causal_mask = repeat(torch.arange(total_len.max()), 'i -> b r i', b=len(total_len), r=x_len.max())
            causal_mask = causal_mask >= ((cache_len[:,None,None] if exists(cache) else cache_len) + torch.arange(x_len.max())[None,:,None] + 1)
            print(attn_mask.shape, causal_mask[:,None].shape)
            attn_mask = torch.logical_or(attn_mask, causal_mask[:,None])
            
        return q_mask, attn_mask, total_len, x_len, cache_len

    def forward(self, x, length=None, self_condtioning=None, cache=None):
        intermediate_logits = []
        cached_kvs = []
    
        mask, attn_mask, total_lens, x_len, cache_len = self.create_masks(x, length, cache)

        cache_indices = self.get_cache_indices(x_len, cache_len, cache['cache'], x) if exists(cache) else None
      
        pos_bias = self.positional_bias(i = attn_mask.shape[-2], j = attn_mask.shape[-1], device=x.device, dtype=x.dtype)

        for i, (attn, ff) in enumerate(self.layers):

            x = self.token_shift(x)
            a_out, kv = self.checkpoint(i, attn, x, pos_bias, attn_mask, self.get_cache(cache, layer=i), cache_indices)
            x = a_out + x
            cached_kvs.append(kv)
            x = self.checkpoint(i, ff, x) + x   

            if i < self.depth - 1 and self_condtioning is not None:
                x, logits = self_condtioning(x)
                intermediate_logits.append(logits)

        if len(intermediate_logits) > 0: # stack intermediate logits
            intermediate_logits = torch.stack(intermediate_logits, dim=0) # D x B x N x L

        cached_kvs = torch.stack(cached_kvs, dim=0) if len(cached_kvs) > 0 else None
        cached_kvs = {'cache_lengths': total_lens, 'cache': cached_kvs} if exists(cached_kvs) else None


        return x, intermediate_logits, cached_kvs

class shared_embedding_output_layer(nn.Module):
    '''Pass a embedding layer and then use this module as the output layer'''
    def __init__(self, embedding_layer, bias=False):
        super().__init__()
        self.embedding_layer = embedding_layer
        self.use_bias = bias
        if bias:
            self.bias = nn.Parameter(torch.zeros(embedding_layer.weight.shape[0]))#
            nn.init.xavier_uniform_(self.bias)

    def forward(self, x):
        return F.linear(x, weight=self.embedding_layer.weight, bias=self.bias if self.use_bias else None)


class transformer_lm(nn.Module):
    def __init__(
        self,
        dim,
        vocab_size,
        depth,
        heads,
        dim_head,
        causal=True,
        temperature=15.5,
        dropout=0.,
        shared_temperture=True,
        self_conditioning=False,
        intermediate_loss=True,
        use_abs_pos=False,
        **kwargs
    ):
        super().__init__()
        if depth == 1:
            self_conditioning == False

        self.self_conditioning = True if self_conditioning else None
        self.intermediate_loss = intermediate_loss

        self.use_abs_pos = use_abs_pos
        if self.use_abs_pos:
            self.abs_pos_fn = ScaledSinuEmbedding(dim=dim)
        self.abs_pos = lambda x: x + self.abs_pos_fn(x) if self.use_abs_pos else x

        if self_conditioning:
            self.reprojection_layer = nn.Linear(vocab_size, dim)


        self.layers = transformer(
            dim = dim, 
            depth = depth, 
            heads = heads, 
            dim_head = dim_head, 
            causal = causal, 
            dropout = dropout,
            temperature = temperature,
            shared_temperture = shared_temperture,
            intermediate_loss = intermediate_loss,
            **kwargs
        )

        self.tie_embedding = kwargs.get('tie_embedding', False)
        print('Tie embedding:', self.tie_embedding) if self.tie_embedding else None
 
        self.embedding = nn.Embedding(vocab_size, dim)

        self.to_logits = shared_embedding_output_layer(self.embedding) if self.tie_embedding else nn.Linear(dim, vocab_size)
        

        self.post_norm = nn.LayerNorm(dim)


    def self_condition_fn(self):
        def self_condition(x):
            logits = self.to_logits(self.post_norm(x))
            if self.self_conditioning: # not effective for LMs (intermediate loss is tho)
                z = F.softmax(logits, dim=-1)
                z = self.reprojection_layer(z)
                x = z + x
            return x, logits
        return self_condition if (self.self_conditioning or self.intermediate_loss) and self.training else None


    def forward(self, x, length=None, cache:Dict=None):
        '''
        x: [B, N] (embedding indices)
        length: [B] (length of each sequence)
        cache: {cache_lengths: [B, N], cache: [L, KV, B, H, N, D]} KV: key and value (2)
        '''
        x = self.embedding(x)
        x = self.abs_pos(x) 
  
        x, interim_logits, cached_kvs = self.layers(x, length, self_condtioning=self.self_condition_fn(), cache=cache)
        x = self.post_norm(x)
        x = self.to_logits(x)

        return  x, interim_logits, cached_kvs

In [358]:
import string
class CharacterTokenizer(): # only for testing
    def __init__(self):
        self.vocab = ['#', '/'] + list(string.ascii_lowercase) + [' '] # bos/eos -> /, pad -> #
        self.vocab_size = len(self.vocab)
        self.token_to_id = {token: i for i, token in enumerate(self.vocab)}
        self.id_to_token = {i: token for i, token in enumerate(self.vocab)}
    
    def __call__(self, text):
        return self.tokenize(text)

    def tokenize(self, text):
        return [self.token_to_id[token] for token in text]

tokenizer = CharacterTokenizer()

In [359]:
model = transformer_lm(
    dim = 256,
    vocab_size = tokenizer.vocab_size,
    depth = 10,
    heads = 1,
    dim_head = 32,
    dropout=0.0,
    causal = True,
    shared_kv = True,
)
model.eval()
print()




In [360]:
def collate_fn(tensors:List[torch.Tensor], pad_token:int):
    max_len = max([t.shape[0] for t in tensors])
    lengths = torch.tensor([t.shape[0] for t in tensors])
    padded_tensors = [torch.cat([t, torch.full((max_len - t.shape[0],), pad_token, dtype=t.dtype)], dim=0) for t in tensors]
    return torch.stack(padded_tensors, dim=0), lengths

In [361]:
print(b1_lengths + b2_lengths, b1_lengths, b2_lengths)
print(fb_lengths)

tensor([13, 13, 10]) tensor([6, 7, 6]) tensor([7, 6, 4])
tensor([13, 13, 10])


In [362]:
# USE THE LENGTHS TO MAKE THE CAUSAL MASK RATHER THAN THE OTHER THINGY

In [364]:
s1_b1, s2_b1, s3_b1 = torch.tensor(tokenizer('/hello')), torch.tensor(tokenizer('/buenos')), torch.tensor(tokenizer('/whats'))
s1_b2, s2_b2, s3_b2 = torch.tensor(tokenizer(' world/')), torch.tensor(tokenizer(' dias/')), torch.tensor(tokenizer(' up/'))
b1, b1_lengths = collate_fn([s1_b1, s2_b1, s3_b1], pad_token=tokenizer.token_to_id['#'])
b2, b2_lengths = collate_fn([s1_b2, s2_b2, s3_b2], pad_token=tokenizer.token_to_id['#'])
# comparsion set
f_1, f_2, f_3 = torch.tensor(tokenizer('/hello world/')), torch.tensor(tokenizer('/buenos dias/')), torch.tensor(tokenizer('/whats up/'))
fb, fb_lengths = collate_fn([f_1, f_2, f_3], pad_token=tokenizer.token_to_id['#'])


with torch.no_grad():
    logits_s1, interim_logits, cached_kvs = model(b1, length=b1_lengths)
    #print(cached_kvs['cache'].shape, 'cache', b2.shape)
    print('second')
    logits_s2, interim_logits, cached_kvs_s2 = model(b2, length=b2_lengths, cache=cached_kvs)
    print('third')
    logits_fs, interim_logits, cached_kvs_fs = model(fb, length=fb_lengths)

#logits_s1.masked_fill_(b1_lengths_mask.unsqueeze(-1), 0)

#print('logits_s2:', logits_s2.shape, 'logits_fs:', logits_fs.shape)
D, B = 1, 1
#print(logits_s2[B, :, D], logits_fs[B, :, D])



torch.Size([3, 1, 7, 7]) torch.Size([3, 1, 7, 7])
second
torch.Size([3, 1, 7, 13]) torch.Size([3, 1, 7, 13])
torch.Size([2, 3, 1, 15, 32]) pre_gather
torch.Size([2, 3, 1, 15, 32]) pre_gather
torch.Size([2, 3, 1, 15, 32]) pre_gather
torch.Size([2, 3, 1, 15, 32]) pre_gather
torch.Size([2, 3, 1, 15, 32]) pre_gather
torch.Size([2, 3, 1, 15, 32]) pre_gather
torch.Size([2, 3, 1, 15, 32]) pre_gather
torch.Size([2, 3, 1, 15, 32]) pre_gather
torch.Size([2, 3, 1, 15, 32]) pre_gather
torch.Size([2, 3, 1, 15, 32]) pre_gather
third
torch.Size([3, 1, 13, 13]) torch.Size([3, 1, 13, 13])


In [365]:
print(cached_kvs['cache'][L,0,1,0,6])

tensor([-0.6397,  0.7431, -0.8576, -0.5721,  0.2620, -0.6212, -0.3207, -0.0375,
        -0.7380, -0.3104,  0.1822, -0.6504, -0.6416,  0.2827,  0.6039,  0.6327,
         0.2097, -0.0468, -0.2050, -0.2130, -0.9321,  0.1500,  0.7238, -0.2692,
        -0.0446,  0.3955,  0.0053,  0.3986, -1.2466,  0.5881,  1.2487, -0.2498])


In [366]:
cached_kvs_fs['cache_lengths']

tensor([13, 13, 10])

In [382]:
print('shapes: ', cached_kvs_fs['cache'].shape, cached_kvs_s2['cache'].shape)
N = 8
L = -1
I = 0
print(cached_kvs_fs['cache'][L,0,I,0,N])
print()
print(cached_kvs_s2['cache'][L,0,I,0,N])
assert torch.allclose(cached_kvs_fs['cache'][L,:,I,:,N], cached_kvs_s2['cache'][L,:,I,:,N], rtol=0.5)

shapes:  torch.Size([10, 2, 3, 1, 13, 32]) torch.Size([10, 2, 3, 1, 13, 32])
tensor([-0.1548,  1.1113, -0.1597, -0.4324,  0.2550, -0.1549, -0.1684,  0.3414,
         0.1699, -0.1366,  0.1756, -0.9724, -0.9767,  0.4447, -1.7030,  1.8319,
         0.0057, -0.2066,  0.2175, -0.2070, -0.8515,  0.0823,  0.7015, -0.0918,
        -0.1286,  0.0922,  0.9617,  0.2686, -0.9577,  0.5732,  1.2578,  0.1263])

tensor([-0.1548,  1.1113, -0.1597, -0.4324,  0.2550, -0.1549, -0.1684,  0.3414,
         0.1699, -0.1366,  0.1756, -0.9724, -0.9767,  0.4447, -1.7030,  1.8319,
         0.0057, -0.2066,  0.2175, -0.2070, -0.8515,  0.0823,  0.7015, -0.0918,
        -0.1286,  0.0922,  0.9617,  0.2686, -0.9577,  0.5732,  1.2578,  0.1263])


In [298]:
print(cached_kvs_fs['cache'].shape)
print(cached_kvs_s2['cache'].shape)

torch.Size([10, 2, 3, 1, 13, 32])
torch.Size([10, 2, 3, 1, 13, 32])


In [269]:
print(logits_s1[B, :, D], logits_fs[B, :logits_s1.shape[1], D])

tensor([ 0.9226,  0.3684,  0.1545,  0.1768, -0.6461,  0.1462,  0.3829]) tensor([ 0.9226,  0.3684,  0.1545,  0.1768, -0.6460,  0.1462,  0.3829])


In [76]:
with torch.no_grad():
    x_og, i_logits, cache_kv = model(torch.tensor(tokenizer('/hello bro/')).unsqueeze(0))

    x, i_logits, cache_kv = model(torch.tensor(tokenizer('/hello')).unsqueeze(0))
    x_c, i_logits, cache_kv = model(torch.tensor(tokenizer(' bro/')).unsqueeze(0), cache=cache_kv)
        #print(cache_kv.shape)

    #print(x_og.shape, x_c.shape)
print(x_og[0,6:, 0], x_c[0,:, 0])

torch.Size([2, 11, 32])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])
tensor([ 0.9778,  0.3346,  0.4991, -0.7083, -0.3909]) tensor([ 0.9778,  0.3346,  0.4991, -0.7083, -0.3909])


In [None]:
torch.cat([torch.tensor(tokenizer('/hello')), torch.tensor(tokenizer(' bro/'))])

tensor([ 1,  9,  6, 13, 13, 16, 28,  3, 19, 16,  1])

In [None]:
print(x_og.shape, x_c.shape)
print(x_og[0,6:, 0], x_c[0,:, 0])

torch.Size([1, 11, 29]) torch.Size([1, 5, 29])
tensor([ 0.4475, -0.1402, -0.5746,  0.9537, -1.0230]) tensor([ 0.3507, -0.1993, -0.5755,  0.7726, -1.1193])


In [218]:
def create_masks(x, length, cache):
    x_len = length if length is not None else torch.tensor(x.shape[-2]).expand(x.shape[0])
    cache_len = 0 if cache is None else cache['length']
    total_len = x_len + cache_len
    kv_mask = torch.arange(total_len.max(), device=x.device).expand(len(total_len), -1) >= total_len.unsqueeze(-1)
    q_mask = torch.arange(x_len.max(), device=x.device).expand(len(x_len), -1) >= x_len.unsqueeze(-1)
    attn_mask = ~(rearrange(~q_mask, "b n -> b () n ()") * rearrange(~kv_mask, "b n -> b () () n"))

    if 1==1: #causal
        causal_mask = torch.ones(attn_mask.shape[-2], attn_mask.shape[-1], device=x.device).triu(1 + attn_mask.shape[-2] - attn_mask.shape[-1]).bool()
        attn_mask = torch.logical_or(attn_mask, causal_mask)
        
    return q_mask, attn_mask, x_len

In [217]:
create_masks(x=torch.rand(1, 10,3), length=None, cache={'length': torch.tensor([5])})

(tensor([[False, False, False, False, False, False, False, False, False, False]]),
 tensor([[[[False, False, False, False, False, False, False, False, False, False,
            False, False, False, False, False],
           [False, False, False, False, False, False, False, False, False, False,
            False, False, False, False, False],
           [False, False, False, False, False, False, False, False, False, False,
            False, False, False, False, False],
           [False, False, False, False, False, False, False, False, False, False,
            False, False, False, False, False],
           [False, False, False, False, False, False, False, False, False, False,
            False, False, False, False, False],
           [False, False, False, False, False, False, False, False, False, False,
            False, False, False, False, False],
           [False, False, False, False, False, False, False, False, False, False,
            False, False, False, False, False],
       