In [1]:
from mask import (
    create_decoder_atn_mask,
    create_encoder_atn_mask,
)
import torch

In [2]:
batch_size = 2
src_len = 4
tgt_len = 5
num_heads = 2

min_val, max_val = 0, 10

In [3]:
src = torch.randint(min_val, max_val, (batch_size, src_len))
tgt = torch.randint(min_val, max_val, (batch_size, tgt_len))

src.shape, tgt.shape

(torch.Size([2, 4]), torch.Size([2, 5]))

In [4]:
src_mask = create_encoder_atn_mask(src).type(torch.int64)
tgt_mask = create_decoder_atn_mask(tgt).type(torch.int64)
src_tgt_mask = create_encoder_atn_mask(src).type(torch.int64)

In [5]:
src_mask

tensor([[[[8, 6, 4, 3]]],


        [[[5, 0, 2, 9]]]])

In [6]:
tgt_mask

tensor([[[[1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0],
          [1, 0, 0, 0, 0]]],


        [[[1, 0, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 1, 0],
          [1, 1, 0, 1, 0]]]])

In [7]:
attn_src_weight = torch.randint(min_val, max_val, (batch_size, num_heads, src_len, src_len)).type(torch.float32)
attn_tgt_weight = torch.randint(min_val, max_val, (batch_size, num_heads, tgt_len, tgt_len)).type(torch.float32)
attn_src_tgt_weight = torch.randint(min_val, max_val, (batch_size, num_heads, tgt_len, src_len)).type(torch.float32)

In [8]:
w1 = attn_src_weight.masked_fill(
    mask=src_mask == 0,
    value=float("-inf"),
)
w2 = attn_tgt_weight.masked_fill(
    mask=tgt_mask == 0,
    value=float("-inf"),
)
w3 = attn_src_tgt_weight.masked_fill(
    mask=src_tgt_mask == 0,
    value=float("-inf"),
)

In [9]:
attn_src_weight.shape, attn_tgt_weight.shape, attn_src_tgt_weight.shape

(torch.Size([2, 2, 4, 4]), torch.Size([2, 2, 5, 5]), torch.Size([2, 2, 5, 4]))

In [10]:
src_mask.shape, tgt_mask.shape, src_tgt_mask.shape

(torch.Size([2, 1, 1, 4]), torch.Size([2, 1, 5, 5]), torch.Size([2, 1, 1, 4]))

In [11]:
w1.shape, w2.shape, w3.shape

(torch.Size([2, 2, 4, 4]), torch.Size([2, 2, 5, 5]), torch.Size([2, 2, 5, 4]))

In [12]:
src_mask_expand = src_mask.expand(-1, -1, attn_src_weight.size(2), -1).expand(-1, attn_src_weight.size(1), -1, -1)
src_mask_expand.shape

torch.Size([2, 2, 4, 4])

In [13]:
src_mask_expand

tensor([[[[8, 6, 4, 3],
          [8, 6, 4, 3],
          [8, 6, 4, 3],
          [8, 6, 4, 3]],

         [[8, 6, 4, 3],
          [8, 6, 4, 3],
          [8, 6, 4, 3],
          [8, 6, 4, 3]]],


        [[[5, 0, 2, 9],
          [5, 0, 2, 9],
          [5, 0, 2, 9],
          [5, 0, 2, 9]],

         [[5, 0, 2, 9],
          [5, 0, 2, 9],
          [5, 0, 2, 9],
          [5, 0, 2, 9]]]])

In [14]:
w1_expand = attn_src_weight.masked_fill(
    mask=src_mask_expand == 0,
    value=float("-inf"),
)

In [15]:
w1_expand.shape

torch.Size([2, 2, 4, 4])

In [16]:
w1_expand == w1

tensor([[[[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]]],


        [[[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]]]])

In [17]:
tgt_mask_expand = tgt_mask.expand(-1, attn_tgt_weight.size(1), -1, -1)

In [18]:
w2_expand = attn_tgt_weight.masked_fill(
    mask=tgt_mask_expand == 0,
    value=float("-inf"),
)

In [19]:
w2_expand == w2

