In [1]:
import sys
sys.path.append('../')
import torch
import torch.nn as nn
import math
from transformer.layers import MultiHeadAttention

## Example to calculate multi-head attention

In [2]:
# Define the dimensions, number of heads, and dropout rate
d_model = 512
h = 8
dropout = 0.1

# Create an instance of the MultiHeadAttention class
multi_head_attention = MultiHeadAttention(d_model, h, dropout)

# Create random tensors to represent a batch of sequences for query, key, and value
query = torch.rand(10, 20, d_model)  # batch_size=10, seq_len=20, d_model=512
key = torch.rand(10, 20, d_model)  # batch_size=10, seq_len=20, d_model=512
value = torch.rand(10, 20, d_model)  # batch_size=10, seq_len=20, d_model=512

# Pass the tensors through the multi-head attention layer
output = multi_head_attention(query, key, value)

print(output.shape)  # Should print: torch.Size([10, 20, 512])

torch.Size([10, 20, 512])


## Example to calculate attention and attention score on single head

In [3]:
# Define the dimensions
d_k = 64

# Create an instance of the MultiHeadAttention class
multi_head_attention = MultiHeadAttention(d_model=512, h=8, dropout=0.1)

# Create random tensors to represent a batch of sequences for query, key, and value
query_k = torch.rand(1, 8, 6, d_k)  # batch_size=1, h=8, seq_len=6, d_k=64
key_k = torch.rand(1, 8, 6, d_k)  # batch_size=1, h=8, seq_len=6, d_k=64
value_k = torch.rand(1, 8, 6, d_k)  # batch_size=1, h=8, seq_len=6, d_k=64

# Pass the tensors through the attention method
output, attention_score = MultiHeadAttention.attention(query_k, key_k, value_k, d_k, dropout=nn.Dropout(0.1))

print(output.shape)  # Should print: torch.Size([10, 8, 20, 64])
print(attention_score.shape)  # Should print: torch.Size([10, 8, 20, 20])

torch.Size([1, 8, 6, 64])
torch.Size([1, 8, 6, 6])


In [4]:
print(attention_score)

tensor([[[[0.1968, 0.1998, 0.1576, 0.1692, 0.1650, 0.2228],
          [0.1798, 0.1859, 0.1614, 0.0000, 0.1887, 0.2079],
          [0.1716, 0.1867, 0.1676, 0.2008, 0.1963, 0.0000],
          [0.1787, 0.2006, 0.1427, 0.1793, 0.1731, 0.2368],
          [0.1854, 0.2040, 0.1572, 0.1860, 0.2016, 0.1769],
          [0.1720, 0.1963, 0.0000, 0.1823, 0.1769, 0.2008]],

         [[0.1991, 0.1750, 0.1876, 0.0000, 0.2303, 0.1730],
          [0.1958, 0.1988, 0.1624, 0.1635, 0.2220, 0.1686],
          [0.0000, 0.1982, 0.1763, 0.0000, 0.2195, 0.1687],
          [0.0000, 0.2095, 0.1748, 0.1474, 0.1863, 0.1750],
          [0.2216, 0.1963, 0.1626, 0.1343, 0.2349, 0.1615],
          [0.1771, 0.1786, 0.1691, 0.0000, 0.2365, 0.1819]],

         [[0.1622, 0.2040, 0.1964, 0.0000, 0.1743, 0.2016],
          [0.0000, 0.2116, 0.1659, 0.1522, 0.2108, 0.2241],
          [0.1482, 0.2000, 0.1733, 0.1731, 0.1960, 0.2204],
          [0.1593, 0.2224, 0.1824, 0.1546, 0.1952, 0.1973],
          [0.1462, 0.2149, 0.1588, 0