In [1]:
import torch, torch.nn as nn, torch.nn.functional as F
from einops import rearrange, repeat
from torch import einsum


In [374]:
class DynamicPositionBias(nn.Module):
    '''taken From Phil Wang's x-transformers library'''
    def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
        super().__init__()
        assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
        self.log_distance = log_distance

        self.mlp = nn.ModuleList([])

        self.mlp.append(nn.Sequential(
            nn.Linear(1, dim),
            nn.LayerNorm(dim) if norm else nn.Identity(),
            nn.ReLU()
        ))

        for _ in range(depth - 1):
            self.mlp.append(nn.Sequential(
                nn.Linear(dim, dim),
                nn.LayerNorm(dim) if norm else nn.Identity(),
                nn.ReLU()
            ))

        self.mlp.append(nn.Linear(dim, heads))

    def forward(self, i, j, device, dtype):
        # get the (n x n) matrix of distances
        seq_arange = torch.arange(i, device = device)
        context_arange = torch.arange(j, device = device)
        indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
        indices += (j-1)
        
        # input to continuous positions MLP
        pos = torch.arange(-i + 1, (j+i), device = device, dtype = dtype)
        pos = rearrange(pos, '... -> ... 1')
    
        if self.log_distance:
            pos = torch.sign(pos) * torch.log(pos.abs() + 1)  # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)

        for layer in self.mlp:
            pos = layer(pos)

        # get position biases        
        bias = pos[indices]
        bias = rearrange(bias, 'i j h -> h i j')
        return bias

In [375]:
pos = DynamicPositionBias(64, heads = 8, depth = 2, log_distance = False)
pos = pos(6,7, device = 'cpu', dtype = torch.float32)
pos.shape

torch.Size([8, 6, 7])

In [696]:
class ReLUSquared(nn.Module):
    def forward(self, x):
        return torch.pow(F.relu(x), 2)

def l2norm(t, groups = 1, dim = -1):
    if groups == 1:
        return F.normalize(t, p = 2, dim = dim)
    t = rearrange(t, '... (g d) -> ... g d', g = groups)
    t = F.normalize(t, p = 2, dim = dim)
    return rearrange(t, '... g d -> ... (g d)')