tensor([[[[True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True]],

         [[True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True]]],


        [[[True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True]],

         [[True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True]]]])

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiheadSlidingWindowSelfAttention(nn.Module):
    def __init__(self, embed_size, num_heads, window_size):
        super(MultiheadSlidingWindowSelfAttention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.window_size = window_size
        self.head_dim = embed_size // num_heads
        
        assert (
            self.head_dim * num_heads == embed_size
        ), "Embedding size needs to be divisible by num_heads"
        
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim]))

    def forward(self, x, mask=None):
        batch_size, seq_length, embed_size = x.shape

        # Tính toán Q, K, V và reshape thành multi-head
        Q = self.query(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
        K = self.key(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
        V = self.value(x).view(batch_size, seq_length, self.num_heads, self.head_dim)

        Q = Q.permute(0, 2, 1, 3)
        K = K.permute(0, 2, 1, 3)
        V = V.permute(0, 2, 1, 3)

        half_window = self.window_size // 2
        attention = torch.zeros(batch_size, self.num_heads, seq_length, self.head_dim).to(x.device)
        
        for i in range(seq_length):
            start = max(0, i - half_window)
            end = min(seq_length, i + half_window + 1)
            
            Q_slice = Q[:, :, i, :].unsqueeze(2)
            K_slice = K[:, :, start:end, :]  # (batch_size, num_heads, k_len, head_dim)
            V_slice = V[:, :, start:end, :]  # (batch_size, num_heads, k_len, head_dim)
            
            # Tính attention scores
            scores = torch.matmul(Q_slice, K_slice.transpose(-2, -1)) / self.scale
            
            print(f"Length = {i}")
            # Áp dụng mask
            print("Mask")
            print(mask)
            if mask is not None:
                mask_slice = mask[:, :, i, start:end].unsqueeze(2)
                print("Mask slice")
                print(mask_slice)
                print()
                scores = scores.masked_fill(mask_slice == 0, float('-inf'))
            
            attention_weights = F.softmax(scores, dim=-1)
            
            # Tính giá trị attention
            attention[:, :, i, :] = torch.matmul(attention_weights, V_slice).squeeze(2)

        # Kết hợp các đầu attention và áp dụng phép chiếu tuyến tính
        attention = attention.permute(0, 2, 1, 3).contiguous()
        attention = attention.view(batch_size, seq_length, embed_size)
        out = self.fc_out(attention)

        return out

# Ví dụ sử dụng
embed_size = 1
num_heads = 1
window_size = 5
seq_length = 5
batch_size = 1

model = MultiheadSlidingWindowSelfAttention(embed_size, num_heads, window_size)
x = torch.randn(batch_size, seq_length, embed_size)

# Tạo mask (ví dụ: mask có kích thước batch_size * seq_length * seq_length)
mask = torch.randint(0, 100, (batch_size, num_heads, seq_length, seq_length))
# Ví dụ: không cho attention tới các vị trí thứ 3 và thứ 7 trong mỗi chuỗi của mỗi batch
# mask[:, :, 2] = 0
# mask[:, :, 6] = 0

output = model(x, mask)
print(output)  # Should output (batch_size, seq_length, embed_size)


Length = 0
Mask
tensor([[[[49, 43, 86, 60, 65],
          [84, 90,  1, 67, 19],
          [74, 21, 57, 11, 38],
          [31, 16, 19,  2, 41],
          [40, 66, 78, 79, 84]]]])
Mask slice
tensor([[[[49, 43, 86]]]])

Length = 1
Mask
tensor([[[[49, 43, 86, 60, 65],
          [84, 90,  1, 67, 19],
          [74, 21, 57, 11, 38],
          [31, 16, 19,  2, 41],
          [40, 66, 78, 79, 84]]]])
Mask slice
tensor([[[[84, 90,  1, 67]]]])

Length = 2
Mask
tensor([[[[49, 43, 86, 60, 65],
          [84, 90,  1, 67, 19],
          [74, 21, 57, 11, 38],
          [31, 16, 19,  2, 41],
          [40, 66, 78, 79, 84]]]])
Mask slice
tensor([[[[74, 21, 57, 11, 38]]]])

Length = 3
Mask
tensor([[[[49, 43, 86, 60, 65],
          [84, 90,  1, 67, 19],
          [74, 21, 57, 11, 38],
          [31, 16, 19,  2, 41],
          [40, 66, 78, 79, 84]]]])
Mask slice
tensor([[[[16, 19,  2, 41]]]])

Length = 4
Mask
tensor([[[[49, 43, 86, 60, 65],
          [84, 90,  1, 67, 19],
          [74, 21, 57, 11, 38],
