In [1]:
import torch, torch.nn as nn, torch.nn.functional as F
from einops import rearrange, repeat
from torch import einsum


In [50]:
class DynamicPositionBias(nn.Module):
    '''taken From Phil Wang's x-transformers library'''
    def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
        super().__init__()
        assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
        self.log_distance = log_distance

        self.mlp = nn.ModuleList([])

        self.mlp.append(nn.Sequential(
            nn.Linear(1, dim),
            nn.LayerNorm(dim) if norm else nn.Identity(),
            nn.ReLU()
        ))

        for _ in range(depth - 1):
            self.mlp.append(nn.Sequential(
                nn.Linear(dim, dim),
                nn.LayerNorm(dim) if norm else nn.Identity(),
                nn.ReLU()
            ))

        self.mlp.append(nn.Linear(dim, heads))

    @staticmethod
    def fetch_module_kwargs(kwargs):
        return {
            'dim': kwargs.get('dpos_dim', 64),
            'depth': kwargs.get('dpos_depth', 2),
            'log_distance': kwargs.get('dpos_log_distance', False),
            'norm': kwargs.get('dpos_norm', False)
        }


    def forward(self, i, j, device, dtype):
        # get the (i x j) matrix of distances
        assert i >= 1 and j >= 1 and i <= j, 'I should be in the range [1, j] and j >= 1'
        seq_arange = torch.arange(i, device = device)
        context_arange = torch.arange(j, device = device)
        indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
        indices += (j-1)
        
        # input to continuous positions MLP
        pos = torch.arange(-i + 1, (j+i), device = device, dtype = dtype)
        pos = rearrange(pos, '... -> ... 1')
     
        if self.log_distance:
            pos = torch.sign(pos) * torch.log(pos.abs() + 1)  # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)

        for layer in self.mlp:
            pos = layer(pos)

        # get position biases        
        bias = pos[indices]
        bias = rearrange(bias, 'i j h -> h i j')
        return bias

In [57]:
DynamicPositionBias(
    heads=8,
    **DynamicPositionBias.fetch_module_kwargs({'dpos_dim':222})
)

DynamicPositionBias(
  (mlp): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=1, out_features=222, bias=True)
      (1): Identity()
      (2): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=222, out_features=222, bias=True)
      (1): Identity()
      (2): ReLU()
    )
    (2): Linear(in_features=222, out_features=8, bias=True)
  )
)

In [49]:
pos = DynamicPositionBias(64, heads = 8, depth = 2, log_distance = False)
pos = pos(4,3, device = 'cpu', dtype = torch.float32)
pos.squeeze(-1)

AssertionError: I should be in the range [1, j] and j >= 1

In [4]:
class ReLUSquared(nn.Module):
    def forward(self, x):
        return torch.pow(F.relu(x), 2)

def l2norm(t, groups = 1, dim = -1):
    if groups == 1:
        return F.normalize(t, p = 2, dim = dim)
    t = rearrange(t, '... (g d) -> ... g d', g = groups)
    t = F.normalize(t, p = 2, dim = dim)
    return rearrange(t, '... g d -> ... (g d)')


