In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [51]:
B = 8
L = 2048
D = 128

ATT_WINDOW = 512

In [52]:
Q = torch.randn(B, L, D)
K = torch.randn(B, L, D)
V = torch.randn(B, L, D)

In [60]:
# local attention
# 0, 1, 2, 3
# 0, 1, 2, 3
# 0, 2, 3, 4
# 0, 3, 4, 5
# ...
offset_start = -ATT_WINDOW // 2 + 1
offset_end = ATT_WINDOW // 2 + 1
offsets = torch.arange(offset_start, offset_end).view(1, -1)

min_idx = 1 - offset_start
max_idx = L - offset_end
start_idxs = torch.arange(0, L).clip(min_idx, max_idx).view(-1, 1)
idxs = start_idxs + offsets

# add 0 to the left
idxs = torch.cat([torch.zeros_like(idxs[:, :1]), idxs], dim=1).long()

print(idxs[[0, 1, -2, -1]])
print(idxs.shape)

tensor([[   0,    1,    2,  ...,  510,  511,  512],
        [   0,    1,    2,  ...,  510,  511,  512],
        [   0, 1536, 1537,  ..., 2045, 2046, 2047],
        [   0, 1536, 1537,  ..., 2045, 2046, 2047]])
torch.Size([2048, 513])


In [54]:
# i-th token has to attend to tokens idxs[i]
print(K[:, idxs, :].shape)  # [B, L, W, D]
print(V[:, idxs, :].shape)  # [B, L, W, D]
print(Q.shape) # [B, L, D]

attention_scores = torch.einsum('b l w d , b l d -> b l w', K[:, idxs, :], Q) / (D ** 0.5)  # [B, L, W]
attention_weights = F.softmax(attention_scores, dim=-1)  # [B, L, W]
outputs = torch.einsum('b l w , b l w d -> b l d', attention_weights, V[:, idxs, :])  # [B, L, D]

print(outputs.shape)

torch.Size([8, 2048, 512, 128])
torch.Size([8, 2048, 512, 128])
torch.Size([8, 2048, 128])
torch.Size([8, 2048, 128])


In [55]:
attention_scores

tensor([[[-5.2706e-01, -2.2102e-01,  1.8632e+00,  ..., -3.5318e-01,
          -1.5379e+00, -5.8125e-01],
         [ 4.7987e-01,  1.1657e+00, -6.7558e-01,  ...,  1.6880e+00,
           1.4380e+00,  1.6206e-01],
         [ 1.0809e+00,  1.1431e+00,  5.4405e-01,  ..., -5.2006e-01,
          -4.1561e-01,  7.4239e-02],
         ...,
         [ 1.8549e+00,  6.6663e-01, -1.7309e+00,  ..., -5.3762e-01,
           6.6817e-01,  1.5919e+00],
         [ 3.5798e-01,  3.0497e-01,  4.1004e-03,  ..., -2.3174e-01,
          -2.0589e-01,  3.4689e-01],
         [ 1.5518e-01, -7.4561e-01,  1.1700e+00,  ..., -1.2351e+00,
           1.9037e+00, -2.4481e-01]],

        [[ 1.8422e-01,  1.0892e+00, -2.0175e-01,  ...,  9.9590e-01,
          -5.2433e-01, -1.0689e+00],
         [ 1.8433e+00,  2.1716e-01,  1.6302e+00,  ..., -7.1174e-01,
          -1.7439e-01,  9.7945e-01],
         [ 3.6468e-01, -3.7308e-01,  9.5988e-01,  ...,  1.4037e+00,
          -3.8576e-01,  9.7395e-01],
         ...,
         [ 3.0268e-01,  1