In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embedding_dim, num_head):
        super(MultiHeadSelfAttention, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_head = num_head
        self.head_dim = self.embedding_dim // self.num_head

        assert self.embedding_dim == self.num_head * self.head_dim, "embedding dimension should be divisible by num of heads"
        self.wq = nn.Linear(self.head_dim, self.head_dim, bias = False)
        self.wk = nn.Linear(self.head_dim, self.head_dim, bias = False)
        self.wv = nn.Linear(self.head_dim, self.head_dim, bias = False)


    def forward(self, query, key, value):
        batch_size, sequence_length, embedding_dim = query.shape
        
        # [b, num_head, seq_len, head_dim]
        query = query.view(batch_size, sequence_length, self.head_dim, -1).permute(0, 3, 1, 2).contiguous()
        key = key.view(batch_size, sequence_length, self.head_dim, -1).permute(0, 3, 1, 2).contiguous()
        value = value.view(batch_size, sequence_length, self.head_dim, -1).permute(0, 3, 1, 2).contiguous()

        query = self.wq(query)
        key = self.wk(key)
        value = self.wv(value)

        attention_score = torch.matmul(query, key.transpose(2,3)) * (self.head_dim ** -0.5)
        attention_score = F.softmax(attention_score, dim = -1)
        attention_out = torch.matmul(attention_score, value)
        attention_out = attention_out.transpose(1,2).flatten(2)
        
        return attention_out

In [None]:
query = torch.randn([1, 3, 80])
key = torch.randn([1, 3, 80])
value = torch.randn([1, 3, 80])
num_heads = 8
embedding_dim = 80
multi_head_self_attention = MultiHeadSelfAttention(embedding_dim = 80, num_head = 10)
out = multi_head_self_attention(query, key, value)
out.shape