In [1]:
from torch import nn
from torch.nn import functional as F
import torch
from einops import rearrange, repeat

class DynamicPositionBias(nn.Module):
    '''taken From Phil Wang's x-transformers library'''
    def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
        super().__init__()
        assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
        self.log_distance = log_distance

        self.mlp = nn.ModuleList([])

        self.mlp.append(nn.Sequential(
            nn.Linear(1, dim),
            nn.LayerNorm(dim) if norm else nn.Identity(),
            nn.ReLU()
        ))

        for _ in range(depth - 1):
            self.mlp.append(nn.Sequential(
                nn.Linear(dim, dim),
                nn.LayerNorm(dim) if norm else nn.Identity(),
                nn.ReLU()
            ))

        self.mlp.append(nn.Linear(dim, heads))

    def forward(self, n, device, dtype):

        # get the (n x n) matrix of distances
        seq_arange = torch.arange(n, device = device)
        context_arange = torch.arange(n, device = device)
        indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
        indices += (n - 1)
        
        # input to continuous positions MLP
        pos = torch.arange(-n + 1, n, device = device, dtype = dtype)
        pos = rearrange(pos, '... -> ... 1')
        print(pos.shape)

        if self.log_distance:
            pos = torch.sign(pos) * torch.log(pos.abs() + 1)  # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)

        for layer in self.mlp:
            pos = layer(pos)

        # get position biases        
        bias = pos[indices]
        bias = rearrange(bias, 'i j h -> h i j')
        return bias

In [25]:
class DynamicNSPPositionBias(nn.Module):
    '''Adapted From Phil Wang's x-transformers library for specific case of cross-attention'''
    def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
        super().__init__()
        assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
        self.log_distance = log_distance

        self.mlp = nn.ModuleList([])

        self.mlp.append(nn.Sequential(
            nn.Linear(1, dim),
            nn.LayerNorm(dim) if norm else nn.Identity(),
            nn.ReLU()
        ))

        for _ in range(depth - 1):
            self.mlp.append(nn.Sequential(
                nn.Linear(dim, dim),
                nn.LayerNorm(dim) if norm else nn.Identity(),
                nn.ReLU()
            ))

        self.mlp.append(nn.Linear(dim, heads))

    def forward(self, qn, kn, device, dtype, lengths): # set dtype and device to the same as q and k

        max_len, min_len = lengths.max().item(), lengths.min().item()
        padding_lens = max_len - lengths
        max_padding = padding_lens.max().item()

        # get the (qn x kn) matrix of distances
        seq_arange = torch.arange(kn, device = device, dtype = torch.long) - kn # -kn ... -1
        seq_arange = repeat(seq_arange, 'k -> q k', q = qn) # repeat for each query
        seq_arange = seq_arange - torch.arange(qn, device = device, dtype = torch.long).unsqueeze(-1) # matrix of relative distance between query and keys

        seq_arange = repeat(seq_arange, 'q k -> b q k', b = len(lengths)) # repeat for each el in batch
        seq_arange = seq_arange + padding_lens[:, None, None] # add padding offsets
        
        minval = seq_arange.min()
        seq_arange -= minval # shift to positive values

        pos = torch.arange(minval, max_padding, device = device, dtype = dtype).unsqueeze(-1)
        #return pos[seq_arange].squeeze() # shows the positions that are being used (look at this to debug)
       
        for layer in self.mlp:
            pos = layer(pos)

        bias = pos[seq_arange]
      
        return rearrange(bias, 'b q k h -> b h q k') # add this to dot product of q and k
        

In [26]:
dim_model = 512
n_heads = 8
dpos = DynamicNSPPositionBias(
    dim = dim_model // 4,
    heads = n_heads,
    depth = 2,
    log_distance = False,
    norm = False, 
)

In [4]:
torch.arange(10)[:,None,None].shape

torch.Size([10, 1, 1])

In [5]:
lengths - lengths.max()

NameError: name 'lengths' is not defined

In [42]:
pos.shape

torch.Size([10, 30, 150])

In [44]:
pos.squeeze()[-1,1]

tensor([-151, -150, -149, -148, -147, -146, -145, -144, -143, -142, -141, -140,
        -139, -138, -137, -136, -135, -134, -133, -132, -131, -130, -129, -128,
        -127, -126, -125, -124, -123, -122, -121, -120, -119, -118, -117, -116,
        -115, -114, -113, -112, -111, -110, -109, -108, -107, -106, -105, -104,
        -103, -102, -101, -100,  -99,  -98,  -97,  -96,  -95,  -94,  -93,  -92,
         -91,  -90,  -89,  -88,  -87,  -86,  -85,  -84,  -83,  -82,  -81,  -80,
         -79,  -78,  -77,  -76,  -75,  -74,  -73,  -72,  -71,  -70,  -69,  -68,
         -67,  -66,  -65,  -64,  -63,  -62,  -61,  -60,  -59,  -58,  -57,  -56,
         -55,  -54,  -53,  -52,  -51,  -50,  -49,  -48,  -47,  -46,  -45,  -44,
         -43,  -42,  -41,  -40,  -39,  -38,  -37,  -36,  -35,  -34,  -33,  -32,
         -31,  -30,  -29,  -28,  -27,  -26,  -25,  -24,  -23,  -22,  -21,  -20,
         -19,  -18,  -17,  -16,  -15,  -14,  -13,  -12,  -11,  -10,   -9,   -8,
          -7,   -6,   -5,   -4,   -3,   

In [33]:
lengths

tensor([ 36,  73, 134, 136,  54, 144, 144, 114,  60, 150])

In [39]:
pos[pos > 0]

IndexError: index 64 is out of bounds for dimension 0 with size 10

In [32]:
pos.squeeze()[0][0] 

tensor([-36, -35, -34, -33, -32, -31, -30, -29, -28, -27, -26, -25, -24, -23,
        -22, -21, -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10,  -9,
         -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,
          6,   7,   8,   9,  10,  11,  12,  13,  14,  15,  16,  17,  18,  19,
         20,  21,  22,  23,  24,  25,  26,  27,  28,  29,  30,  31,  32,  33,
         34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,
         48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,
         62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,
         76,  77,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,
         90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
        104, 105, 106, 107, 108, 109, 110, 111, 112, 113])

In [7]:
pos.squeeze()[-1].shape

NameError: name 'pos' is not defined

In [8]:
lengths - lengths.max()

NameError: name 'lengths' is not defined

In [9]:
pos.squeeze()[-2][-1]

NameError: name 'pos' is not defined

In [10]:
pos.min(-1).values.min(-1)  

NameError: name 'pos' is not defined

In [11]:
repeat(pos, 'q k -> b q k', b = 10).shape

NameError: name 'pos' is not defined

In [36]:
q = torch.randn(10, 8, 30, 32)
k = torch.randn(10, 8, 150, 32)
lengths = torch.randint(30, 150, (10,))
lengths[-1] = 150

dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) # cross attention for dot product
qn, kn = dots.shape[-2:]
pos = dpos(qn, kn, device='cpu', dtype=torch.float32, lengths=lengths) # position bias only needs to be calculated once and can be reused for all layers
dots = dots + pos # POSITIONIFIED 


RuntimeError: The size of tensor a (8) must match the size of tensor b (10) at non-singleton dimension 1