In [1]:
import torch
import torch.nn as nn
import math

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float) -> None: 
        # d_model: feature length of token
        # h: number of heads
        super().__init__()
        self.d_model = d_model
        self.h = h

        # d_model % num_heads should be zero
        assert d_model % h == 0, "d_model % num_heads should be zero" 
        self.d_k = d_model // h

        self.w_q = nn.Linear(d_model, d_model, bias=False) # parameter matrix for query W_Q
        self.w_k = nn.Linear(d_model, d_model, bias=False) # parameter matrix for key W_K
        self.w_v = nn.Linear(d_model, d_model, bias=False) # parameter matrix for value W_V
        self.w_o = nn.Linear(d_model, d_model, bias=False) # parameter matrix for output W_O
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query_k, key_k, value_k, d_k, mask=None, dropout=nn.Dropout):
        # query_k: [batch_size, h, seq_len, d_k]
        # key_k: [batch_size, h, seq_len, d_k]
        # value_k: [batch_size, h, seq_len, d_k]
        # mask: [batch_size, 1, seq_len, seq_len]

        attention_score = (query_k @ key_k.transpose(-2, -1)) / math.sqrt(d_k) # [batch_size, h, seq_len, seq_len]

        if mask is not None:
            attention_score = attention_score.masked_fill(mask == 0, -1e9)
        
        attention_score = torch.softmax(attention_score, dim=-1) # [batch_size, h, seq_len, seq_len]
        attention_score = dropout(attention_score)

        return attention_score @ value_k, attention_score # [batch_size, h, seq_len, d_k], [batch_size, h, seq_len, seq_len]
    
    def forward(self, query, key, value, mask=None):
        # query: [batch_size, seq_len, d_model]
        # key: [batch_size, seq_len, d_model]
        # value: [batch_size, seq_len, d_model]
        # mask: [batch_size, 1, seq_len, seq_len]

        query_k = self.w_q(query) # [batch_size, seq_len, d_model] -> [batch_size, seq_len, d_model]
        key_k = self.w_k(key) # [batch_size, seq_len, d_model] -> [batch_size, seq_len, d_model]
        value_k = self.w_v(value) # [batch_size, seq_len, d_model] -> [batch_size, seq_len, d_model]

        # [batch_size, seq_len, d_model] -> [batch_size, seq_len, h, d_k] -> [batch_size, h, seq_len, d_k]
        query_k = query_k.view(query_k.shape[0], query_k.shape[1], self.h, self.d_k).transpose(1, 2)
        key_k = key_k.view(key_k.shape[0], key_k.shape[1], self.h, self.d_k).transpose(1, 2)
        value_k = value_k.view(value_k.shape[0], value_k.shape[1], self.h, self.d_k).transpose(1, 2)

        # Calculate attention  
        attention, _ = self.attention(query_k, key_k, value_k, self.d_k, mask, self.dropout)

        # Concatenate h heads
        # [batch_size, h, seq_len, d_k] -> [batch_size, seq_len, h, d_k] -> [batch_size, seq_len, d_model]
        attention = attention.transpose(1, 2).contiguous().view(attention.shape[0], -1, self.d_model)

        return self.w_o(attention) # [batch_size, seq_len, d_model]

## Example to calculate multi-head attention

In [8]:
# Define the dimensions, number of heads, and dropout rate
d_model = 512
h = 8
dropout = 0.1

# Create an instance of the MultiHeadAttention class
multi_head_attention = MultiHeadAttention(d_model, h, dropout)

# Create random tensors to represent a batch of sequences for query, key, and value
query = torch.rand(10, 20, d_model)  # batch_size=10, seq_len=20, d_model=512
key = torch.rand(10, 20, d_model)  # batch_size=10, seq_len=20, d_model=512
value = torch.rand(10, 20, d_model)  # batch_size=10, seq_len=20, d_model=512

# Pass the tensors through the multi-head attention layer
output = multi_head_attention(query, key, value)

print(output.shape)  # Should print: torch.Size([10, 20, 512])

torch.Size([10, 20, 512])


## Example to calculate attention and attention score on single head

In [16]:
# Define the dimensions
d_k = 64

# Create an instance of the MultiHeadAttention class
multi_head_attention = MultiHeadAttention(d_model=512, h=8, dropout=0.1)

# Create random tensors to represent a batch of sequences for query, key, and value
query_k = torch.rand(1, 8, 6, d_k)  # batch_size=1, h=8, seq_len=6, d_k=64
key_k = torch.rand(1, 8, 6, d_k)  # batch_size=1, h=8, seq_len=6, d_k=64
value_k = torch.rand(1, 8, 6, d_k)  # batch_size=1, h=8, seq_len=6, d_k=64

# Pass the tensors through the attention method
output, attention_score = MultiHeadAttention.attention(query_k, key_k, value_k, d_k, dropout=nn.Dropout(0.1))

print(output.shape)  # Should print: torch.Size([10, 8, 20, 64])
print(attention_score.shape)  # Should print: torch.Size([10, 8, 20, 20])

torch.Size([1, 8, 6, 64])
torch.Size([1, 8, 6, 6])


In [17]:
print(attention_score)

tensor([[[[0.1626, 0.1937, 0.1655, 0.1733, 0.2434, 0.1725],
          [0.1569, 0.2007, 0.1504, 0.1910, 0.2289, 0.1832],
          [0.1969, 0.0000, 0.1830, 0.1584, 0.2203, 0.1938],
          [0.1353, 0.0000, 0.1941, 0.1564, 0.2266, 0.2024],
          [0.1752, 0.1983, 0.1429, 0.1692, 0.2243, 0.0000],
          [0.1461, 0.2089, 0.1711, 0.1633, 0.2581, 0.1637]],

         [[0.0000, 0.1600, 0.1884, 0.2143, 0.1807, 0.1928],
          [0.2084, 0.1580, 0.2117, 0.0000, 0.1749, 0.1757],
          [0.1747, 0.1441, 0.2124, 0.2295, 0.1589, 0.1916],
          [0.1970, 0.0000, 0.2190, 0.1816, 0.1817, 0.1895],
          [0.1533, 0.1618, 0.2093, 0.2504, 0.1704, 0.1659],
          [0.1727, 0.1565, 0.0000, 0.2222, 0.0000, 0.1967]],

         [[0.1684, 0.1952, 0.1737, 0.2013, 0.1876, 0.1848],
          [0.1809, 0.2249, 0.1764, 0.1714, 0.1840, 0.1735],
          [0.1814, 0.2028, 0.1689, 0.1806, 0.0000, 0.0000],
          [0.2120, 0.2057, 0.1746, 0.1879, 0.1720, 0.0000],
          [0.1787, 0.2027, 0.1890, 0