class CosineAttention(nn.Module):
    def __init__(
        self,
        n_feats,
        head_dim,
        n_heads,
        dropout=0.1,
        bias=False,
        temperature=15.5,
        return_attention=False,
        causal=False,
        activation='softmax',
        **kwargs
    ):
        super().__init__()
        assert activation in ['relusq', 'softmax']
        self.shared_kv = kwargs.get('shared_kv', False)
        self.talking_heads = kwargs.get('talking_heads', False)
        self.cache_kv = kwargs.get('cache_kv', False) # whether prev key and values are used 
        

        self.n_feats, self.head_dim, self.n_heads = n_feats, head_dim, n_heads
        self.dropout = nn.Dropout(dropout)
        self.bias = bias
        self.return_attention = return_attention
        self.causal = causal

        if self.talking_heads:
            self._head_proj = nn.Conv2d(n_heads, n_heads, (1, 1))

        self.temperature = torch.nn.Parameter(torch.tensor(temperature), requires_grad=True) if isinstance(temperature, float) else temperature

        self.activation = ReLUSquared() if activation == 'relusq' else nn.Softmax(dim=-1)

        if not self.shared_kv:
            self.qkv_proj = nn.Linear(n_feats, 3 * n_heads * head_dim, bias=bias)
            self.qkv = lambda x: rearrange(self.qkv_proj(x), "b n (h d qkv) -> qkv b h n d", qkv=3, h=n_heads, d=head_dim)
        else:
            self.q_proj, self.kv_proj = [nn.Linear(n_feats, el, bias=bias) for el in [n_heads * head_dim, 2 * head_dim]]
            map_q, map_kv = lambda q: rearrange(q, 'b n (h d) -> b h n d', h=n_heads), lambda kv: rearrange(kv, 'b n (kv d) -> kv b () n d', kv=2, d=head_dim)
            self.qkv = lambda x: (map_q(self.q_proj(x)), *map_kv(self.kv_proj(x)))

        self.out_proj = nn.Linear(n_heads * head_dim, n_feats, bias=bias)

        if self.cache_kv:
            cache_heads = n_heads if not self.shared_kv else 1
            self.cache_vector = torch.nn.Parameter(torch.zeros(1, cache_heads, 1, head_dim), requires_grad=True)
            print(self.cache_vector.shape)
    
    def head_proj(self, dots):
        if not self.talking_heads:
            return dots
        dots = self._head_proj(dots)
        return dots      

    def attend(self, query, key, value, mask, k_mask, pos_fn):
        dots = einsum('bhid,bhjd->bhij', query, key) * self.temperature
        dots = self.head_proj(dots)

        dots += pos_fn(i=dots.shape[-2], j=dots.shape[-1], device=dots.device, dtype=dots.dtype)
        qmask, kmask = ~mask, ~k_mask
        attn_mask = ~(rearrange(qmask, "b n -> b () n ()") * rearrange(kmask, "b n -> b () () n"))
    
        if self.causal: # create a regular causal mask
            causal_mask = torch.ones(dots.shape[-2], dots.shape[-1], device=dots.device).triu(1).bool()
            attn_mask = torch.logical_or(attn_mask, causal_mask)
        
        dots.masked_fill_(attn_mask, -torch.finfo(dots.dtype).max)
    
        attn = self.activation(dots)
     
        attn = self.dropout(attn)
        return einsum("bhij,bhjd->bhid", attn, value)

    def lengths_from_mask(self, x, mask):
        if mask is None:
            return x.shape[-2]
        return (~mask).sum(dim=-1)

    def attach_cache(self, k, v, mask, cache_kv, cache_mask):
        if cache_kv is None:
            return k, v, mask
        
        cache_k, cache_v = cache_kv
        
        cache_k, cache_v = cache_k.to(k.device), cache_v.to(k.device)
        cache_vector = self.cache_vector.to(k.device)
        cache_k, cache_v = cache_k + cache_vector, cache_v + cache_vector
        cache_lens = self.lengths_from_mask(cache_k, cache_mask)
        max_cache_len = cache_lens.max()    
        x_lens = self.lengths_from_mask(k, mask)
        new_lens = x_lens + cache_lens
        max_new_len = new_lens.max()
        # so we want to remove excess padding and only have padding at the end of the sequence
        # otherwise things get weird with the position encoding
        # lets used gather to do this (not sure if there is a faster way)
        # fk ill use a for loop to get the indices
        indices = []
        new_k, new_v = torch.cat([cache_k, k], dim=-2), torch.cat([cache_v, v], dim=-2)
   
        # add zero to last dimension to use to fetch the padding
        B,H,N,D= new_k.shape
        zero_vector = torch.zeros((B,H,1,D), device=new_k.device)
        new_k,new_v = torch.cat([new_k, zero_vector], dim=-2), torch.cat([new_v, zero_vector], dim=-2)
    

        for i in range(new_k.shape[0]):
            cache_indices = torch.arange(cache_lens[i], device='cpu')
            total_length = cache_lens[i] + x_lens[i]
            diff_from_max_len = max_new_len - total_length
            x_indices = torch.arange(x_lens[i]+diff_from_max_len, device='cpu') + cache_k.shape[-2]
            if diff_from_max_len > 0:
                x_indices[-diff_from_max_len:] = new_k.shape[-2] - 1
            new_indices = torch.cat([cache_indices, x_indices])
            indices.append(new_indices)

        indices = torch.stack(indices, dim=0)
        # NOW LETS GATHER
        indices = rearrange(indices, 'b n -> () b () n ()').expand(2,B,H,N,D)
        new_kv = torch.stack([new_k, new_v], dim=0) # avoid double gather
        new_k, new_v = torch.gather(new_kv, dim=-2, index=indices)
       
        # just create the new mask
        new_mask = torch.arange(max_new_len, device=mask.device) >= new_lens[:, None]
        return new_k, new_v, new_mask
       
    
        
        

    def forward(self, x, pos_fn, mask=None, cached_kv=None, cached_mask=None):
        assert pos_fn is not None, 'pls provide a position function'
        B, N, C, H, D = *x.shape, self.n_heads, self.head_dim
        #print(x.shape, mask.shape)

        if mask is None:
            mask = torch.zeros(B, N, device=x.device, dtype=torch.bool)

        q, k, v = self.qkv(x)
        q, k = map(l2norm, (q, k))

        if self.cache_kv:
            k, v, k_mask = self.attach_cache(k, v, mask, cached_kv, cached_mask)
     
        out = self.attend(q, k, v, mask, k_mask, pos_fn)

        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.out_proj(out)
        return out if not self.cache_kv else (out, k, v)

