# Masked Multi-Head Attetion

1. implement Masked Multi-head Attention
2. implement Encoder-Decoder Attention

## import packages

In [None]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

## Data refinement

to observe values and shapes in more detail, decrease the size of sample

In [None]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
  [66, 88, 98, 47],
  [77, 65, 51, 77, 19, 15, 35, 19, 23]
]

In [None]:
def padding(data):
  max_len = len(max(data, key = len))
  print(f"\nMaximum sequence length: {max_len}")

  for i, seq in enumerate(tqdm(data)):
    if len(seq)<max_len:
      data[i] = seq + [pad_id] * (max_len-len(seq))
  
  return data, max_len

In [None]:
data, max_len = padding(data)

100%|██████████| 5/5 [00:00<00:00, 1282.27it/s]


Maximum sequence length: 10





In [None]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0],
 [77, 65, 51, 77, 19, 15, 35, 19, 23, 0]]

## Set Hyperparameters and Embedding

In [None]:
d_model = 8   # hidden_size of model
num_heads = 2 # the number of heads
inf = 1e12

In [None]:
embedding = nn.Embedding(vocab_size, d_model)

# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data)  # (B, L)
batch_emb = embedding(batch)    # (B, L, d_model)

In [None]:
print(batch_emb)
print(batch_emb.shape)

tensor([[[-1.3026, -0.1526,  0.6810,  1.1066,  0.9289,  0.8636, -0.2004,
          -1.2961],
         [ 1.7260,  0.4165,  1.6349, -0.7804, -0.8938, -1.2166,  1.1108,
           1.0554],
         [-0.4178, -1.0309,  1.5177,  1.2440, -0.9330,  1.5603, -0.3513,
          -0.6272],
         [ 0.3972,  0.6800,  1.1626, -0.4897,  0.4080,  1.3365, -0.4225,
          -0.7903],
         [-0.6695,  0.2338,  1.7958, -0.4542, -0.7971,  0.0782,  0.1521,
          -0.7014],
         [-0.4815,  1.1098,  0.3544,  0.0872, -0.2117, -1.0081, -0.0151,
           0.1299],
         [-0.0355, -1.5054,  0.0409,  1.2526, -0.7749,  1.4308,  0.4267,
          -0.1816],
         [ 1.7260,  0.4165,  1.6349, -0.7804, -0.8938, -1.2166,  1.1108,
           1.0554],
         [ 0.1417,  0.6713,  0.3581, -1.1337, -2.0103, -1.4942,  0.7149,
           0.4122],
         [ 0.1417,  0.6713,  0.3581, -1.1337, -2.0103, -1.4942,  0.7149,
           0.4122]],

        [[-0.2745,  2.0970, -1.2221, -0.1496, -0.8902,  0.6538, -0.1

## Masking

the position whose value is `True` will be used for attention

the position whose value is `False` will be masked

In [None]:
padding_mask = (batch !=pad_id).unsqueeze(1)  # (B, 1, L)

print(padding_mask)
print(padding_mask.shape)

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True,  True,  True,  True,  True, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True, False, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True, False]]])
torch.Size([5, 1, 10])


In [None]:
nopeak_mask = torch.ones([1, max_len, max_len], dtype = torch.bool) # (1, L, L)
nopeak_mask = torch.tril(nopeak_mask) # (1, L, L)

print(nopeak_mask)
print(nopeak_mask.shape)

tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]]])
torch.Size([1, 10, 10])


In [None]:
mask = padding_mask & nopeak_mask # (B, L, L)

print(mask)
print(mask.shape)

tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  T

## Linear transformation & splitting into several heads

In [None]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

w_0 = nn.Linear(d_model, d_model)

In [None]:
q = w_q(batch_emb)
k = w_k(batch_emb)
v = w_v(batch_emb)

batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

q = q.transpose(1, 2) # (B, num_heds, L, d_k)
k = k.transpose(1, 2) # (B, num_heds, L, d_k)
v = v.transpose(1, 2) # (B, num_heds, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])


