In [11]:
import torch
import torch.nn as nn
from torch import einsum
import torch.nn.functional as F
import math


In [None]:
# Let's get real! Now we know attention, which is of course "all you need" ;) We will 
# implement encoder properly with in pytorch
# we use einops to simplify matrix handling syntax among other things, to install, uncomment:
# !pip3 install einops

In [2]:
print(torch.__version__) # einsum in early versions can be too slow

1.3.0


In [6]:
'''
BLOCK:
batch of tensors of embedding dim n -> self attention + id
-> layer norm -> MLP for each in batch (just plain default) + id
-> layer norm
'''
sentence_len = 9 # batch size
emb_dim = 10 
num_heads = 3
input_batch = torch.rand(sentence_len, emb_dim)
input_batch

tensor([[1.1204e-01, 3.3179e-01, 3.7844e-01, 7.2353e-01, 4.4927e-01, 1.1372e-01,
         2.7527e-01, 2.0844e-01, 3.0714e-01, 1.5609e-01],
        [4.5294e-01, 9.9741e-01, 3.5441e-01, 2.1147e-01, 1.8738e-01, 7.6277e-01,
         5.6122e-01, 9.7209e-04, 2.1788e-01, 8.6360e-01],
        [3.4282e-01, 6.3288e-01, 5.9995e-01, 5.8601e-01, 7.1222e-02, 5.2544e-01,
         3.2955e-01, 2.8433e-01, 3.7653e-01, 6.3157e-01],
        [8.2733e-01, 4.8969e-01, 4.8225e-01, 1.3652e-02, 5.5434e-01, 2.6493e-01,
         2.8982e-01, 6.8646e-01, 1.6388e-01, 4.6974e-01],
        [5.5633e-01, 5.6522e-01, 9.8118e-01, 4.0900e-01, 2.4492e-01, 1.8387e-01,
         2.1073e-01, 7.3774e-01, 1.6522e-01, 6.4002e-01],
        [5.3758e-01, 8.9740e-01, 8.6875e-01, 8.8043e-01, 1.6907e-01, 3.2457e-01,
         7.3095e-01, 7.8696e-01, 7.5072e-01, 4.4058e-01],
        [9.2194e-01, 4.5239e-01, 8.8996e-01, 9.5824e-01, 6.2432e-01, 6.7026e-01,
         7.6904e-01, 5.4834e-01, 5.6674e-01, 9.1058e-01],
        [8.5816e-01, 5.0199

In [9]:
def attention(Q, K, V):
    ''' Functional implementation for scaled dot product attention formula'''
    dot_prod = torch.matmul(Q, torch.transpose(K, -2, -1)) #swap last 2 dims, regardless of batch dim
    K_dim = K.size(-1)
    softmax = F.softmax(dot_prod/math.sqrt(K_dim), dim = -1)
    attention = torch.matmul(softmax, V)
    return attention

class SelfAttentionWide(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super().__init__()
        # we want to output the same dim as embedding to enable residual connection
        # init 3 matrices all the same
        self.M_Q, self.M_V, self.M_K = \
            [nn.Linear(emb_dim, emb_dim*num_heads, bias=False) for _ in range(3)]
        
        self.M_merge_heads = nn.Linear(emb_dim*num_heads, emb_dim, bias=False)
        
    def forward(self, x):
        # get Q, K, V
        Q = self.M_Q(x)
        K = self.M_K(x)
        V = self.M_V(x)
        multi_att = attention(Q, K, V)
        return self.M_merge_heads(multi_att)
        

In [12]:
attention_layer = SelfAttentionWide(emb_dim, num_heads)
x = attention_layer(input_batch)

In [13]:
x

tensor([[-0.0544,  0.1883, -0.0467, -0.3533,  0.0380, -0.1548, -0.0787, -0.2424,
          0.0815, -0.0482],
        [-0.0539,  0.1869, -0.0471, -0.3514,  0.0367, -0.1543, -0.0785, -0.2426,
          0.0806, -0.0478],
        [-0.0539,  0.1880, -0.0472, -0.3528,  0.0371, -0.1549, -0.0784, -0.2432,
          0.0812, -0.0480],
        [-0.0536,  0.1882, -0.0479, -0.3532,  0.0360, -0.1555, -0.0780, -0.2444,
          0.0817, -0.0481],
        [-0.0527,  0.1890, -0.0477, -0.3537,  0.0364, -0.1552, -0.0781, -0.2444,
          0.0816, -0.0481],
        [-0.0533,  0.1897, -0.0476, -0.3542,  0.0377, -0.1554, -0.0786, -0.2441,
          0.0811, -0.0480],
        [-0.0536,  0.1890, -0.0465, -0.3529,  0.0385, -0.1545, -0.0791, -0.2424,
          0.0806, -0.0479],
        [-0.0542,  0.1894, -0.0477, -0.3542,  0.0375, -0.1558, -0.0786, -0.2441,
          0.0812, -0.0480],
        [-0.0542,  0.1886, -0.0484, -0.3538,  0.0359, -0.1564, -0.0779, -0.2453,
          0.0816, -0.0481]], grad_fn=<MmBackwar

In [14]:
layernorm = nn.LayerNorm(emb_dim)
layernorm(x)

tensor([[ 0.0855,  1.7223,  0.1377, -1.9306,  0.7086, -0.5916, -0.0786, -1.1822,
          1.0019,  0.1272],
        [ 0.0901,  1.7223,  0.1362, -1.9274,  0.7040, -0.5913, -0.0768, -1.1898,
          1.0018,  0.1309],
        [ 0.0898,  1.7229,  0.1351, -1.9279,  0.7041, -0.5921, -0.0754, -1.1880,
          1.0019,  0.1297],
        [ 0.0937,  1.7229,  0.1322, -1.9259,  0.6974, -0.5932, -0.0707, -1.1921,
          1.0052,  0.1305],
        [ 0.0983,  1.7243,  0.1314, -1.9269,  0.6976, -0.5918, -0.0726, -1.1915,
          1.0019,  0.1293],
        [ 0.0940,  1.7264,  0.1322, -1.9272,  0.7049, -0.5923, -0.0763, -1.1878,
          0.9967,  0.1295],
        [ 0.0896,  1.7257,  0.1373, -1.9292,  0.7111, -0.5911, -0.0824, -1.1836,
          0.9947,  0.1279],
        [ 0.0891,  1.7257,  0.1325, -1.9267,  0.7054, -0.5937, -0.0748, -1.1870,
          0.9989,  0.1305],
        [ 0.0913,  1.7243,  0.1304, -1.9233,  0.6976, -0.5956, -0.0681, -1.1936,
          1.0047,  0.1325]], grad_fn=<NativeLay

tensor([[ 0.0576,  0.5200,  0.3318,  0.3702,  0.4872, -0.0411,  0.1965, -0.0339,
          0.3886,  0.1079],
        [ 0.3991,  1.1843,  0.3074, -0.1399,  0.2241,  0.6084,  0.4827, -0.2417,
          0.2985,  0.8158],
        [ 0.2889,  0.8209,  0.5527,  0.2332,  0.1083,  0.3705,  0.2511,  0.0411,
          0.4577,  0.5836],
        [ 0.7737,  0.6778,  0.4344, -0.3396,  0.5903,  0.1094,  0.2118,  0.4421,
          0.2455,  0.4216],
        [ 0.5037,  0.7542,  0.9334,  0.0553,  0.2813,  0.0287,  0.1327,  0.4934,
          0.2469,  0.5920],
        [ 0.4843,  1.0871,  0.8212,  0.5263,  0.2067,  0.1691,  0.6523,  0.5429,
          0.8318,  0.3926],
        [ 0.8683,  0.6413,  0.8434,  0.6053,  0.6629,  0.5157,  0.6899,  0.3060,
          0.6473,  0.8626],
        [ 0.8040,  0.6914,  0.4524,  0.5720,  0.3570,  0.8059,  0.8505,  0.4750,
          0.3507,  0.0386],
        [ 0.4226,  0.7014,  0.1006, -0.3317,  0.1930,  0.6897,  0.1996,  0.5422,
          0.8889,  0.8813]], grad_fn=<AddBackwa