In [1]:
import torch
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size = 2
sequence_size = 10
dimension_size = 512

In [3]:
X = torch.rand(batch_size, sequence_size, dimension_size)

In [4]:
X.shape

torch.Size([2, 10, 512])

In [5]:
X

tensor([[[0.7972, 0.9107, 0.3476,  ..., 0.0161, 0.4684, 0.8573],
         [0.6835, 0.6850, 0.1866,  ..., 0.7891, 0.0558, 0.3691],
         [0.0484, 0.6103, 0.1663,  ..., 0.6652, 0.4410, 0.2054],
         ...,
         [0.5687, 0.9382, 0.0334,  ..., 0.1783, 0.8655, 0.0550],
         [0.8593, 0.1502, 0.9706,  ..., 0.7714, 0.2001, 0.9810],
         [0.1401, 0.7820, 0.0012,  ..., 0.7534, 0.1224, 0.1316]],

        [[0.7725, 0.8038, 0.5548,  ..., 0.8106, 0.9348, 0.1920],
         [0.2415, 0.3012, 0.1258,  ..., 0.5033, 0.4164, 0.5353],
         [0.6066, 0.8001, 0.3196,  ..., 0.1229, 0.9644, 0.0695],
         ...,
         [0.2941, 0.5257, 0.9837,  ..., 0.3689, 0.8956, 0.8976],
         [0.6220, 0.1326, 0.5128,  ..., 0.6187, 0.3986, 0.8743],
         [0.6737, 0.7558, 0.0177,  ..., 0.2288, 0.7728, 0.2693]]])

In [6]:
# torch.bmm is a batched matrix multiplication
raw_weights = torch.bmm(X, X.transpose(1, 2))

In [7]:
raw_weights[0]

tensor([[175.7629, 132.4185, 130.6548, 130.1246, 132.7341, 129.2438, 137.8691,
         135.2160, 127.7514, 129.3416],
        [132.4185, 168.0042, 128.7352, 122.8057, 124.6702, 126.4963, 134.2721,
         130.4998, 121.4668, 127.2200],
        [130.6548, 128.7352, 171.8701, 124.8773, 125.3165, 122.9939, 136.0140,
         132.5821, 123.7235, 131.2590],
        [130.1246, 122.8057, 124.8773, 159.6731, 123.1352, 118.1296, 129.1505,
         128.0027, 121.0963, 122.9911],
        [132.7341, 124.6702, 125.3165, 123.1352, 164.8947, 124.8289, 131.9359,
         126.1846, 117.9964, 125.3555],
        [129.2438, 126.4963, 122.9939, 118.1296, 124.8289, 160.5341, 131.3461,
         127.8684, 122.5303, 125.5314],
        [137.8691, 134.2721, 136.0140, 129.1505, 131.9359, 131.3461, 184.3858,
         138.7914, 135.6242, 137.1956],
        [135.2160, 130.4998, 132.5821, 128.0027, 126.1846, 127.8684, 138.7914,
         174.8904, 125.0862, 134.0635],
        [127.7514, 121.4668, 123.7235, 121.0963,

In [8]:
weights = F.softmax(raw_weights, dim = -1)

In [9]:
Y = torch.bmm(weights, X)

In [10]:
Y

tensor([[[0.7972, 0.9107, 0.3476,  ..., 0.0161, 0.4684, 0.8573],
         [0.6835, 0.6850, 0.1866,  ..., 0.7891, 0.0558, 0.3691],
         [0.0484, 0.6103, 0.1663,  ..., 0.6652, 0.4410, 0.2054],
         ...,
         [0.5687, 0.9382, 0.0334,  ..., 0.1783, 0.8655, 0.0550],
         [0.8593, 0.1502, 0.9706,  ..., 0.7714, 0.2001, 0.9810],
         [0.1401, 0.7820, 0.0012,  ..., 0.7534, 0.1224, 0.1316]],

        [[0.7725, 0.8038, 0.5548,  ..., 0.8106, 0.9348, 0.1920],
         [0.2415, 0.3012, 0.1258,  ..., 0.5033, 0.4164, 0.5353],
         [0.6066, 0.8001, 0.3196,  ..., 0.1229, 0.9644, 0.0695],
         ...,
         [0.2941, 0.5257, 0.9837,  ..., 0.3689, 0.8956, 0.8976],
         [0.6220, 0.1326, 0.5128,  ..., 0.6187, 0.3986, 0.8743],
         [0.6737, 0.7558, 0.0177,  ..., 0.2288, 0.7728, 0.2693]]])