## implement self-attention masking-applied

In [None]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # (B, num_heads, L, L)

In [None]:
masks = mask.unsqueeze(1) # (B, 1, L, L)
masked_attn_scores = attn_scores.masked_fill_(masks == False, -1 * inf)  # (B, num_heads, L, L)

print(masked_attn_scores)
print(masked_attn_scores.shape)

tensor([[[[-3.6156e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 3.3317e-01, -1.4631e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-6.3422e-01,  1.1257e-01, -2.9731e-01, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 2.4420e-02,  1.1599e-01, -9.1046e-02, -3.1467e-01, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-3.6400e-01,  7.8098e-02, -2.4817e-01, -9.4597e-02, -3.0795e-02,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 2.9746e-01, -2.6459e-02,  1.7034e-01,  5.8131e-02,  4.9967e-02,
            2.4040e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-4.1200e-01, -4.2126e-02, -1.2988e-01,  1.4401e-02, -1.3483e-01,
      

`-1 * inf` is gonna be `0` after softmax applied

In [None]:
attn_dists = F.softmax(masked_attn_scores, dim = -1)  # (B, num_heads, L, d_k)

print(attn_dists)
print(attn_dists.shape)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.6176, 0.3824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2217, 0.4678, 0.3105, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2703, 0.2962, 0.2408, 0.1926, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1567, 0.2438, 0.1759, 0.2051, 0.2186, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2027, 0.1466, 0.1785, 0.1596, 0.1583, 0.1542, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1047, 0.1515, 0.1388, 0.1603, 0.1381, 0.1635, 0.1432, 0.0000,
           0.0000, 0.0000],
          [0.1623, 0.1005, 0.1397, 0.1178, 0.1166, 0.1160, 0.1468, 0.1005,
           0.0000, 0.0000],
          [0.1278, 0.1197, 0.1210, 0.1106, 0.1346, 0.1346, 0.1319, 0.1197,
           0.0000, 0.0000],
          [0.1278, 0.1197, 0.1210, 0.1106, 0.1346, 0.1346, 0.1319, 0.1197

In [None]:
attn_values = torch.matmul(attn_dists, v) # (B, num_heads, L, d_k)

print(attn_values.shape)

torch.Size([5, 2, 10, 4])


## the whole code

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self):
    super(MultiHeadAttention, self).__init__()

    # Q, K, V learnerable matrices
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)

    # Linaer transformation for concatenated outputs
    self.w_0 = nn.Linear(d_model, d_model)

  def forward(self, q, k, v, mask = None):
    batch_size = q.shape[0]

    q = self.w_q(q) # (B, L, d_model)
    k = self.w_k(k) # (B, L, d_model)
    v = self.w_v(v) # (B, L, d_model)

    q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

    q = q.transpose(1, 2) # (B, num_heads, L, d_k)
    k = k.transpose(1, 2) # (B, num_heads, L, d_k)
    v = v.transpose(1, 2) # (B, num_heads, L, d_k)

    attn_values = self.self_attention(q, k, v, mask = mask) # (B, num_heads, L, d_k)
    attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model) # (B, L, num_heads, d_k) => (B, L, d_model)

    return self.w_0(attn_values)
  
  def self_attention(self, q, k, v, mask = None):
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # (B, num_heads, L, L)

    if mask is not None:
      mask = mask.unsqueeze(1)  # (B, 1, L, L) or (B, 1, 1, L)
      attn_scores = attn_scores.masked_fill_(mask == False, -1*inf)
    
    attn_dists = F.softmax(attn_scores, dim = -1) # (B, num_heads, L, L)
    attn_values = torch.matmul(attn_dists, v)     # (B, num_heads, L, d_k)

    return attn_values

In [None]:
multihead_attn = MultiHeadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb, mask = mask)  # (B, L, d_model)

In [None]:
print(outputs)
print(outputs.shape)

tensor([[[ 1.0917e+00,  2.8661e-02, -3.0873e-01,  3.3833e-01, -4.0447e-01,
           2.5567e-01, -3.2935e-01,  4.2472e-01],
         [ 1.1372e-01, -3.3286e-02,  3.8443e-01, -5.6257e-02,  3.1793e-02,
           4.5371e-01,  1.1532e-01, -2.6525e-02],
         [ 5.1314e-01, -1.1402e-01,  7.1595e-02,  2.2340e-01, -2.2535e-01,
           3.4830e-01,  1.8477e-02,  1.8242e-01],
         [ 3.4497e-01,  1.1171e-01,  8.5337e-03,  2.3904e-01, -2.2966e-02,
           2.0238e-01,  2.6598e-01,  2.4121e-01],
         [ 3.9921e-01, -3.2567e-02,  8.6769e-02,  2.3958e-01, -7.2909e-02,
           3.1351e-01,  2.0750e-01,  1.2090e-01],
         [ 2.5025e-01, -7.7398e-03,  2.3225e-01,  1.6333e-01,  2.7752e-03,
           3.4592e-01,  2.4483e-01,  6.9345e-02],
         [ 3.1394e-01, -1.8729e-02,  1.6940e-01,  1.6482e-01, -9.5959e-02,
           2.9457e-01,  1.2831e-01,  1.8098e-01],
         [ 1.3698e-01,  1.4190e-03,  2.8889e-01,  9.6515e-02,  2.6682e-02,
           3.4300e-01,  2.4716e-01,  6.7828e-02],


## Encode-Decode Attetion

only differ in Query, key, value

the implementation is same

let's implement the batch which is an input for decoder

In [None]:
trg_data = [
  [33, 11, 49, 10],
  [88, 34, 5, 29, 99, 45, 11, 25],
  [67, 25, 15, 90, 54, 4, 92, 10, 46, 20, 88 ,19],
  [16, 58, 91, 47, 12, 5, 8],
  [71, 63, 62, 7, 9, 11, 55, 91, 32, 48]
]

trg_data, trg_max_len = padding(trg_data)

100%|██████████| 5/5 [00:00<00:00, 2959.99it/s]


Maximum sequence length: 12





In [None]:
# S_L: source maximum sequence length, T_L: target maximum sequence length
src_batch = batch # (B, S_L)
trg_batch = torch.LongTensor(trg_data)  # (B, T_L)

print(src_batch.shape)
print(trg_batch.shape)

torch.Size([5, 10])
torch.Size([5, 12])


In [None]:
src_emb = embedding(src_batch)  # (B, S_L, d_w)
trg_emb = embedding(trg_batch)  # (B, S_L, d_w)

print(src_emb.shape)
print(trg_emb.shape)

torch.Size([5, 10, 8])
torch.Size([5, 12, 8])


suppose that

- `src_emb` is the output of encoder
- `trg_emb` is the output of masked multi-head attention

In [None]:
q = w_q(trg_emb)  # (B, T_L, d_model)
k = w_k(trg_emb)  # (B, T_L, d_model)
v = w_v(trg_emb)  # (B, T_L, d_model)

batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, T_L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, T_L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, T_L, num_heads, d_k)

q = q.transpose(1, 2) # (B, num_heads, T_L, d_k)
k = k.transpose(1, 2) # (B, num_heads, T_L, d_k)
v = v.transpose(1, 2) # (B, num_heads, T_L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([5, 2, 12, 4])
torch.Size([5, 2, 12, 4])
torch.Size([5, 2, 12, 4])


In [None]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # (B, num_heads, T_L, S_L)
attn_dists = F.softmax(attn_scores, dim = -1) # (B, num_heads, T_L, S_L)

print(attn_dists.shape)

torch.Size([5, 2, 12, 12])


In [None]:
attn_values = torch.matmul(attn_dists, v) # (B, num_heads, T_L, d_k)

print(attn_values.shape)

torch.Size([5, 2, 12, 4])
