## Left to Right Language Model by Masking Future Token Like GPT.

In [2]:
import torch
torch.manual_seed(42)


<torch._C.Generator at 0x2407f905430>

In [3]:
b, t, c = (4, 8, 2)
x = torch.randn((b, t, c))
print(x.shape)

torch.Size([4, 8, 2])


## Average previous values in time dimension

Channel dim is independent of other channel dim features in context of averaging. Bag of words for averaging against previous values.

In [4]:
xbow = torch.zeros((b, t, c))
for i in range(b):
    for j in range(t):
        xprev = x[i, : j + 1]
        xbow[i, j] = torch.mean(xprev, dim=0)

print(x[0])
print(xbow[0])

tensor([[ 1.9269,  1.4873],
        [ 0.9007, -2.1055],
        [ 0.6784, -1.2345],
        [-0.0431, -1.6047],
        [-0.7521,  1.6487],
        [-0.3925, -1.4036],
        [-0.7279, -0.5594],
        [-0.7688,  0.7624]])
tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])


## Efficient Matrix Multiplication Version

In [5]:
a = torch.ones((3, 3))
b = torch.randint(0, 10, (3, 2)).float()

c = a @ b

print(f"a = \n{a}\n")
print(f"b = \n{b}\n")
print(f"c = \n{c}\n")

a = 
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

b = 
tensor([[0., 1.],
        [3., 0.],
        [1., 1.]])

c = 
tensor([[4., 2.],
        [4., 2.],
        [4., 2.]])



Triangular matrix trick to generate sum of previous features. For autoregressive causal left to right model that does not know about future tokens or values.

In [6]:
a = torch.tril(torch.ones((3, 3)))
b = torch.randint(0, 10, (3, 2)).float()

c = a @ b

print(f"a = \n{a}\n")
print(f"b = \n{b}\n")
print(f"c = \n{c}\n")

a = 
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

b = 
tensor([[7., 9.],
        [4., 3.],
        [8., 9.]])

c = 
tensor([[ 7.,  9.],
        [11., 12.],
        [19., 21.]])



Average previous values by normalizing triangular matrix row wise.

In [7]:
a = torch.tril(torch.ones((3, 3)))
b = torch.randint(0, 10, (3, 2)).float()

a /= torch.sum(a, dim=1, keepdim=True) 

c = a @ b

print(f"a = \n{a}\n")
print(f"b = \n{b}\n")
print(f"c = \n{c}\n")

a = 
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

b = 
tensor([[3., 7.],
        [8., 1.],
        [4., 1.]])

c = 
tensor([[3.0000, 7.0000],
        [5.5000, 4.0000],
        [5.0000, 3.0000]])



## Vectorized Averaging by Weighted Sum

In [8]:
weights = torch.tril(torch.ones((t, t)))
weights /= torch.sum(weights, dim=1, keepdim=True)
print(weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


In [9]:
# (t, t) @ (b, t, c) => (b, t, t) @ (b, t, c) => (b, t, c)
xbow_vec = weights @ x
print(xbow_vec.shape)
print(torch.allclose(xbow, xbow_vec))
print(xbow[0])
print(xbow_vec[0])

torch.Size([4, 8, 2])
True
tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])
tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])


## Softmax Version for Previous Value Averaging

Using negative infinity for softmax exponential gives 0 for each -inf for correct averaging of previous values.

In [10]:
from torch.nn import functional as F

tril = torch.tril(torch.ones(t, t))
weights = torch.zeros((t, t))
weights = weights.masked_fill(tril == 0, float('-inf'))
print(weights)
weights = F.softmax(weights, dim=-1)
xbow_soft = weights @ x
print(weights)
print(tril)
print(torch.allclose(xbow, xbow_soft))

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
tensor([[1., 0

## Self Attention

Batch of 4 with 8 token and each token has 32 dim embedding.

In [15]:
torch.manual_seed(42)
b, t, c = 4, 8, 32 
x = torch.randn((b, t, c))

tril = torch.tril(torch.ones((t, t)))
weights = torch.zeros((t, t))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
out = weights @ x

print(out.shape)

torch.Size([4, 8, 32])


Single head self attention. Masking only used for left to right decoder only LM like GPT, incase of bidirectional encoder only model like BERT.

In [22]:
from torch import nn

torch.manual_seed(42)
b, t, c = 4, 8, 32 
x = torch.randn((b, t, c))

head_dim = 16
to_key = nn.Linear(c, head_dim, bias=False)
to_query = nn.Linear(c, head_dim, bias=False)
to_value = nn.Linear(c, head_dim, bias=False)

k = to_key(x)   # (b, t, head_dim)
q = to_query(x) # (b, t, head_dim)
v = to_query(x) # (b, t, head_dim)

# Transpose only last two dim keeping batch dim. Shape (b, t, t) for token interation attention map.
# Divided by sqrt head dim to get variance such that weights are more distributed instead of focused
# on a single node where most weights are and in other the weights are close to 0. In that case it will behave like
# one hot vector aggregating most info from one single node. It is solved by normalizing with head dim.
weights = q @ k.transpose(-1, -2) * (head_dim ** -0.5)     
print(weights.var())

tril = torch.tril(torch.ones((t, t)))
weights = weights.masked_fill(tril == 0, float('-inf')) # For left to right LM where there previous tokens do not know about future tokens.
weights = F.softmax(weights, dim=-1)
out = weights @ v

print(out.shape)
print(weights.shape)
print(weights[0])

tensor(0.1231, grad_fn=<VarBackward0>)
torch.Size([4, 8, 16])
torch.Size([4, 8, 8])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4106, 0.5894, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3657, 0.2283, 0.4061, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2168, 0.2759, 0.2204, 0.2870, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2553, 0.1697, 0.1548, 0.2341, 0.1861, 0.0000, 0.0000, 0.0000],
        [0.1318, 0.2060, 0.1405, 0.1917, 0.1949, 0.1351, 0.0000, 0.0000],
        [0.2137, 0.0978, 0.2374, 0.1025, 0.1418, 0.0838, 0.1230, 0.0000],
        [0.0852, 0.1047, 0.0824, 0.1376, 0.1015, 0.1900, 0.1780, 0.1206]],
       grad_fn=<SelectBackward0>)


## Softmax Variance Problem Example.

Smoother values.

In [26]:
sm_example = torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)
print(sm_example)
print(sm_example.var())

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
tensor(0.0039)


Sharper values.

In [29]:
sm_example = torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]) * 8, dim=-1)
print(sm_example)
print(sm_example.var())

tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])
tensor(0.1168)
