In [11]:
import torch

batches, block_size, channels = 4,8,2

x = torch.randn(batches, block_size, channels)
x[0]


tensor([[ 1.1117, -1.1870],
        [-0.4048, -0.3101],
        [ 1.6475,  0.9239],
        [-1.4163,  1.3245],
        [-0.2045, -0.9434],
        [-1.5406,  1.1350],
        [ 1.4756,  1.0943],
        [-0.0473, -2.3244]])

The tokens should flow from the previous context to the future tokens. 

## Averaging
All previous tokens up until token t will be averaged. There will be considerable information loss in this process. 

$ X_{b,t} = \frac{1}{T} \sum_{i=0}^T X_{b,i} $

In [24]:
xbow = torch.zeros((batches,block_size,channels))
for batch in range(batches):
    for block in range(block_size):
        xprev = x[batch,:block+1]
        xbow[batch,block] = torch.mean(xprev,0)
        
xbow[0], x[0], (-1.1870-0.3101)/2

(tensor([[ 1.1117, -1.1870],
         [ 0.3534, -0.7486],
         [ 0.7848, -0.1911],
         [ 0.2345,  0.1878],
         [ 0.1467, -0.0384],
         [-0.1345,  0.1571],
         [ 0.0955,  0.2910],
         [ 0.0776, -0.0359]]),
 tensor([[ 1.1117, -1.1870],
         [-0.4048, -0.3101],
         [ 1.6475,  0.9239],
         [-1.4163,  1.3245],
         [-0.2045, -0.9434],
         [-1.5406,  1.1350],
         [ 1.4756,  1.0943],
         [-0.0473, -2.3244]]),
 -0.74855)

In order to get the same result, we can use matrix multiplication using the lower triangle matrix and then calculating average for each rowxcol multiplication.

### Torch method - tril

In [48]:
weights = torch.tril(torch.ones(block_size,block_size))
weights = weights / weights.sum(1, keepdim=True)

print (weights)

xbow2 = weights @ x

torch.allclose(xbow2, xbow)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

### Softmax

In [56]:
import torch.nn.functional as F
tril = torch.tril(torch.ones(block_size, block_size))
weights = torch.zeros((block_size, block_size))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
print(weights)

xbow2 = weights @ x

torch.allclose(xbow2, xbow)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True