In [1]:
%cd ..

/home/pablo/long-transformers


In [2]:
import sys
sys.path.append('.')

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
B = 8
L = 2048
D = 128

ATT_WINDOW = 512

In [5]:
Q = torch.randn(B, L, D)
K = torch.randn(B, L, D)
V = torch.randn(B, L, D)

In [6]:
# local attention
# 0, 1, 2, 3
# 0, 1, 2, 3
# 0, 2, 3, 4
# 0, 3, 4, 5
# ...
offset_start = -ATT_WINDOW // 2 + 1
offset_end = ATT_WINDOW // 2 + 1
offsets = torch.arange(offset_start, offset_end).view(1, -1)

min_idx = 1 - offset_start
max_idx = L - offset_end
start_idxs = torch.arange(0, L).clip(min_idx, max_idx).view(-1, 1)
idxs = start_idxs + offsets

# add 0 to the left
idxs = torch.cat([torch.zeros_like(idxs[:, :1]), idxs], dim=1).long()

print(idxs[[0, 1, -2, -1]])
print(idxs.shape)

tensor([[   0,    1,    2,  ...,  510,  511,  512],
        [   0,    1,    2,  ...,  510,  511,  512],
        [   0, 1536, 1537,  ..., 2045, 2046, 2047],
        [   0, 1536, 1537,  ..., 2045, 2046, 2047]])
torch.Size([2048, 513])


In [7]:
# i-th token has to attend to tokens idxs[i]
print(K[:, idxs, :].shape)  # [B, L, W, D]
print(V[:, idxs, :].shape)  # [B, L, W, D]
print(Q.shape) # [B, L, D]

attention_scores = torch.einsum('b l w d , b l d -> b l w', K[:, idxs, :], Q) / (D ** 0.5)  # [B, L, W]
attention_weights = F.softmax(attention_scores, dim=-1)  # [B, L, W]
outputs = torch.einsum('b l w , b l w d -> b l d', attention_weights, V[:, idxs, :])  # [B, L, D]

print(outputs.shape)

torch.Size([8, 2048, 513, 128])
torch.Size([8, 2048, 513, 128])
torch.Size([8, 2048, 128])
torch.Size([8, 2048, 128])


In [8]:
attention_scores

tensor([[[ 4.5883e-01,  2.7065e+00, -2.9602e+00,  ...,  1.3942e+00,
           7.2766e-01, -4.6214e-01],
         [ 4.2545e-01,  6.6448e-01, -3.2067e-01,  ...,  1.0622e+00,
          -9.0218e-01, -6.7893e-01],
         [-8.4039e-01, -4.7138e-01, -1.8840e+00,  ..., -4.0091e-01,
           2.4124e-01,  6.7348e-02],
         ...,
         [-7.4906e-01,  1.6485e+00,  3.8024e-01,  ..., -3.7892e-01,
          -1.0379e+00, -6.9358e-01],
         [-6.5863e-01, -9.9215e-01,  2.2255e-02,  ...,  1.3609e-01,
           1.4068e+00, -1.1733e+00],
         [-1.0103e+00, -4.9050e-01,  6.5288e-02,  ..., -6.6780e-01,
          -6.5068e-01, -2.5657e-01]],

        [[ 4.9925e-01,  1.0430e+00,  1.4871e-01,  ...,  2.0554e-01,
          -1.1469e+00, -1.3343e+00],
         [-1.4126e+00,  6.1874e-01,  9.4093e-01,  ..., -8.7320e-02,
          -1.7498e+00,  3.6725e-01],
         [-1.6615e+00,  1.8562e-01, -4.5033e-01,  ..., -4.8201e-01,
           3.0066e-01, -8.1920e-02],
         ...,
         [-7.5891e-01,  2

In [9]:
global_tokens = torch.tensor([-1, -2]).repeat(L, 1)
idxs = torch.cat([global_tokens, idxs], dim=1)
print(idxs[[0, 1, -2, -1]])
print(idxs.shape)

tensor([[  -1,   -2,    0,  ...,  510,  511,  512],
        [  -1,   -2,    0,  ...,  510,  511,  512],
        [  -1,   -2,    0,  ..., 2045, 2046, 2047],
        [  -1,   -2,    0,  ..., 2045, 2046, 2047]])
torch.Size([2048, 515])


## Testing the module

In [10]:
import torch

from src.models.modules.local_multihead_self_attention import LocalMultiheadSelfAttention

In [11]:
B = 8
H = 8
L = 2048
D = 128

ATT_WINDOW = 512

In [12]:
mha = LocalMultiheadSelfAttention(D, H, ATT_WINDOW)

In [13]:
embeddings = torch.randn(B, L, D)
attention_mask = torch.ones(B, L).bool()
outputs = mha(embeddings, attention_mask)

In [14]:
outputs.shape

torch.Size([8, 2048, 128])

## Trying unfold

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.masked import masked_tensor

In [10]:
B = 8
L = 2048
D = 128

In [11]:
l_rng = - torch.abs(L - 1 - torch.arange(0, 2 * L - 1)) / L
l_rng

tensor([-0.9995, -0.9990, -0.9985,  ..., -0.9985, -0.9990, -0.9995])

In [12]:
attn_mask = l_rng.unfold(0, L, 1).view(1, 1, L, L).expand(B, 1, L, L)

In [15]:
attn_mask = l_rng.unfold(0, L, 1)# .flip(dims=(1,))  # [L, L]
padding_mask = torch.rand(B, L) > 0.5  # [B, L]

In [14]:
# view both as [B, 1, L, L]
attn_mask = attn_mask.view(1, 1, L, L).expand(B, 1, L, L)
padding_mask = padding_mask.view(B, 1, 1, L).expand(B, 1, L, L)

In [17]:
len(attn_mask.untyped_storage()), len(padding_mask.untyped_storage())

(16777216, 16384)

In [54]:
mt = masked_tensor(attn_mask, padding_mask)



In [55]:
len(mt.untyped_storage())

134217728

## Trying conv

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
B = 8
H = 8
L = 2048
D = 128
WINDOW_SIZE = 32 * 2 + 1

In [5]:
V = torch.randn(B, H, L, D)

In [None]:
kernel_logits = nn.Parameter(torch.Empty(H, WINDOW_SIZE, 1))
alpha_logits = nn.Parameter(torch.Empty(H, 1, 1))