In [61]:
import torch
import torch.nn.functional as F
import torch.nn as nn

batches, block_size, channels = 4,8,2

x = torch.randn(batches, block_size, channels)
x[0]


tensor([[ 0.6308,  0.0204],
        [-1.0404, -0.7600],
        [ 0.8160, -0.3312],
        [ 0.9665, -1.7312],
        [-0.3889, -0.6368],
        [-0.1730,  0.1489],
        [-0.6770, -0.3378],
        [-1.1074, -0.1256]])

The tokens should flow from the previous context to the future tokens. 

## Averaging
All previous tokens up until token t will be averaged. There will be considerable information loss in this process. 

$ X_{b,t} = \frac{1}{T} \sum_{i=0}^T X_{b,i} $

In [62]:
xbow = torch.zeros((batches,block_size,channels))
for batch in range(batches):
    for block in range(block_size):
        xprev = x[batch,:block+1]
        xbow[batch,block] = torch.mean(xprev,0)
        
xbow[0], x[0], (-1.1870-0.3101)/2

(tensor([[ 0.6308,  0.0204],
         [-0.2048, -0.3698],
         [ 0.1355, -0.3569],
         [ 0.3432, -0.7005],
         [ 0.1968, -0.6878],
         [ 0.1352, -0.5483],
         [ 0.0192, -0.5182],
         [-0.1217, -0.4692]]),
 tensor([[ 0.6308,  0.0204],
         [-1.0404, -0.7600],
         [ 0.8160, -0.3312],
         [ 0.9665, -1.7312],
         [-0.3889, -0.6368],
         [-0.1730,  0.1489],
         [-0.6770, -0.3378],
         [-1.1074, -0.1256]]),
 -0.74855)

In order to get the same result, we can use matrix multiplication using the lower triangle matrix and then calculating average for each rowxcol multiplication.

### Torch method - tril

In [63]:
weights = torch.tril(torch.ones(block_size,block_size))
weights = weights / weights.sum(1, keepdim=True)

print (weights)

xbow2 = weights @ x

torch.allclose(xbow2, xbow)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

### Softmax

In [64]:
tril = torch.tril(torch.ones(block_size, block_size))
weights = torch.zeros((block_size, block_size))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
print(weights)

xbow2 = weights @ x

torch.allclose(xbow2, xbow)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

### Data dependent self-attention head

This section shows an example of scaled dot-product attention sample from the 3.2.1 of paper (attn is all you need !). 

In [85]:
batches, block_size, no_of_embeddings = 4,8,32
attn_head_size = 16

test = torch.randn(batches, block_size, no_of_embeddings)

key = nn.Linear(no_of_embeddings, attn_head_size, bias=False)
query = nn.Linear(no_of_embeddings, attn_head_size, bias=False)
value = nn.Linear(no_of_embeddings, attn_head_size, bias=False)

k = key(test) # batches x blocksize x attn_head_size
q = query(test) # batches x blocksize x attn_head_size
v = value(test)

# Normalize to ensure that the softmax does not eventually starts converging the max values into 1. 
weights = q @ k.transpose(-2, -1) * attn_head_size**-0.5 # Multiply only last two dimensions above 

# The above operation mulitplies the following. 
# q ( batches, block_size, attn_head_size ) x k( batches, attn_head_size, block_size)
# result = batches x block_size x block_size

tril = torch.tril(torch.ones(block_size, block_size))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)

print (weights)

xbow2 = weights @ v

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.7700, 0.2300, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5087, 0.2784, 0.2129, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2172, 0.2747, 0.1749, 0.3331, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1902, 0.1937, 0.1839, 0.2478, 0.1843, 0.0000, 0.0000, 0.0000],
         [0.1083, 0.2550, 0.1970, 0.1581, 0.2000, 0.0815, 0.0000, 0.0000],
         [0.1058, 0.1034, 0.0718, 0.1642, 0.1555, 0.2547, 0.1446, 0.0000],
         [0.1171, 0.1342, 0.2769, 0.0855, 0.1072, 0.0905, 0.1322, 0.0564]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3494, 0.6506, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3506, 0.3445, 0.3049, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2934, 0.2030, 0.2560, 0.2476, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2679, 0.1502, 0.2152, 0.1830, 0.1836, 0.0000, 0.0000, 0.0000],
         [0.1583, 0.183