class CosineAttention(nn.Module):
    def __init__(
        self,
        n_feats,
        head_dim,
        n_heads,
        dropout=0.1,
        bias=False,
        temperature=15.5,
        return_attention=False,
        causal=False,
        activation='softmax',
        **kwargs
    ):
        super().__init__()
        assert activation in ['relusq', 'softmax']
        self.shared_kv = kwargs.get('shared_kv', False)
        self.talking_heads = kwargs.get('talking_heads', False)
        self.cache_kv = kwargs.get('cache_kv', False) # whether prev key and values are used 
        

        self.n_feats, self.head_dim, self.n_heads = n_feats, head_dim, n_heads
        self.dropout = nn.Dropout(dropout)
        self.bias = bias
        self.return_attention = return_attention
        self.causal = causal

        if self.talking_heads:
            self._head_proj = nn.Conv2d(n_heads, n_heads, (1, 1))

        self.temperature = torch.nn.Parameter(torch.tensor(temperature), requires_grad=True) if isinstance(temperature, float) else temperature

        self.activation = ReLUSquared() if activation == 'relusq' else nn.Softmax(dim=-1)

        if not self.shared_kv:
            self.qkv_proj = nn.Linear(n_feats, 3 * n_heads * head_dim, bias=bias)
            self.qkv = lambda x: rearrange(self.qkv_proj(x), "b n (h d qkv) -> qkv b h n d", qkv=3, h=n_heads, d=head_dim)
        else:
            self.q_proj, self.kv_proj = [nn.Linear(n_feats, el, bias=bias) for el in [n_heads * head_dim, 2 * head_dim]]
            map_q, map_kv = lambda q: rearrange(q, 'b n (h d) -> b h n d', h=n_heads), lambda kv: rearrange(kv, 'b n (kv d) -> kv b () n d', kv=2, d=head_dim)
            self.qkv = lambda x: (map_q(self.q_proj(x)), *map_kv(self.kv_proj(x)))

        self.out_proj = nn.Linear(n_heads * head_dim, n_feats, bias=bias)

        if self.cache_kv:
            cache_heads = n_heads if not self.shared_kv else 1
            self.cache_vector = torch.nn.Parameter(torch.zeros(1, cache_heads, 1, head_dim), requires_grad=True)
            print(self.cache_vector.shape)
    
    def head_proj(self, dots):
        if not self.talking_heads:
            return dots
        dots = self._head_proj(dots)
        return dots      

    def attend(self, query, key, value, mask, pos_fn):
        dots = einsum('bhid,bhjd->bhij', query, key) * self.temperature
        dots = self.head_proj(dots)

        dots += pos_fn(i=dots.shape[-2], j=dots.shape[-1], device=dots.device, dtype=dots.dtype)
        qkmask = ~mask
        attn_mask = ~(rearrange(qkmask, "b n -> b () n ()") * rearrange(qkmask, "b n -> b () () n"))
    
        if self.causal: # create a regular causal mask
            causal_mask = torch.ones(dots.shape[-2], dots.shape[-1], device=dots.device).triu(1).bool()
            attn_mask = torch.logical_or(attn_mask, causal_mask)
        
        dots.masked_fill_(attn_mask, -torch.finfo(dots.dtype).max)
    
        attn = self.activation(dots)
     
        attn = self.dropout(attn)
        return einsum("bhij,bhjd->bhid", attn, value)

    def lengths_from_mask(self, x, mask):
        if mask is None:
            return x.shape[-2]
        return (~mask).sum(dim=-1)

    def attach_cache(self, k, v, mask, cache_kv, cache_mask):
        if cache_kv is None:
            return k, v, mask
        
        cache_k, cache_v = cache_kv
        
        cache_k, cache_v = cache_k.to(k.device), cache_v.to(k.device)
        cache_vector = self.cache_vector.to(k.device)
        cache_k, cache_v = cache_k + cache_vector, cache_v + cache_vector
        cache_lens = self.lengths_from_mask(cache_k, cache_mask)
        max_cache_len = cache_lens.max()    
        x_lens = self.lengths_from_mask(k, mask)
        new_lens = x_lens + cache_lens
        max_new_len = new_lens.max()
        # so we want to remove excess padding and only have padding at the end of the sequence
        # otherwise things get weird with the position encoding
        # lets used gather to do this (not sure if there is a faster way)
        # fk ill use a for loop to get the indices
        indices = []
        new_k, new_v = torch.cat([cache_k, k], dim=-2), torch.cat([cache_v, v], dim=-2)
   
        # add zero to last dimension to use to fetch the padding
        B,H,N,D= new_k.shape
        zero_vector = torch.zeros((B,H,1,D), device=new_k.device)
        new_k,new_v = torch.cat([new_k, zero_vector], dim=-2), torch.cat([new_v, zero_vector], dim=-2)
    

        for i in range(new_k.shape[0]):
            cache_indices = torch.arange(cache_lens[i], device='cpu')
            total_length = cache_lens[i] + x_lens[i]
            diff_from_max_len = max_new_len - total_length
            x_indices = torch.arange(x_lens[i]+diff_from_max_len, device='cpu') + cache_k.shape[-2]
            if diff_from_max_len > 0:
                x_indices[-diff_from_max_len:] = new_k.shape[-2] - 1
            new_indices = torch.cat([cache_indices, x_indices])
            indices.append(new_indices)

        indices = torch.stack(indices, dim=0)
        # NOW LETS GATHER
        indices = rearrange(indices, 'b n -> () b () n ()').expand(2,B,H,N,D)
        new_kv = torch.stack([new_k, new_v], dim=0) # avoid double gather
        new_kv = torch.gather(new_kv, dim=-2, index=indices)
        new_k, new_v = new_kv
        # just create the new mask
        new_mask = torch.arange(max_new_len, device=mask.device) >= new_lens[:, None]
        return new_k, new_v, new_mask, k, v
       
    
        
    

        

    def forward(self, x, pos_fn, mask=None, cached_kv=None, cached_mask=None):
        assert pos_fn is not None, 'pls provide a position function'
        B, N, C, H, D = *x.shape, self.n_heads, self.head_dim
        #print(x.shape, mask.shape)

        if mask is None:
            mask = torch.zeros(B, N, device=x.device, dtype=torch.bool)

        q, k, v = self.qkv(x)
        q, k = map(l2norm, (q, k))
        if self.cache_kv:
            new_k, new_v, new_mask, k, v = self.attach_cache(k, v, mask, cached_kv, cached_mask)
        return new_k, new_v, new_mask, k, v
        out = self.attend(q, k, v, mask, pos_fn)

        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.out_proj(out)
        return out if not self.cache_kv else (out, k, v)

In [697]:
128+92

220

