In [None]:
import numpy as np
import torch
from torch import nn

class ScaledDotProductAttention(nn.Module):
    def __init__(self,d_k:int)->None:
        super().__init__()
        self.d_k=d_k

    def forward(
        self,
        q:torch.Tensor,
        k:torch.Tensor,
        v:torch.Tensor,
        mask:torch.Tensor=None,
    )->torch.Tensor:
        scalar=np.sqrt(self.d_k)
        attention_weight=torch.matmul(q,torch.transpose(k,1,2))/scalar

        if mask is not None:
            if mask.dim()!=attention_weight.dim():
                raise ValueError(
                    "mask.dim!=attention_weight.dim,mask.dim={},attention_weight.dim={}".fromat(
                        mask.dim(),attention_weight.dim()
                    )
                )
            attention_weight=attention_weight.data.masked_fill_(
                mask,-torch.finfo(torch.float).max
            )

        attention_weight=nn.functional.softmax(attention_weight,dim=2)
        return torch.matmul(attention_weight,v)

In [None]:
import torch
from layers.transformer.ScaledDOtProductAttention import ScaledDotProductAttention
from torch import nn

class MultiHeadAttention(nn.Module):
    def __init__(self,d_model:int,h:int)->None:
        super().__init__()
        self.d_model=d_model
        self.h=h
        self.d_k=d_model//h
        self.d_v=d_model//h

        self.W_k=nn.Parameter(
            torch.Tensor(h,d_model,self.d_k)
        )

        self.W_q=nn.Parameter(
            torch.Tensor(h,d_model,self.d_k)
        )

        self.W_v=nn.Parameter(
            torch.Tensor(h,d_model,self.d_v)
        )

        self.scaled_dot_product_attention=ScaledDotProductAttention(self.d_k)

        self.linear=nn.Linear(h*self.d_v,d_model)

def forward(
    self,
    q:troch.Tensor,
    k:torch.Tensor,
    v:torch.Tensor,
    mask_3d:torch.Tensor=None,
)->torch.Tensor:

    batch_size,seq_len=q.size(0),q.size(1)
    q=q.repeat(self.h,1,1,1)
    k=k.repeat(self.h,1,1,1)
    v=v.repeat(self.h,1,1,1)

    q=torch.einsum(
        "hijk,hkl->hijl",(q,self.W_q)
    )
    k=torch.einsum(
        "hijk,hkl->hijl",(k,self.W_k)
    )
    v=torch.einsum(
        "hijk,hkl->hijl",(v,self.W_v)
    )

    q=q.view(self.h*batch_size,seq_len,self.d_k)
    k=k.view(self.h*batch_size,seq_len,self.d_k)
    q=q.view(self.h*batch_size,seq_len,self.d_v)

    if mask_3d is not None:
        mask_3d=mask_3d.repeat(self.h,1,1)

    attention_output=self.scaled_dot_product_attention(
        q,k,v,mask_3d
    )

    attention_output=torch.chunk(attention_output,self.h,dim=0)
    attention_output=torch.cat(attention_output,dim=2)

    output=self.linear(attention_output)
    return output