In [67]:
z = torch.randn(2, 1, 8, 256)
v = torch.randn(2, 1, 1, 256)
(v.expand_as(z) + 0).shape
a = nn.Identity(d=2)
a(1)

1

In [7]:
!

220

In [8]:
pos_fn = DynamicPositionBias(dim = 64, heads = 8, log_distance = False, depth = 2)
cAttn = CosineAttention(n_feats=256, head_dim=32, n_heads=8, temperature=15.5, cache_kv=True, shared_kv=True, talking_heads=True)
x = torch.randn(3, 128, 256)
cached_kv = torch.randn(1,3,1,92,32).repeat(2,1,1,1,1)
print(cached_kv.shape)
cached_mask = torch.zeros(3, 92, dtype=torch.bool)
cached_mask[0, -3:] = True
print(cached_kv.shape, cached_mask.shape)
cached_kv.masked_fill_(rearrange(cached_mask, 'b n -> () b () n ()'), 0)
mask = torch.zeros(3, 128, dtype=torch.bool)

out = cAttn(x, pos_fn, mask, cached_kv, cached_mask)

torch.Size([1, 1, 1, 32])
torch.Size([2, 3, 1, 92, 32])
torch.Size([2, 3, 1, 92, 32]) torch.Size([3, 92])


In [830]:
cached_kv[0,0,0,:,0][89]

tensor(0.)

In [831]:
out[0][0,0,:,0][89]

tensor(0.4142, grad_fn=<SelectBackward0>)

In [833]:
out[-2][0,0,:,1]

