## Left to Right Language Model by Masking Future Token Like GPT.

In [2]:
import torch
torch.manual_seed(42)


<torch._C.Generator at 0x1d02ad9d3f0>

In [3]:
b, t, c = (4, 8, 2)
x = torch.randn((b, t, c))
print(x.shape)

torch.Size([4, 8, 2])


## Average previous values in time dimension

Channel dim is independent of other channel dim features in context of averaging. Bag of words for averaging against previous values.

In [6]:
xbow = torch.zeros((b, t, c))
for i in range(b):
    for j in range(t):
        xprev = x[i, : j + 1]
        xbow[i, j] = torch.mean(xprev, dim=0)

print(x[0])
print(xbow[0])

tensor([[ 1.9269,  1.4873],
        [ 0.9007, -2.1055],
        [ 0.6784, -1.2345],
        [-0.0431, -1.6047],
        [-0.7521,  1.6487],
        [-0.3925, -1.4036],
        [-0.7279, -0.5594],
        [-0.7688,  0.7624]])
tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])


## Efficient Matrix Multiplication Version

In [11]:
a = torch.ones((3, 3))
b = torch.randint(0, 10, (3, 2)).float()

c = a @ b

print(f"a = \n{a}\n")
print(f"b = \n{b}\n")
print(f"c = \n{c}\n")

a = 
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

b = 
tensor([[5., 3.],
        [7., 7.],
        [5., 9.]])

c = 
tensor([[17., 19.],
        [17., 19.],
        [17., 19.]])



Triangular matrix trick to generate sum of previous features. For autoregressive causal left to right model that does not know about future tokens or values.

In [13]:
a = torch.tril(torch.ones((3, 3)))
b = torch.randint(0, 10, (3, 2)).float()

c = a @ b

print(f"a = \n{a}\n")
print(f"b = \n{b}\n")
print(f"c = \n{c}\n")

a = 
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

b = 
tensor([[0., 3.],
        [7., 5.],
        [7., 1.]])

c = 
tensor([[ 0.,  3.],
        [ 7.,  8.],
        [14.,  9.]])



Average previous values by normalizing triangular matrix row wise.

In [17]:
a = torch.tril(torch.ones((3, 3)))
b = torch.randint(0, 10, (3, 2)).float()

a /= torch.sum(a, dim=1, keepdim=True) 

c = a @ b

print(f"a = \n{a}\n")
print(f"b = \n{b}\n")
print(f"c = \n{c}\n")

a = 
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

b = 
tensor([[0., 9.],
        [5., 2.],
        [9., 1.]])

c = 
tensor([[0.0000, 9.0000],
        [2.5000, 5.5000],
        [4.6667, 4.0000]])



## Vectorized Averaging by Weighted Sum

In [19]:
weights = torch.tril(torch.ones((t, t)))
weights /= torch.sum(weights, dim=1, keepdim=True)
print(weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


In [23]:
# (t, t) @ (b, t, c) => (b, t, t) @ (b, t, c) => (b, t, c)
xbow_vec = weights @ x
print(xbow_vec.shape)
print(torch.allclose(xbow, xbow_vec))
print(xbow[0])
print(xbow_vec[0])

torch.Size([4, 8, 2])
True
tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])
tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])


## Softmax Version for Previous Value Averaging

Using negative infinity for softmax exponential gives 0 for each -inf for correct averaging of previous values.

In [30]:
from torch.nn import functional as F

tril = torch.tril(torch.ones(t, t))
weights = torch.zeros((t, t))
weights = weights.masked_fill(tril == 0, float('-inf'))
print(weights)
weights = F.softmax(weights, dim=-1)
xbow_soft = weights @ x
print(weights)
print(tril)
print(torch.allclose(xbow, xbow_soft))

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
tensor([[1., 0