In [2]:
from math import sqrt
import torch
import torch.nn as nn

In [12]:
class SelfMultiHeadAttention(nn.Module):
    def __init__(self, input_dim, dim_k, dim_v, head_num):
        super(SelfMultiHeadAttention,self).__init__()
        
        assert dim_k % head_num==0
        assert dim_v % head_num==0
        
        self.q=nn.Linear(input_dim,dim_k)
        self.k=nn.Linear(input_dim,dim_k)
        self.v=nn.Linear(input_dim,dim_v)
        
        self.head_num = head_num
        self.dim_k = dim_k
        self.dim_v = dim_v
        self._norm_fact = 1 / sqrt(dim_k//head_num)
        
    def forward(self,x):
        # x shape: (batch_size, seq_len, input_dim)
        b, n, _ = x.shape
        
        # after transpose shape: (batch, head, seq_len, head_dim)
        Q = self.q(x).view(b, n, self.head_num, self.dim_k//self.head_num).transpose(1, 2)
        K = self.k(x).view(b, n, self.head_num, self.dim_k//self.head_num).transpose(1, 2)
        V = self.v(x).view(b, n, self.head_num, self.dim_v//self.head_num).transpose(1, 2)
        
        print(x.shape)
        print(Q.size())
        
        # scores shape: (batch, head, seq_len, seq_len)
        atten = torch.matmul(Q, K.transpose(-2, -1)) * self._norm_fact
        print(atten.size())
        atten = torch.softmax(atten, dim=-1)
        
        # output shape: (batch, seq_len, dim_v)
        output = torch.matmul(atten, V).transpose(1, 2).contiguous().view(b, n, -1)
        
        return output

In [13]:
x = torch.rand(1,3,4)
print(x)

atten = SelfMultiHeadAttention(4,4,4,2)
y=atten(x)

print(y.shape)

tensor([[[0.8009, 0.6667, 0.6815, 0.1197],
         [0.2812, 0.5741, 0.4124, 0.7451],
         [0.3060, 0.5470, 0.4089, 0.6769]]])
torch.Size([1, 3, 4])
torch.Size([1, 2, 3, 2])
torch.Size([1, 2, 3, 3])
torch.Size([1, 3, 4])