tensor([-0.6663, -0.3249, -0.4769,  0.2107, -0.3240,  0.4785,  0.9103,  0.3106,
         0.2695, -1.3441, -0.2908,  0.7281,  0.1858, -0.2704,  0.1194, -1.1572,
        -0.7689, -0.6562, -0.6445, -0.5988,  0.3206, -0.5878,  0.2903, -0.3617,
        -0.3463, -0.4374, -0.2675,  0.4168, -0.0143, -0.7390, -0.7874, -0.8283,
        -0.7971,  1.3905, -0.3582,  0.5042,  0.0984,  0.5525, -0.8537, -0.7623,
        -0.0656,  0.4852,  0.4649, -0.2858, -1.0995, -0.5842,  0.0458, -0.5944,
        -0.0380,  0.7912,  0.8925,  0.4276,  0.3902, -0.3050, -0.0830,  1.5923,
        -1.3425,  0.0548,  1.2778,  0.2260, -1.4920, -0.6034, -0.5314,  0.4827,
        -0.8079, -0.0481,  0.9936,  0.0947,  0.0714,  0.5514, -0.2284,  0.8921,
        -0.1457,  0.5982, -0.6773,  0.0830, -0.2965,  0.5944, -0.5165, -0.4533,
         0.5581, -1.0282,  0.3890, -0.2361,  0.6638, -0.2990,  0.4568, -0.9507,
        -0.1035, -0.4476,  0.0744, -0.0944, -0.2519, -0.0906,  0.2232, -0.3658,
         1.7562, -0.5497, -0.4318, -0.29

tensor([ 1.6012, -0.6469, -0.3006, -1.0750,  1.1811,  0.6629, -0.3181, -1.3653,
        -2.1848, -0.7446,  0.6312,  0.1165,  0.3485, -0.7637, -0.0511, -0.1435,
        -0.4605, -0.5894, -0.3992,  0.1334,  0.1373,  0.9801, -0.9805,  1.1095,
        -0.8132, -0.3352,  1.6133, -3.2370, -1.1195,  0.6321,  0.7622, -0.1708],
       grad_fn=<SelectBackward0>)

In [35]:
import torch.nn as nn, torch

class HydraAttention(nn.Module):
    def __init__(self, d_model, output_layer='scale_and_bias'):
        '''
        output_layer: 'scale_and_bias' | 'linear' | 'none'
        '''
        super(HydraAttention, self).__init__()
        self.d_model = d_model
        self.qkv = nn.Linear(d_model, d_model * 3)
        if output_layer == 'scale_and_bias':
            self.scale = nn.Parameter(torch.ones(1, 1, d_model))
            self.bias = nn.Parameter(torch.zeros(1, 1, d_model))
            self.out = lambda x: x * self.scale + self.bias
        elif output_layer == 'linear':
            self.out = nn.Linear(d_model, d_model)
        elif output_layer == 'none':
            self.out = nn.Identity()

    def forward(self, x):
        '''x: (B, T, D)'''
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        q = q / q.norm(dim=-1, keepdim=True)
        k = k / k.norm(dim=-1, keepdim=True)
        kv = (k * v).sum(dim=-2, keepdim=True)
        out = q * kv
        return self.out(out)

In [37]:
HydraAttention(256)(x)

tensor([[[-2.6228e-04,  1.3423e-01,  3.1196e-02,  ..., -5.5381e-02,
           2.3028e-02, -1.9155e-02],
         [-8.7564e-06,  8.4757e-02,  1.5684e-02,  ...,  1.5657e-02,
          -5.4329e-03, -1.1544e-02],
         [-9.8780e-05, -1.9900e-02,  2.5104e-02,  ..., -4.6537e-02,
          -2.8425e-02, -1.3837e-02],
         ...,
         [ 3.7448e-04, -7.5378e-02, -6.1634e-02,  ...,  2.0541e-02,
           1.8774e-03, -7.0934e-02],
         [ 3.6962e-04, -4.1401e-02,  4.3923e-03,  ..., -1.4942e-02,
           3.6164e-02, -1.0721e-02],
         [-1.0050e-04, -7.9818e-03, -3.8386e-02,  ..., -3.7135e-02,
           1.2968e-02, -7.7314e-02]]], grad_fn=<AddBackward0>)

In [2]:
%cd ../tedlium

/exp/exp1/acp21rjf/deliberation/speachy/tedlium


In [131]:
import non_iid_dataloader as niiddl, torch

In [4]:
import tools

In [85]:
from importlib import reload as rl
rl(niiddl)

<module 'non_iid_dataloader' from '/exp/exp1/acp21rjf/deliberation/speachy/tedlium/non_iid_dataloader.py'>

In [5]:
corpus = tools.load_corpus()

In [6]:
tkn = tools.load_tokenizer('./tokenizers/tokenizer_spe_bpe_v128/tokenizer.model')

In [323]:
dl = niiddl.get_data_loader(
    split = corpus['train'],
    tokenizer = tkn,
    batch_size = 15,
    max_duration = 30,
)

In [324]:
for i in dl:
    z = i
    break

In [389]:
def create_subbatches(audio, audio_lens, tokens, token_lens, segment_lens): # for loops ):
    max_segment_len = segment_lens.max()

    culm_seglens = segment_lens.cumsum(dim=0)
    cur_positions = culm_seglens - segment_lens
    sub_batches_indices = []

    # first get indices for each sub batch of the "rnn"
    for ix in range(max_segment_len):
        indices = []
        for iz in range(len(segment_lens)):
            pos = cur_positions[iz].item()
            if pos < culm_seglens[iz]:
                indices.append(pos)
                cur_positions[iz] += 1
            else:
                indices.append(-1)
        sub_batches_indices.append(torch.tensor(indices, dtype=torch.long))
    ####
    ### after each forward pass the model will return the cached kvs
    # this gets the indices of the correct kvs for the next forward pass
    non_empty_indices = torch.arange(len(segment_lens), dtype=torch.long)
    prev_non_empty_fetch = []
    for i in range(len(sub_batches_indices)):
        cur = sub_batches_indices[i]
        cur = cur[sub_batches_indices[i-1] != -1] if i > 0 else cur
        non_empty_indices = non_empty_indices[cur != -1]
        prev_non_empty_fetch.append(non_empty_indices.clone())
        non_empty_indices = torch.arange(len(non_empty_indices), dtype=torch.long)
    ####
    sub_batches = []
    for i, ix in enumerate(sub_batches_indices):
        sbi = ix[ix != -1]
        cur_audio, cur_audio_lens, cur_tokens, cur_token_lens = audio[sbi], audio_lens[sbi], tokens[sbi], token_lens[sbi]
        # trim audio and tokens to max length in sub batch
        max_cur_audio_len, max_cur_token_len = cur_audio_lens.max(), cur_token_lens.max()
        cur_audio, cur_tokens = cur_audio[:, :max_cur_audio_len], cur_tokens[:, :max_cur_token_len]
        sub_batches.append({
            'audio': cur_audio,
            'audio_lens': cur_audio_lens,
            'tokens': cur_tokens,
            'token_lens': cur_token_lens,
            'prev_state_indices': prev_non_empty_fetch[i] if i > 0 else None, # for the first sub batch there is no previous state  
        })
        
    return sub_batches

In [395]:
def move_to_device(sub_batch, device):
    for k, v in sub_batch.items():
        if isinstance(v, torch.Tensor):
            sub_batch[k] = v.to(device)
    return sub_batch

In [396]:
sb = create_subbatches(**z)

In [403]:
sb[0].keys()

dict_keys(['audio', 'audio_lens', 'tokens', 'token_lens', 'prev_state_indices'])

In [413]:
sb[3]['audio_lens'][sb[4]['prev_state_indices']

tensor([ 15040,  49600, 119200], dtype=torch.int32)

In [411]:
sb[3]['audio_lens']

tensor([133440, 112480,  15040,  49600, 119200,  93280, 125440],
       dtype=torch.int32)

In [441]:
q = torch.randn(1, 12,256)
k = q.clone() + torch.randn(1, 12,256)*0.1
dots = torch.einsum('bnd,bmd->bnm', q, k) / q.shape[-1]**0.5
dots.softmax(dim=-1)[0][1]

tensor([3.3991e-07, 9.9999e-01, 8.4997e-07, 2.7006e-06, 3.9144e-07, 3.5621e-07,
        1.0080e-06, 1.0284e-06, 6.0161e-06, 1.2595e-06, 2.9979e-07, 6.0915e-07])

In [100]:
z['metadata'][2]

[{'unique_id': 'f6a3719a-1aa1-48c0-9792-4a28ffd69cb8',
  'timings': {'segment_start': 437.65, 'segment_end': 441.61},
  'recording_id': 'GeverTulley_2007U',
  'utterance_id': 'GeverTulley_2007U-31',
  'speaker': 'GeverTulley_2007U'}]

In [105]:
z.keys()

dict_keys(['audio', 'audio_lens', 'tokens', 'token_lens', 'segment_lens'])

In [106]:
z['audio'].shape

torch.Size([6, 265920])