# 手撕Transformer之CrossAttention

In [4]:
import torch
from torch import nn
import torch.nn.functional as F
import math

In [5]:
class Multiheadattention(nn.Module):
    def __init__(self, input_dim, heads, d_model):
        super(Multiheadattention, self).__init__()
        self.d_model = d_model
        self.head_dim = self.d_model // heads
        self.heads_num = heads
        self.input_dim = input_dim

        self.to_q = nn.Linear(self.input_dim, self.d_model)   # batch_size, input_dim, d_model
        self.to_k = nn.Linear(self.input_dim, self.d_model)   # batch_size, input_dim, d_model
        self.to_v = nn.Linear(self.input_dim, self.d_model)   # batch_size, input_dim, d_model
        self.to_out = nn.Linear(self.d_model, self.input_dim)   # batch_size, input_dim, d_model

    def forward(self, q, k, v):
        bs = q.shape[0]
        q = self.to_q(q).view(bs, -1, self.heads_num, self.head_dim).transpose(1,2) # batch_size, seq_len, head_num, head_dim -> batch_size, head_num, seq_len, head_dim 
        k = self.to_k(k).view(bs, -1, self.heads_num, self.head_dim).transpose(1,2) # batch_size, seq_len, head_num, head_dim -> batch_size, head_num, seq_len, head_dim 
        v = self.to_v(v).view(bs, -1, self.heads_num, self.head_dim).transpose(1,2) # batch_size, seq_len, head_num, head_dim -> batch_size, head_num, seq_len, head_dim 
        scores = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(self.head_dim)
        scores = F.softmax(scores, dim=-1)
        out = torch.matmul(scores, v)    # batch_size, seq_len, head_num, head_dim
        out = out.transpose(1,2).contiguous().view(bs, -1, self.d_model)
        out = self.to_out(out)
        return out

In [6]:
heads = 2
batch_size = 4
input_dim = 32

multiheadattn = Multiheadattention(input_dim, heads, input_dim)

q = torch.randn(batch_size,256,input_dim)
k = torch.randn(batch_size,77,input_dim)
v = torch.randn(batch_size,77,input_dim)

out = multiheadattn(q,k,v)
print(out.shape)

torch.Size([4, 256, 32])
