In [2]:
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim=256, num_heads=4):
        """
        hidden_dim: 输入的维度
        num_heads: 输入分成的注意头数量

        维护的是
        """
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        assert hidden_dim % num_heads == 0, "hidden_dim must be the integer times of num_heads"
        # query, key, value, and output
        self.Wq = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.Wk = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.Wv = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.Wo = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def _check_scaled_dot_product_attention_inputs(self, x):
        """
        check scaled dot-product attention inputs
        """
        assert x.size(1) == self.num_heads, f"expects that x has shape as:" \
             f" ({-1, self.num_heads, -1, self.hidden_dim // self.num_heads})," \
             f"but get {x.size()}"
        assert x.size(3) == self.hidden_dim

    def _scaled_dot_product_attention(self, query, key, value, 
                                      attention_mask=None, key_padding_mask=None):
        """
        query: tensor, shape is (batch_size, num_heads, query_sequence_length, hidden_dim // num_heads)
        key: tensor, shape is (batch_size, num_heads, key_sequence_length, hidden_dim // num_heads)
        value: tensor, shape is (batch_size, num_heads, value_sequence_length, hidden_dim // num_heads)
        attention_mask: tensor, shape is (query_sequence_length, key_sequence_length)
        key_padding_mask: tensor, shape is (sequence_length, key_sequence_length)

        query最开始是(batch_size, query_sequence_length, hidden_dim)的,w，经过split_into_heads后的结果。
        """
        self.check_scaled_dot_product_attention_inputs(query)
        self.check_scaled_dot_product_attention_inputs(key)
        self.check_scaled_dot_product_attention_inputs(value)

        d_k = key.size(-1)
        tgt_len, src_len = query.size(-2), key.size(-2)

        # logits = (B, H, tgt_len, E) * (B, H, E, src_len) = (B, H, tgt_len, src_len)
        logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

        # 注意力遮罩
        if attention_mask:
            if attention_mask.dim() == 2:
                assert attention_mask.size() == (tgt_len, src_len)
                # 广播到 (1, query_sequence_length, key_sequence_length)
                attention_mask = attention_mask.unsqueeze(0)
                logits = logits + attention_mask
            else:
                raise ValueError(f"attention_mask.size() is invalid: {attention_mask.size()}")

        if key_padding_mask:
            # Broadcast to fit logits
            # 广播到 query_sequence_length, 1, 1, key_sequence_length)
            key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
            logits = logits + key_padding_mask
        
        attention = torch.softmax(logits, dim=-1)
        # (batch_size, num_heads, sequence_length, hidden_dim)
        output = torch.matmul(attention, value)
        return output, attention

    def _split_into_heads(self, x, num_heads):
        batch_size, seq_length, hidden_dim = x.size()
        """
        我们以q为例：
        对于每hidden_dim个元素，我们拆分成(num_heads, d_k)的形式，所以对应着，
        源输入的x是hidden_dim个都被Wq转换过，所以
        """
        x = x.view(batch_size, seq_length, num_heads, hidden_dim // num_heads)
        # 最终返回(batch_size, num_heads, seq_length, hidden_dim // num_heads)
        return x.transpose(1, 2)

    # 上一个方法的逆操作
    def combine_heads(self, x):
        batch_size, num_heads, seq_length, head_hidden_dim = x.size()
        return x.transpose(1, 2).contiguous().view(
            batch_size, seq_length, num_heads * head_hidden_dim
        )

    def forward(self, q, k, v, attention_mask=None, key_padding_mask=None):
        """
        q: tensor, shape is (batch_size, query_sequence_length, hidden_dim)
        k: tensor, shape is (batch_size, key_sequence_length, hidden_dim)
        v: tensor, shape is (batch_size, value_sequence_length, hidden_dim)
        attention_mask: tensor, shape is (query_sequence_length, key_sequence_length)
        key_padding_mask: tensor, shape is (sequence_length, key_sequence_length)
        """
        q = self.Wq(q)
        k = self.Wk(k)
        v = self.Wv(v)

        q = self.split_into_heads(q, self.num_heads)
        k = self.split_into_heads(k, self.num_heads)
        v = self.split_into_heads(v, self.num_heads)

        attention_values, attention_weights = self.scaled_dot_product_attention(
            q, k, v, attention_mask, key_padding_mask)
        grouped = self.combine_heads(attention_values)
        output = self.Wo(grouped)

        self.attention_weights = attention_weights
        return output


In [20]:
import torch.nn as nn
l = nn.Linear(20, 30, bias=False)
print(l.weight.shape)
print(l.bias)

import torch
t1 = torch.arange(24).reshape((2, 3, 4)).repeat((4, 1, 1, 1))
# t2 = torch.arange(40*2).reshape((2, 2, 4, 5)) # error
t2 = torch.arange(40).reshape((2, 4, 5)).repeat((4, 1, 1, 1))
t3 = t1 @ t2
print(t3.shape)
assert(torch.allclose(t3[0, 1], t1[0, 1] @ t2[0, 1]))
assert(torch.allclose(t3[0, 1], t1[2, 1] @ t2[0, 1]))

torch.Size([30, 20])
None
torch.Size([4, 2, 3, 5])


In [25]:
import torch
x = torch.arange(24).reshape(3, 8)
z = x.view(-1, 3, 4)
print(x)
print(z)

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20, 21, 22, 23]])
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
