# MultiHead Attention

<img src="./multihead_attention.png" width="500" height="400">

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        
        self.d_model = d_model
        self.n_head = n_head
        self.n_d = d_model // n_head

        self.w_q= nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

    def forward(self, q, k, v):
        batch_size,time,dimension=q.shape

        q,k,v=self.w_q(q),self.w_k(k),self.w_v(v)
        q=q.view(batch_size,time,self.n_head,self.n_d).permute(0,2,1,3)
        k=k.view(batch_size,time,self.n_head,self.n_d).permute(0,2,1,3)
        v=v.view(batch_size,time,self.n_head,self.n_d).permute(0,2,1,3)

        score=q@k.transpose(2,3)/math.sqrt(self.n_d)

        mask=torch.tril(torch.ones(time,time,dtype=bool))
        score=score.masked_fill(mask==0,float("-inf"))

        out=F.softmax(score,dim=-1)@v
        out=out.permute(0,2,1,3).contiguous().view(batch_size,time,self.d_model)
        out=self.w_o(out)

        return out

X=torch.rand(128,64,512)
d_model=512 
n_head=8

attention=MultiHeadAttention(d_model,n_head)
output=attention(X,X,X)
print(output,output.shape)








tensor([[[ 0.1877,  0.1657, -0.0410,  ..., -0.2167, -0.0197,  0.1195],
         [ 0.0885,  0.1367, -0.0774,  ..., -0.1746, -0.0669,  0.1598],
         [ 0.1100,  0.1292, -0.0874,  ..., -0.1688, -0.0893,  0.2096],
         ...,
         [ 0.0604,  0.1330, -0.1173,  ..., -0.0713, -0.0913,  0.1809],
         [ 0.0596,  0.1345, -0.1183,  ..., -0.0724, -0.0928,  0.1801],
         [ 0.0593,  0.1320, -0.1162,  ..., -0.0743, -0.0932,  0.1779]],

        [[ 0.0258, -0.1764, -0.1212,  ..., -0.0511, -0.2387,  0.2261],
         [ 0.1300,  0.0085, -0.1791,  ..., -0.1085, -0.2460,  0.1905],
         [ 0.0437,  0.0695, -0.1528,  ..., -0.0671, -0.2192,  0.1942],
         ...,
         [ 0.0515,  0.1371, -0.1193,  ..., -0.0898, -0.1102,  0.1712],
         [ 0.0509,  0.1384, -0.1182,  ..., -0.0896, -0.1107,  0.1692],
         [ 0.0517,  0.1390, -0.1204,  ..., -0.0894, -0.1110,  0.1687]],

        [[-0.0492,  0.1312, -0.1478,  ..., -0.1601, -0.2125,  0.2068],
         [ 0.0322,  0.0366, -0.1446,  ..., -0