In [699]:
pos_fn = DynamicPositionBias(dim = 64, heads = 8, log_distance = False, depth = 2)
cAttn = CosineAttention(n_feats=256, head_dim=32, n_heads=8, temperature=15.5, cache_kv=True, shared_kv=True, talking_heads=True)
x = torch.randn(3, 128, 256)
cached_kv = torch.randn(1,3,1,92,32).expand(2,-1,-1,-1,-1)
cached_mask = torch.zeros(3, 92, dtype=torch.bool)
cached_mask[0, -3:] = True
mask = torch.zeros(3, 128, dtype=torch.bool)

out = cAttn(x, pos_fn, mask, cached_kv, cached_mask)

torch.Size([1, 1, 1, 32])


In [702]:
out[-2][0,0,:,0]

tensor([ 0.3906,  0.0710,  0.0172, -0.2074, -0.0701,  0.2890,  0.3494, -0.0264,
        -0.1111, -0.4142, -0.0788,  0.0519, -0.0026,  0.0767, -0.0388, -0.0247,
        -0.1080,  0.2853,  0.0020, -0.1601, -0.2944,  0.0677, -0.1763, -0.0125,
         0.1275, -0.2477, -0.1884, -0.0234,  0.1822, -0.0721, -0.3212,  0.1506,
         0.2194, -0.1693,  0.0064,  0.0831, -0.1904,  0.3044,  0.1207, -0.1494,
        -0.0674,  0.1864,  0.0760,  0.1781,  0.0145, -0.1953,  0.2129,  0.0245,
        -0.3183,  0.1486, -0.3075,  0.0647, -0.0644, -0.1859, -0.1389,  0.0577,
        -0.0197, -0.1308,  0.2411,  0.3444,  0.1453,  0.0641,  0.2817, -0.0930,
         0.0313, -0.0715, -0.1922, -0.0799, -0.3564,  0.1434,  0.2053,  0.2263,
        -0.1312,  0.0874, -0.0185,  0.2556, -0.2807,  0.1241,  0.2018,  0.0927,
        -0.1649, -0.2063, -0.0261,  0.0648,  0.0185,  0.0333, -0.1019, -0.1100,
         0.1062, -0.0849,  0.0339, -0.1424,  0.0999, -0.1539, -0.0010, -0.0622,
         0.2464,  0.0281,  0.0983, -0.04

In [709]:
out[0][0,0,:,0][92]

tensor(-0.2074, grad_fn=<SelectBackward0>)

In [35]:
import torch.nn as nn, torch

class HydraAttention(nn.Module):
    def __init__(self, d_model, output_layer='scale_and_bias'):
        '''
        output_layer: 'scale_and_bias' | 'linear' | 'none'
        '''
        super(HydraAttention, self).__init__()
        self.d_model = d_model
        self.qkv = nn.Linear(d_model, d_model * 3)
        if output_layer == 'scale_and_bias':
            self.scale = nn.Parameter(torch.ones(1, 1, d_model))
            self.bias = nn.Parameter(torch.zeros(1, 1, d_model))
            self.out = lambda x: x * self.scale + self.bias
        elif output_layer == 'linear':
            self.out = nn.Linear(d_model, d_model)
        elif output_layer == 'none':
            self.out = nn.Identity()

    def forward(self, x):
        '''x: (B, T, D)'''
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        q = q / q.norm(dim=-1, keepdim=True)
        k = k / k.norm(dim=-1, keepdim=True)
        kv = (k * v).sum(dim=-2, keepdim=True)
        out = q * kv
        return self.out(out)

In [37]:
HydraAttention(256)(x)

tensor([[[-2.6228e-04,  1.3423e-01,  3.1196e-02,  ..., -5.5381e-02,
           2.3028e-02, -1.9155e-02],
         [-8.7564e-06,  8.4757e-02,  1.5684e-02,  ...,  1.5657e-02,
          -5.4329e-03, -1.1544e-02],
         [-9.8780e-05, -1.9900e-02,  2.5104e-02,  ..., -4.6537e-02,
          -2.8425e-02, -1.3837e-02],
         ...,
         [ 3.7448e-04, -7.5378e-02, -6.1634e-02,  ...,  2.0541e-02,
           1.8774e-03, -7.0934e-02],
         [ 3.6962e-04, -4.1401e-02,  4.3923e-03,  ..., -1.4942e-02,
           3.6164e-02, -1.0721e-02],
         [-1.0050e-04, -7.9818e-03, -3.8386e-02,  ..., -3.7135e-02,
           1.2968e-02, -7.7314e-02]]], grad_fn=<AddBackward0>)