In [43]:
import torch

In [44]:
torch.arange(0, 10, 2).float()

tensor([0., 2., 4., 6., 8.])

In [45]:
theta = torch.arange(0, 4, 2).float()

# build m (position parameter), shape = (seq_len)
m = torch.arange(4)

freqs = torch.outer(m, theta)
freqs.shape, freqs

(torch.Size([4, 2]),
 tensor([[0., 0.],
         [0., 2.],
         [0., 4.],
         [0., 6.]]))

In [46]:
freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
freqs_complex

tensor([[ 1.0000+0.0000j,  1.0000+0.0000j],
        [ 1.0000+0.0000j, -0.4161+0.9093j],
        [ 1.0000+0.0000j, -0.6536-0.7568j],
        [ 1.0000+0.0000j,  0.9602-0.2794j]])

In [47]:
x = torch.arange(0, 16).reshape(4, 4)
x

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

In [48]:
# (batch_size, seq_len, n_heads, head_size)
x = torch.randn(3, 4, 2, 4)

In [49]:
torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2).float()).shape

torch.Size([3, 4, 2, 2])

In [50]:
x = torch.arange(0, 10).unsqueeze(0).repeat((2, 1))
x

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [51]:
x[..., : x.shape[-1] // 2]

tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]])

## RMS Norm

In [52]:
x = torch.arange(16).reshape(2, 4, -1)
x.shape

torch.Size([2, 4, 2])

In [53]:
print(f'Shape with keepdims = True: {x.pow(2).mean(-1, keepdim=True, dtype=torch.float).shape}')
print(f'Shape with keepdims = False: {x.pow(2).mean(-1, keepdim=False, dtype=torch.float).shape}')

Shape with keepdims = True: torch.Size([2, 4, 1])
Shape with keepdims = False: torch.Size([2, 4])


In [54]:
(torch.nn.Parameter(torch.ones(2)) * x).shape

torch.Size([2, 4, 2])

In [55]:
a = torch.arange(10).long()
b = torch.arange(10, 13).long()
a, b

(tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([10, 11, 12]))

In [56]:
a[-len(b):] = b
a

tensor([ 0,  1,  2,  3,  4,  5,  6, 10, 11, 12])

In [57]:
torch.arange(10).reshape(2, 5)[:, 0]

tensor([0, 5])

In [90]:
probs = torch.tensor([[0.8, 0.2, 0.3, 0.4], [0.2, 0.6, 0.4, 0.1]])
probs = torch.softmax(probs, dim=-1)
print(f'Probs: {probs}')

probs_sort, idx = torch.sort(probs, dim=-1, descending=True)
probs_cum = torch.cumsum(probs_sort, dim=-1)
print(f'Probs sort: {probs_sort}')
print(f'Probs cum: {probs_cum}')

probs_shifted = torch.roll(probs_cum, 1, dims=-1)
probs_shifted[..., 0] = 0
print(f'Probs shifted: {probs_shifted}')

Probs: tensor([[0.3539, 0.1942, 0.2147, 0.2372],
        [0.2165, 0.3230, 0.2645, 0.1959]])
Probs sort: tensor([[0.3539, 0.2372, 0.2147, 0.1942],
        [0.3230, 0.2645, 0.2165, 0.1959]])
Probs cum: tensor([[0.3539, 0.5911, 0.8058, 1.0000],
        [0.3230, 0.5875, 0.8041, 1.0000]])
Probs shifted: tensor([[0.0000, 0.3539, 0.5911, 0.8058],
        [0.0000, 0.3230, 0.5875, 0.8041]])


In [91]:
mask = probs_shifted > 0.7
mask

tensor([[False, False, False,  True],
        [False, False, False,  True]])

In [92]:
probs[mask] = 0
probs

tensor([[0.3539, 0.1942, 0.2147, 0.0000],
        [0.2165, 0.3230, 0.2645, 0.0000]])

In [99]:
probs = probs / probs.sum(dim=-1, keepdim=True)
probs

tensor([[0.4640, 0.2546, 0.2814, 0.0000],
        [0.2693, 0.4018, 0.3289, 0.0000]])

In [101]:
idx_sort = torch.multinomial(probs, num_samples=1)
print(f'Idx sort: {idx_sort}')

idx = torch.gather(idx, -1, idx_sort)
idx

Idx sort: tensor([[2],
        [1]])


tensor([[2],
        [2]])