In [33]:
import torch, torch.nn as nn
import math

In [34]:
class MultiHeadAttention(nn.Module):

    def __init__(self, hidden_size: int, num_heads: int, bias: bool = False):
        super.__init__()
        assert hidden_size%num_heads == 0
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        
        ''' linear layer to generate Q,K,V matrice'''
        self.Wqkv = nn.Linear(hidden_size, hidden_size*3, bias=bias)

        '''final projection layer'''
        self.Wo = nn.Linear(hidden_size, hidden_size, bias=bias)

        ''' add droput layers'''
        self.attn_dropout = nn.Dropout(p=0.1)
        self.out_dropout = nn.Dropout(p=0.1)

    def forward(self, X:torch.Tensor):
        
        '''B=batch size, S=sequence length, C=input dimension'''
        B, S, C = X.shape

        ''' split into q, k, v '''
        _attn = self.Wqkv(X).reshape(B,S,3,self.num_heads,C//self.num_heads)
        q, k, v = _attn.transpose(3,1).unbind(dim=2)

        '''compute dot product of q and k transpose'''
        attn = q@k.transpose(-2,-1)
        '''scale the dot product by dk'''
        attn=attn/math.sqrt(k.size(-1))
        '''softmax the output'''
        attn=attn.softmax(dim=-1)
        '''add dropout to attention'''
        attn=self.attn_dropout(attn)
        '''dot product with v'''
        attn = attn @ v

        '''final projected output'''
        proj_op = self.Wo(attn.transpose(1,2).reshape(B,S,C))

        return self.out_dropout(proj_op)

Explanation below with example

In [35]:
class MultiHeadAttention(nn.Module):

    def __init__(self, hidden_size: int, num_heads: int, bias: bool = False):
        super().__init__()
        assert hidden_size%num_heads == 0
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        
        ''' linear layer to generate Q,K,V matrice'''
        self.Wqkv = nn.Linear(hidden_size, hidden_size*3, bias=bias)

        '''final projection layer'''
        self.Wo = nn.Linear(hidden_size, hidden_size, bias=bias)

    def forward(self, X:torch.Tensor):
        B, S, C = X.shape

        attn = self.Wqkv(X)

        return attn

In [36]:
attn_network = MultiHeadAttention(8, 2) #hidden_dim=embedding_dim=8 , num_heads=2

In [37]:
X= torch.randn(1,3,8)  # assume input batch_size=1, seq_len=3, hidden_dim=8
output = attn_network(X)

In [38]:
output.shape

torch.Size([1, 3, 24])

In [39]:
output

tensor([[[ 1.5973,  0.2666,  0.2746,  0.4936,  0.6298,  0.4234,  0.4988,
           0.0929, -0.3498,  0.1395,  0.3523, -0.6470, -1.6516,  0.0221,
          -0.8677, -0.2388,  0.3012,  0.7000,  0.7043,  0.1389,  0.3321,
           0.0668,  0.1967,  0.3129],
         [ 1.1910, -0.0927,  0.3379,  0.7205,  0.1519, -0.6670,  0.0610,
           0.7899, -0.8511, -0.9471,  0.1259, -0.1026, -0.6632,  0.1835,
          -0.6266, -1.1363, -0.3523,  0.9304, -0.0053,  0.4833, -0.8110,
          -0.1459,  0.0292,  0.4406],
         [ 0.8885, -0.7931,  0.5758,  1.5568,  0.1038, -0.8108,  0.6113,
          -0.1354, -1.1147, -0.6673,  0.7566,  0.4825, -0.2541,  0.4908,
           0.0200, -0.2089, -0.3127,  1.8513, -0.6349, -0.6648, -0.6157,
          -0.7970,  0.1724, -0.1503]]], grad_fn=<UnsafeViewBackward0>)

In [40]:
output = output.reshape(1, 3, 3, 2, 4)

In [41]:
output

tensor([[[[[ 1.5973,  0.2666,  0.2746,  0.4936],
           [ 0.6298,  0.4234,  0.4988,  0.0929]],

          [[-0.3498,  0.1395,  0.3523, -0.6470],
           [-1.6516,  0.0221, -0.8677, -0.2388]],

          [[ 0.3012,  0.7000,  0.7043,  0.1389],
           [ 0.3321,  0.0668,  0.1967,  0.3129]]],


         [[[ 1.1910, -0.0927,  0.3379,  0.7205],
           [ 0.1519, -0.6670,  0.0610,  0.7899]],

          [[-0.8511, -0.9471,  0.1259, -0.1026],
           [-0.6632,  0.1835, -0.6266, -1.1363]],

          [[-0.3523,  0.9304, -0.0053,  0.4833],
           [-0.8110, -0.1459,  0.0292,  0.4406]]],


         [[[ 0.8885, -0.7931,  0.5758,  1.5568],
           [ 0.1038, -0.8108,  0.6113, -0.1354]],

          [[-1.1147, -0.6673,  0.7566,  0.4825],
           [-0.2541,  0.4908,  0.0200, -0.2089]],

          [[-0.3127,  1.8513, -0.6349, -0.6648],
           [-0.6157, -0.7970,  0.1724, -0.1503]]]]],
       grad_fn=<ReshapeAliasBackward0>)

In [42]:
output = output.transpose(3,1)
output

tensor([[[[[ 1.5973,  0.2666,  0.2746,  0.4936],
           [ 1.1910, -0.0927,  0.3379,  0.7205],
           [ 0.8885, -0.7931,  0.5758,  1.5568]],

          [[-0.3498,  0.1395,  0.3523, -0.6470],
           [-0.8511, -0.9471,  0.1259, -0.1026],
           [-1.1147, -0.6673,  0.7566,  0.4825]],

          [[ 0.3012,  0.7000,  0.7043,  0.1389],
           [-0.3523,  0.9304, -0.0053,  0.4833],
           [-0.3127,  1.8513, -0.6349, -0.6648]]],


         [[[ 0.6298,  0.4234,  0.4988,  0.0929],
           [ 0.1519, -0.6670,  0.0610,  0.7899],
           [ 0.1038, -0.8108,  0.6113, -0.1354]],

          [[-1.6516,  0.0221, -0.8677, -0.2388],
           [-0.6632,  0.1835, -0.6266, -1.1363],
           [-0.2541,  0.4908,  0.0200, -0.2089]],

          [[ 0.3321,  0.0668,  0.1967,  0.3129],
           [-0.8110, -0.1459,  0.0292,  0.4406],
           [-0.6157, -0.7970,  0.1724, -0.1503]]]]],
       grad_fn=<TransposeBackward0>)

In [43]:
output.shape

torch.Size([1, 2, 3, 3, 4])

In [44]:
q,k,v = output.unbind(dim=2)
q

tensor([[[[ 1.5973,  0.2666,  0.2746,  0.4936],
          [ 1.1910, -0.0927,  0.3379,  0.7205],
          [ 0.8885, -0.7931,  0.5758,  1.5568]],

         [[ 0.6298,  0.4234,  0.4988,  0.0929],
          [ 0.1519, -0.6670,  0.0610,  0.7899],
          [ 0.1038, -0.8108,  0.6113, -0.1354]]]], grad_fn=<UnbindBackward0>)

In [45]:
q.shape

torch.Size([1, 2, 3, 4])