In [1]:
from torch import nn
from torch.nn import functional as F
import torch
from einops import rearrange, repeat

class DynamicPositionBias(nn.Module):
    '''taken From Phil Wang's x-transformers library'''
    def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
        super().__init__()
        assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
        self.log_distance = log_distance

        self.mlp = nn.ModuleList([])

        self.mlp.append(nn.Sequential(
            nn.Linear(1, dim),
            nn.LayerNorm(dim) if norm else nn.Identity(),
            nn.ReLU()
        ))

        for _ in range(depth - 1):
            self.mlp.append(nn.Sequential(
                nn.Linear(dim, dim),
                nn.LayerNorm(dim) if norm else nn.Identity(),
                nn.ReLU()
            ))

        self.mlp.append(nn.Linear(dim, heads))

    def forward(self, n, device, dtype):

        # get the (n x n) matrix of distances
        seq_arange = torch.arange(n, device = device)
        context_arange = torch.arange(n, device = device)
        indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
        indices += (n - 1)
        
        # input to continuous positions MLP
        pos = torch.arange(-n + 1, n, device = device, dtype = dtype)
        pos = rearrange(pos, '... -> ... 1')
        print(pos.shape)

        if self.log_distance:
            pos = torch.sign(pos) * torch.log(pos.abs() + 1)  # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)

        for layer in self.mlp:
            pos = layer(pos)

        # get position biases        
        bias = pos[indices]
        bias = rearrange(bias, 'i j h -> h i j')
        return bias

In [52]:
class DynamicNSPPositionBias(nn.Module):
    '''Adapted From Phil Wang's x-transformers library for specific case of cross-attention'''
    def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
        super().__init__()
        assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
        self.log_distance = log_distance

        self.mlp = nn.ModuleList([])

        self.mlp.append(nn.Sequential(
            nn.Linear(1, dim),
            nn.LayerNorm(dim) if norm else nn.Identity(),
            nn.ReLU()
        ))

        for _ in range(depth - 1):
            self.mlp.append(nn.Sequential(
                nn.Linear(dim, dim),
                nn.LayerNorm(dim) if norm else nn.Identity(),
                nn.ReLU()
            ))

        self.mlp.append(nn.Linear(dim, heads))

    def forward(self, qn, kn, device, dtype): # set dtype and device to the same as q and k

        # get the (qn x kn) matrix of distances
        seq_arange = torch.arange(kn, device = device, dtype = torch.long) - kn # -kn ... -1
        seq_arange = repeat(seq_arange, 'k -> q k', q = qn) # repeat for each query
        seq_arange = seq_arange - torch.arange(qn, device = device, dtype = torch.long).unsqueeze(-1) # matrix of relative distance between query and keys
        
        minval = -(kn - 1) - (qn - 1) - 1 # extra -1 cus we start at -1
        seq_arange -= minval # shift to positive values to use for indexing

        pos = torch.arange(minval, 0, device = device, dtype = dtype).unsqueeze(-1)
      
        for layer in self.mlp:
            pos = layer(pos)

        bias = pos[seq_arange]
        
        return rearrange(bias, 'qn kn h -> h qn kn') # add this to dot product of q and k
        

In [53]:
dim_model = 512
n_heads = 8
dpos = DynamicNSPPositionBias(
    dim = dim_model // 4,
    heads = n_heads,
    depth = 2,
    log_distance = False,
    norm = False, 
)

In [60]:
pos.squeeze()[1]

tensor([-146., -145., -144., -143., -142., -141., -140., -139., -138., -137.,
        -136., -135., -134., -133., -132., -131., -130., -129., -128., -127.,
        -126., -125., -124., -123., -122., -121., -120., -119., -118., -117.,
        -116., -115., -114., -113., -112., -111., -110., -109., -108., -107.,
        -106., -105., -104., -103., -102., -101., -100.,  -99.,  -98.,  -97.,
         -96.,  -95.,  -94.,  -93.,  -92.,  -91.,  -90.,  -89.,  -88.,  -87.,
         -86.,  -85.,  -84.,  -83.,  -82.,  -81.,  -80.,  -79.,  -78.,  -77.,
         -76.,  -75.,  -74.,  -73.,  -72.,  -71.,  -70.,  -69.,  -68.,  -67.,
         -66.,  -65.,  -64.,  -63.,  -62.,  -61.,  -60.,  -59.,  -58.,  -57.,
         -56.,  -55.,  -54.,  -53.,  -52.,  -51.,  -50.,  -49.,  -48.,  -47.,
         -46.,  -45.,  -44.,  -43.,  -42.,  -41.,  -40.,  -39.,  -38.,  -37.,
         -36.,  -35.,  -34.,  -33.,  -32.,  -31.,  -30.,  -29.,  -28.,  -27.,
         -26.,  -25.,  -24.,  -23.,  -22.,  -21.,  -20.,  -19., 

In [55]:
q = torch.randn(10, 8, 30, 32)
k = torch.randn(10, 8, 150, 32)

dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) # cross attention for dot product
qn, kn = dots.shape[-2:]
pos = dpos(qn, kn, device='cpu', dtype=torch.float32) # position bias only needs to be calculated once and can be reused for all layers
dots = dots + pos # POSITIONIFIED 


RuntimeError: The size of tensor a (30) must match the size of tensor b (150) at non-singleton dimension 2