In [21]:
import torch
from torch import nn
import numpy as np

class ScaledDotProductAttention(nn.Module):
    """ Scaled Dot-Product Attention """

    def __init__(self, scale):
        super().__init__()

        self.scale = scale
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):
        u = torch.bmm(q, k.transpose(1, 2)) # 1.batched matrix multiplication 就是叉乘
        u = u / self.scale # 2.Scale

        if mask is not None:
            u = u.masked_fill(mask, -np.inf) # 3.Mask

        attn = self.softmax(u) # 4.Softmax
        output = torch.bmm(attn, v) # 5.Output

        return attn, output

if __name__ == "__main__":
    n_q, n_k, n_v = 2, 4, 4
    d_q, d_k, d_v = 128, 128, 64
    
    batch = 32

    q = torch.randn(batch, n_q, d_q)
    k = torch.randn(batch, n_k, d_k)
    v = torch.randn(batch, n_v, d_v)
    mask = torch.zeros(batch, n_q, n_k).bool()

    attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
    attn, output = attention(q, k, v, mask=mask)

    # print(attn)
    # print(output)


In [43]:
class MultiHeadAttention(nn.Module):
    """ Multi-Head Attention """

    def __init__(self, n_head, d_k_, d_v_, d_k, d_v, d_o):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.fc_q = nn.Linear(d_k_, n_head * d_k)
        self.fc_k = nn.Linear(d_k_, n_head * d_k)
        self.fc_v = nn.Linear(d_v_, n_head * d_v)

        self.attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))

        self.fc_o = nn.Linear(n_head * d_v, d_o)

    def forward(self, q, k, v, mask=None):

        n_head, d_q, d_k, d_v = self.n_head, self.d_k, self.d_k, self.d_v

        batch, n_q, d_q_ = q.size()
        batch, n_k, d_k_ = k.size()
        batch, n_v, d_v_ = v.size()

        q = self.fc_q(q) # 1.单头变多头
        k = self.fc_k(k)
        v = self.fc_v(v)

        q = q.view(batch, n_q, n_head, d_q).permute(2, 0, 1, 3).contiguous().view(-1, n_q, d_q)
        k = k.view(batch, n_k, n_head, d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_k, d_k)
        v = v.view(batch, n_v, n_head, d_v).permute(2, 0, 1, 3).contiguous().view(-1, n_v, d_v)

        if mask is not None:
            mask = mask.repeat(n_head, 1, 1)
        attn, output = self.attention(q, k, v, mask=mask) # 2.当成单头注意力求输出

        output = output.view(n_head, batch, n_q, d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1) # 3.Concat
        output = self.fc_o(output) # 4.仿射变换得到最终输出

        return attn, output


if __name__ == "__main__":
    n_q, n_k, n_v = 2, 4, 4
    d_q_, d_k_, d_v_ = 128, 128, 64

    q = torch.randn(batch, n_q, d_q_)
    k = torch.randn(batch, n_k, d_k_)
    v = torch.randn(batch, n_v, d_v_)
    mask = torch.zeros(batch, n_q, n_k).bool()

    mha = MultiHeadAttention(n_head=8, d_k_=128, d_v_=64, d_k=256, d_v=128, d_o=128)
    attn, output = mha(q, k, v, mask=mask)

    # print(attn.size())
    # print(output.size())

## 我自己写了一个self-attention

In [1]:
import torch
from torch import nn
import numpy as np

class self_attention(nn.Module):
    def __init__(self, num_head, n_q, n_k, n_v, d_q, d_k, d_v):
        super().__init__()
        self.num_head, self.n_q, self.n_k, self.n_v, self.d_q, self.d_k, self.d_v = num_head, n_q, n_k, n_v, d_q, d_k, d_v

        self.FC_q = nn.Linear(self.n_q, num_head * self.n_q)
        self.FC_k = nn.Linear(self.n_k, num_head * self.n_k)
        self.FC_v = nn.Linear(self.n_v, num_head * self.n_v)

        self.softmax = nn.Softmax(dim=2)
    
    def forward(self, q, k, v):
        qk = torch.bmm(q, k.transpose(1,2))
        qk /= np.sqrt(self.d_k)
        qk = self.softmax(qk)
        output = torch.bmm(qk, v)

        return output
        
        

if __name__ == "__main__":
    n_q, n_k, n_v = 100, 200, 200
    d_q, d_k, d_v = 128, 128, 128
    num_head = 6

    batch_size = 32

    atten = self_attention(num_head, n_q, n_k, n_v, d_q, d_k, d_v)

    q = torch.randn(batch_size, n_q, d_q)
    k = torch.randn(batch_size, n_k, d_k)
    v = torch.randn(batch_size, n_v, d_v)

    output = atten(q, k, v)

    print(output.size())

torch.Size([32, 100, 128])
