In [49]:
import torch
import torch.nn as nn
import math

In [71]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, d_model:int, n_heads:int, ):
        self.d_model = d_model
        self.n_heads = n_heads
        assert self.d_model % self.n_heads == 0, "d_model should be divisible by n_heads!"
        self.d_h = self.d_model//self.n_heads
        
        self.LinearQ = nn.Linear(d_model, d_model)
        self.LinearK = nn.Linear(d_model, d_model)
        self.LinearV = nn.Linear(d_model, d_model)
        
        
    def forward(self, q, k, v):
        # q, k, v size: (batch size, seq_len, d_model)
        q_proj = self.LinearQ(q)  # (batch size, seq_len, d_model)
        k_proj = self.LinearK(k)
        v_proj = self.LinearV(v)
        
        q_proj = q_proj.view(q_proj.shape[0], q_proj.shape[1], self.n_heads, self.d_h).transpose(1,2) # (batch size, n_heads, seq_len, d_h)
        k_proj = k_proj.view(q_proj.shape[0], q_proj.shape[1], self.n_heads, self.d_h).transpose(1,2)
        v_proj = v_proj.view(q_proj.shape[0], q_proj.shape[1], self.n_heads, self.d_h).transpose(1,2)

        att_prob = (q_proj @ k_proj.transpose(-1,-2)/math.sqrt(self.d_h)).softmax(dim=-1) # (batch size, n_heads, seq_len, seq_len)
        att_val = (att_prob @ v_proj).transpose(1,2).view(q_proj.shape[0], q_proj.shape[1], self.d_model) # (batch size, seq_len, n_model)
        return att_val

In [48]:
m = nn.Softmax(dim=-1)
m(a @ torch.transpose(b, 1, 2))

tensor([[[0.3564, 0.1849, 0.1144, 0.1004, 0.2439],
         [0.2444, 0.2232, 0.1591, 0.1228, 0.2505],
         [0.3624, 0.1600, 0.1398, 0.1140, 0.2238],
         [0.3129, 0.1503, 0.1475, 0.1525, 0.2368],
         [0.2616, 0.2201, 0.1191, 0.1384, 0.2607]],

        [[0.1406, 0.2157, 0.1067, 0.1946, 0.3423],
         [0.1062, 0.1445, 0.0759, 0.2854, 0.3880],
         [0.1470, 0.2324, 0.0975, 0.1865, 0.3366],
         [0.1071, 0.1688, 0.0805, 0.2838, 0.3597],
         [0.0894, 0.0771, 0.1371, 0.3367, 0.3597]]])

In [69]:
a.view((1,100))

tensor([[5.1662e-01, 8.5651e-01, 2.7627e-02, 7.0715e-04, 6.7368e-01, 6.1441e-01,
         3.5076e-01, 7.8261e-01, 2.2942e-01, 9.4171e-01, 4.2464e-02, 5.7357e-01,
         6.7729e-01, 5.2001e-01, 2.0488e-01, 9.5282e-01, 5.0492e-03, 1.2875e-01,
         3.7501e-01, 9.3959e-01, 3.9402e-01, 4.0903e-01, 3.5688e-01, 5.0833e-01,
         1.6406e-01, 2.1066e-02, 3.5486e-01, 6.4652e-01, 4.9309e-03, 5.5206e-01,
         2.4028e-01, 8.2871e-01, 3.9795e-01, 1.2829e-02, 4.0384e-01, 5.3361e-01,
         7.6448e-02, 8.9446e-02, 4.1684e-01, 2.7728e-01, 2.8358e-01, 6.3485e-01,
         3.3976e-01, 3.3906e-01, 7.3859e-01, 2.3730e-02, 3.3382e-01, 9.5803e-01,
         7.1400e-02, 2.2049e-02, 3.7260e-01, 1.1008e-01, 1.9810e-01, 6.3477e-02,
         4.0891e-01, 6.1942e-01, 8.5863e-01, 5.7828e-01, 1.6506e-01, 8.8968e-01,
         3.9389e-01, 9.9047e-01, 5.4204e-01, 5.0099e-01, 7.5437e-01, 4.4286e-01,
         5.6542e-02, 2.2373e-01, 2.4036e-01, 7.7569e-01, 7.7603e-01, 3.0141e-02,
         9.5410e-01, 4.1722e

In [None]:
# implement transpose
import numpy as np

def transpose(arr, dim1, dim2):
    s = arr.shape
    

In [27]:
import numpy as np

np.zeros((2,3,4)).shape

(2, 3, 4)

In [24]:
a.shape

torch.Size([2, 5, 10])

In [23]:
torch.transpose(a, 0,2)

tensor([[[ 0, 50],
         [10, 60],
         [20, 70],
         [30, 80],
         [40, 90]],

        [[ 1, 51],
         [11, 61],
         [21, 71],
         [31, 81],
         [41, 91]],

        [[ 2, 52],
         [12, 62],
         [22, 72],
         [32, 82],
         [42, 92]],

        [[ 3, 53],
         [13, 63],
         [23, 73],
         [33, 83],
         [43, 93]],

        [[ 4, 54],
         [14, 64],
         [24, 74],
         [34, 84],
         [44, 94]],

        [[ 5, 55],
         [15, 65],
         [25, 75],
         [35, 85],
         [45, 95]],

        [[ 6, 56],
         [16, 66],
         [26, 76],
         [36, 86],
         [46, 96]],

        [[ 7, 57],
         [17, 67],
         [27, 77],
         [37, 87],
         [47, 97]],

        [[ 8, 58],
         [18, 68],
         [28, 78],
         [38, 88],
         [48, 98]],

        [[ 9, 59],
         [19, 69],
         [29, 79],
         [39, 89],
         [49, 99]]])

In [20]:
a @ torch.transpose(b, -1, -2)

tensor([[[  285,   735,  1185,  1635,  2085],
         [  735,  2185,  3635,  5085,  6535],
         [ 1185,  3635,  6085,  8535, 10985],
         [ 1635,  5085,  8535, 11985, 15435],
         [ 2085,  6535, 10985, 15435, 19885]],

        [[29785, 35235, 40685, 46135, 51585],
         [35235, 41685, 48135, 54585, 61035],
         [40685, 48135, 55585, 63035, 70485],
         [46135, 54585, 63035, 71485, 79935],
         [51585, 61035, 70485, 79935, 89385]]])

In [14]:
a @ b.transpose()

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [2, 10] but got: [2, 5].

In [6]:
a.shape

torch.Size([2, 4, 10])

In [7]:
a.view(2,4,5,2)

tensor([[[[0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.]]],


        [[[0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.]]]])