## Masked Multi-Head Attention

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V$$
- $M$: Look-Ahead Mask
    - $M = \begin{bmatrix} 
0 & -\infty & -\infty & -\infty \\
0 & 0 & -\infty & -\infty \\
0 & 0 & 0 & -\infty \\
0 & 0 & 0 & 0 
\end{bmatrix}$
    - $0$인 부분: 볼 수 있는 과거와 현재
    - $-\infty$인 부분: 볼 수 없는 미래(Softmax를 취하면 $e^{-\infty} \approx 0$이 되어 확률이 사라짐)

In [48]:
import math
from typing import Optional

import torch
import torch.nn as nn


In [16]:
# [L, dim]
L , dim = 10, 64
x = torch.randn(L, dim, dtype=torch.float, requires_grad=True)

# [L ,L]
mask = torch.ones(L, L)
torch.tril(mask)

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

In [None]:
# torch.tril: Lower triangle
    # 밑 부분만 남김
# torch.triu: Upper triangle
    # 아래 부분만 남김

In [20]:
add_mask = torch.triu(torch.ones(L, L), diagonal=1)
add_mask

tensor([[0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [22]:
add_mask.masked_fill(add_mask == 1, float('-inf'))
add_mask.masked_fill(add_mask == 0, 0.0)


tensor([[0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [25]:
torch.tril(torch.ones(L,L))

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

In [26]:
torch.triu(torch.ones(L,L))

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

### 방법 1
$$\text{Attention}(Q, K, V) = \text{softmax}(\text{masked\_fill}(\frac{QK^T}{\sqrt{d_k}}, \text{Mask}==0, -\infty))V$$

1. Binary Mask
    - 1(True): 유효한 영역
    - 0(False): 가려야 할 영역
1. Mapping
    - `masked_fill` 사용
    - 마스크가 0인 위치의 값만 강제로 $-\infty$로 덮어씌우고, 나머지 값은 건드리지 않고 그대로 

In [31]:
# 방식 1
L, dim = 10, 64
# x를 softmax 전이라고 생각
x = torch.randn(L, L, requires_grad=True)
causal = torch.tril(torch.ones(L,L, dtype=torch.long))
print(causal)
x.masked_fill(causal == 0, -1e9)
# 이후는 동일함

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])


tensor([[ 1.0155e-01, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,
         -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [ 1.5688e+00,  5.9407e-01, -1.0000e+09, -1.0000e+09, -1.0000e+09,
         -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-1.9798e-01,  4.3771e-01, -1.4651e+00, -1.0000e+09, -1.0000e+09,
         -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [ 6.7507e-01, -1.8383e+00,  8.3622e-01, -2.4066e-02, -1.0000e+09,
         -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-3.5347e-01,  7.4248e-01, -1.0590e+00,  4.7114e-01, -7.9914e-01,
         -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-1.3309e+00,  7.7494e-01, -1.1387e+00,  1.9499e-01,  8.2780e-02,
          4.9165e-01, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-5.0612e-01,  1.9958e-01, -1.0390e+00, -3.3517e-02,  9.5005e-01,
         -1.1797e+00, -2.0787e+0

In [43]:
B, L, dim, H = 4, 10, 64, 8
# [B, L, dim]
x = torch.randn(B, H, L, L, dtype=torch.float, requires_grad=True)
input_ids = torch.tensor([
    [0,0,0,0,0, 5,6,7,8,9],
    [0,0,3,4,5, 6,7,8,9,1],
    [2,3,0,0,0, 0,0,0,0,0],
    [1,2,3,4,5, 6,0,0,0,0],
], dtype=torch.long)

# padding mask
# [B, L] -> [B, 1, 1, L]
pad_mask = (input_ids != 0)[:, None, None, :]
print(pad_mask.shape)

# causal_mask
# [L, L] -> [1, 1, L, L]
causal = torch.tril(torch.ones(L, L, dtype=torch.long))
causal = causal[None, None, :, :]

# 결합
# padding: [B, 1, 1, L]
# causal: [1, 1, L, L]
# [B, 1, 1, L] -> [B, 1, L, L]
# [1, 1, L, L] -> [B, 1, L, L]
mask = pad_mask * causal

output = x.masked_fill(mask == 0, -1e9)
print(output.shape)
print(output[0][-1])

torch.Size([4, 1, 1, 10])
torch.Size([4, 8, 10, 10])
tensor([[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,
         -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,
         -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,
         -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,
         -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,
         -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,
          6.1687e-01, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.00

### 방법2
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M_{\text{additive}}\right)V$$

1. Additive Mask 행렬
    $- 0$: 유효한 영역 (값 보존)
    - $-\infty$ (또는 $-1e9$): 가려야 할 영역 (값 파괴)

In [36]:
x = torch.randn(L, L, requires_grad=True)

additive_mask = torch.zeros(L, L)

mask_indices = torch.triu_indices(L, L, offset=1)
additive_mask[mask_indices[0], mask_indices[1]] = -1e9
additive_mask

x + additive_mask

tensor([[-4.1055e-02, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,
         -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-1.4350e-01,  1.3683e+00, -1.0000e+09, -1.0000e+09, -1.0000e+09,
         -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [ 1.5323e+00, -6.2992e-01, -1.7739e-01, -1.0000e+09, -1.0000e+09,
         -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [ 4.2355e-01,  5.5030e-01,  7.4291e-01, -1.7069e-01, -1.0000e+09,
         -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [ 7.6101e-03,  1.5022e-01,  7.8821e-01, -1.6962e+00, -7.8850e-01,
         -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-2.2091e-01, -5.7895e-01,  9.2281e-01, -1.9385e-01, -2.0747e-01,
         -2.1172e+00, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-3.4514e-01,  6.9409e-01,  1.3448e-01,  7.5586e-01, -5.8678e-01,
         -2.6284e-02,  2.7105e-0

In [50]:
class MultiHeadAttention:
    def __init__(
        self,
        d_model: int,
        heads: int,
        bias: bool = True
    ):
        if d_model % heads != 0:
            raise ValueError("d_model must be divisible by heads.")
        
        self.d_model = d_model
        self.heads = heads
        self.head_dim = d_model // heads
        self.bias = bias
        self.training = True

        self.scale = math.sqrt(d_model)
        self.w_q = torch.randn(d_model, d_model) / self.scale
        self.w_q.requires_grad_()
        self.w_k = torch.randn(d_model, d_model) / self.scale
        self.w_k.requires_grad_()
        self.w_v = torch.randn(d_model, d_model) / self.scale
        self.w_v.requires_grad_()
        self.w_o = torch.randn(d_model, d_model) / self.scale
        self.w_o.requires_grad_()

        if bias:
            self.b_q = torch.zeros(d_model, requires_grad=True)
            self.b_k = torch.zeros(d_model, requires_grad=True)
            self.b_v = torch.zeros(d_model, requires_grad=True)
            self.b_o = torch.zeros(d_model, requires_grad=True)
        else:
            self.b_q = self.b_k = self.b_v = self.b_o = None  
    
    def __call__(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_padding_mask: Optional[torch.Tensor] = None,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.forward(query, key, value, key_padding_mask, attn_mask)
    
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_padding_mask: Optional[torch.Tensor] = None,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size = query.size(0)

        # Q, K, V 생성 (각각은 [B, max_len, d_model])
        # [B, max_len, d_model] * [d_model, d_model] -> [B, max_len, d_model]
        q = torch.matmul(query, self.w_q)
        if self.b_q is not None:
            q = q + self.b_q
        k = torch.matmul(key, self.w_k)
        if self.b_k is not None:
            k = k + self.b_k
        v = torch.matmul(value, self.w_v)
        if self.b_v is not None:
            v = v + self.b_v
        
        # q,k,v를 multihead로 reshape
        # [B, max_len, d_model] -> [B, max_len, heads, d_heads] -> [B, heads, max_len, d_heads]
        q = q.view(batch_size, -1, self.heads, self.head_dim).transpose(1,2)
        k = k.view(batch_size, -1, self.heads, self.head_dim).transpose(1,2)
        v = v.view(batch_size, -1, self.heads, self.head_dim).transpose(1,2)

        # attention_score
        # [B, heads, max_len, d_heads] @ [B, heads, d_heads, max_len] -> [B, heads, max_len, max_len]
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Mask
        # [L_q, L_k] -> [1, 1, L_q, L_k]
        if attn_mask is not None:
            if attn_mask.dim() == 2:
                attn_mask = attn_mask[None, None, :, :]
        
        # padding
        # [B,L] -> [B, 1, 1, L] 
        if key_padding_mask is not None:
            if key_padding_mask.dim() == 2:
                key_padding_mask = key_padding_mask[:, None, None, :]

        if attn_mask is not None and key_padding_mask is not None:
            mask = attn_mask * key_padding_mask
        elif attn_mask is not None:
            mask = attn_mask
        elif key_padding_mask is not None:
            mask = key_padding_mask
        else:
            mask = None
        
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        # attention weights
        # [B, heads, max_len, max_len] @ [B, heads, max_len, d_heads] -> [B, heads, max_len, d_heads]
        attn_weights = torch.softmax(attn_scores, dim=-1)
        context = torch.matmul(attn_weights, v)

        # Concationate
        # [B, heads, max_len, d_heads] -> [B, max_len, heads, d_heads] -> [B, max_len, d_model]
        context = context.transpose(1, 2).contiguous()
        context = context.view(batch_size, -1, self.d_model)

        output = torch.matmul(context, self.w_o)
        if self.b_o is not None:
            output = output + self.b_o

        return output

    def parameters(self):
        params = [self.w_q, self.w_k, self.w_v, self.w_o]
        if self.b_q is not None:
            params.extend([self.b_q, self.b_k, self.b_v, self.b_o])
        return params
    
    def train(self, mode: bool = True):
        self.training = mode
        return self

    def eval(self):
        return self.train(False)
    
    def to(self, device: torch.device):
        self.w_q = self.w_q.to(device).detach().requires_grad_(True)
        self.w_k = self.w_k.to(device).detach().requires_grad_(True)
        self.w_v = self.w_v.to(device).detach().requires_grad_(True)
        self.w_o = self.w_o.to(device).detach().requires_grad_(True)

        if self.b_q is not None:
            self.b_q = self.b_q.to(device).detach().requires_grad_(True)
            self.b_k = self.b_k.to(device).detach().requires_grad_(True)
            self.b_v = self.b_v.to(device).detach().requires_grad_(True)
            self.b_o = self.b_o.to(device).detach().requires_grad_(True)
        
        return self
    
    def zero_grad(self):
        for param in self.parameters():
            if param.grad is not None:
                param.grad.zero_()

In [53]:
torch.manual_seed(0)

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print("device:", device)

device: mps


In [55]:
B, Lq, Lk, d_model, heads = 2, 5, 7, 64, 8

mha = MultiHeadAttention(d_model=d_model, heads=heads).to(device)

q = torch.randn(B, Lq, d_model, device=device)
k = torch.randn(B, Lk, d_model, device=device)
v = torch.randn(B, Lk, d_model, device=device)

out = mha(q, k, v)
print("out shape:", out.shape)

out shape: torch.Size([2, 5, 64])


In [56]:
B, L, d_model, heads = 2, 6, 64, 8
mha = MultiHeadAttention(d_model=d_model, heads=heads).to(device)

x = torch.randn(B, L, d_model, device=device)
out = mha(x, x, x)

loss = out.mean()
mha.zero_grad()
loss.backward()

for name, p in zip(["w_q","w_k","w_v","w_o"], [mha.w_q, mha.w_k, mha.w_v, mha.w_o]):
    g = p.grad
    print(name, "grad is None?", g is None, "| grad norm:", None if g is None else g.norm().item())

w_q grad is None? False | grad norm: 0.10621220618486404
w_k grad is None? False | grad norm: 0.1277574896812439
w_v grad is None? False | grad norm: 0.2637781798839569
w_o grad is None? False | grad norm: 0.28294530510902405


In [57]:
import torch.optim as optim

B, L, d_model, heads = 2, 6, 64, 8
mha = MultiHeadAttention(d_model=d_model, heads=heads).to(device)

opt = optim.SGD(mha.parameters(), lr=1e-1)

x = torch.randn(B, L, d_model, device=device)

wq_before = mha.w_q.detach().clone()

out = mha(x, x, x)
loss = out.pow(2).mean()

opt.zero_grad()
loss.backward()
opt.step()

diff = (mha.w_q.detach() - wq_before).abs().sum().item()
print("w_q changed abs-sum:", diff)

w_q changed abs-sum: 0.7707235813140869


In [58]:
B, L, d_model, heads = 1, 6, 64, 8
mha = MultiHeadAttention(d_model=d_model, heads=heads).to(device)

x = torch.randn(B, L, d_model, device=device)

# 마지막 3개 토큰을 pad(0)라고 치자
key_padding_mask = torch.tensor([[1,1,1,0,0,0]], device=device)

out_no_mask = mha(x, x, x, key_padding_mask=None, attn_mask=None)
out_pad_mask = mha(x, x, x, key_padding_mask=key_padding_mask, attn_mask=None)

print("diff(out) abs-sum:", (out_no_mask - out_pad_mask).abs().sum().item())

diff(out) abs-sum: 132.52188110351562


In [59]:
B, L, d_model, heads = 1, 6, 64, 8
mha = MultiHeadAttention(d_model=d_model, heads=heads).to(device)

x = torch.randn(B, L, d_model, device=device)

attn_mask = torch.tril(torch.ones(L, L, device=device))

out_no_mask = mha(x, x, x, attn_mask=None)
out_causal = mha(x, x, x, attn_mask=attn_mask)

print("diff(out) abs-sum:", (out_no_mask - out_causal).abs().sum().item())

diff(out) abs-sum: 149.92823791503906


In [60]:
B, L, d_model, heads = 2, 6, 64, 8
mha = MultiHeadAttention(d_model=d_model, heads=heads).to(device)
x = torch.randn(B, L, d_model, device=device)

key_padding_mask = torch.tensor([
    [1,1,1,1,0,0],
    [1,1,0,0,0,0],
], device=device)

attn_mask = torch.tril(torch.ones(L, L, device=device))

out = mha(x, x, x, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
print("out shape:", out.shape)
print("out finite?", torch.isfinite(out).all().item())

out shape: torch.Size([2, 6, 64])
out finite? True
