In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
v = 10
seq_len = 5
b = 1
tokens = torch.randint(v, (b, seq_len))
tokens.shape, tokens

(torch.Size([1, 5]), tensor([[1, 7, 1, 9, 0]]))

In [3]:
d = 16
embedding = nn.Embedding(v, d)
embedding.weight.shape, embedding.weight

(torch.Size([10, 16]),
 Parameter containing:
 tensor([[ 9.1225e-02,  2.0366e+00,  1.9383e-01, -1.9696e+00,  1.1776e+00,
          -5.1408e-01,  3.2720e-01,  5.9374e-02,  5.5193e-01, -7.6948e-01,
          -5.3258e-01, -8.3575e-02,  6.4998e-01, -5.8064e-01,  6.1310e-01,
           8.6126e-01],
         [ 4.8720e-01, -8.7291e-01,  1.2219e+00, -1.1459e+00, -2.1013e+00,
           3.9205e-01, -3.1630e-01, -2.9296e-01, -1.8151e+00, -1.3193e+00,
           3.1812e-01, -2.6833e-01, -1.4210e+00,  9.5136e-01, -9.5099e-01,
          -9.1863e-01],
         [-4.7831e-01, -2.3978e+00,  2.5013e-02,  6.6511e-01,  1.4395e+00,
          -4.7337e-01, -1.4695e+00,  8.1029e-02,  9.8985e-01,  5.7672e-01,
          -1.3192e+00,  1.1745e+00, -2.2438e+00, -5.2124e-01,  3.0874e-01,
           1.7272e-01],
         [ 1.3685e-01,  9.5971e-01, -4.0739e-01,  3.1784e+00,  9.9384e-02,
           5.2807e-02,  3.8992e-02,  1.8404e+00, -1.0096e+00, -4.5549e-01,
          -6.2814e-01, -3.0709e-02,  4.1311e-01, -5.5749e

In [4]:
x = embedding(tokens)
x.shape, x

(torch.Size([1, 5, 16]),
 tensor([[[ 0.4872, -0.8729,  1.2219, -1.1459, -2.1013,  0.3920, -0.3163,
           -0.2930, -1.8151, -1.3193,  0.3181, -0.2683, -1.4210,  0.9514,
           -0.9510, -0.9186],
          [ 1.8185, -0.1426, -0.3489, -0.2103, -0.6678, -0.9676,  0.2713,
           -1.5767,  1.6402, -1.8149, -0.9517,  1.2237, -0.8839,  0.5745,
           -0.5461,  0.5711],
          [ 0.4872, -0.8729,  1.2219, -1.1459, -2.1013,  0.3920, -0.3163,
           -0.2930, -1.8151, -1.3193,  0.3181, -0.2683, -1.4210,  0.9514,
           -0.9510, -0.9186],
          [ 2.0418, -0.3707, -0.4163,  1.2082,  0.1269,  1.1172,  0.2135,
            1.0372,  0.6899,  2.7725,  0.6110,  1.0755, -1.0651,  0.3827,
           -0.3130, -1.2047],
          [ 0.0912,  2.0366,  0.1938, -1.9696,  1.1776, -0.5141,  0.3272,
            0.0594,  0.5519, -0.7695, -0.5326, -0.0836,  0.6500, -0.5806,
            0.6131,  0.8613]]], grad_fn=<EmbeddingBackward0>))

In [5]:
theta = 10000
num_heads = 4
head_dim = d // num_heads

freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))
print(f'freqs: {freqs.shape}\n{freqs}\n')

t = torch.arange(seq_len * 2, device=freqs.device, dtype=torch.float32)
print(f't: {t.shape}\n{t}\n')

freqs = torch.outer(t, freqs)
print(f'freqs: {freqs.shape}\n{freqs}\n')

freqs_cis = torch.polar(torch.ones_like(freqs), freqs)[:seq_len]
print(f'freqs_cis: {freqs_cis.shape}\n{freqs_cis}')

freqs: torch.Size([2])
tensor([1.0000, 0.0100])

t: torch.Size([10])
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])

freqs: torch.Size([10, 2])
tensor([[0.0000, 0.0000],
        [1.0000, 0.0100],
        [2.0000, 0.0200],
        [3.0000, 0.0300],
        [4.0000, 0.0400],
        [5.0000, 0.0500],
        [6.0000, 0.0600],
        [7.0000, 0.0700],
        [8.0000, 0.0800],
        [9.0000, 0.0900]])

freqs_cis: torch.Size([5, 2])
tensor([[ 1.0000+0.0000j,  1.0000+0.0000j],
        [ 0.5403+0.8415j,  0.9999+0.0100j],
        [-0.4161+0.9093j,  0.9998+0.0200j],
        [-0.9900+0.1411j,  0.9996+0.0300j],
        [-0.6536-0.7568j,  0.9992+0.0400j]])


In [6]:
mask = torch.full(
    (seq_len, seq_len),
    float("-inf")
)
mask = torch.triu(mask, diagonal=1)
mask

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

In [7]:
h = x
print(f'h: {h.shape}\n{h}')

h: torch.Size([1, 5, 16])
tensor([[[ 0.4872, -0.8729,  1.2219, -1.1459, -2.1013,  0.3920, -0.3163,
          -0.2930, -1.8151, -1.3193,  0.3181, -0.2683, -1.4210,  0.9514,
          -0.9510, -0.9186],
         [ 1.8185, -0.1426, -0.3489, -0.2103, -0.6678, -0.9676,  0.2713,
          -1.5767,  1.6402, -1.8149, -0.9517,  1.2237, -0.8839,  0.5745,
          -0.5461,  0.5711],
         [ 0.4872, -0.8729,  1.2219, -1.1459, -2.1013,  0.3920, -0.3163,
          -0.2930, -1.8151, -1.3193,  0.3181, -0.2683, -1.4210,  0.9514,
          -0.9510, -0.9186],
         [ 2.0418, -0.3707, -0.4163,  1.2082,  0.1269,  1.1172,  0.2135,
           1.0372,  0.6899,  2.7725,  0.6110,  1.0755, -1.0651,  0.3827,
          -0.3130, -1.2047],
         [ 0.0912,  2.0366,  0.1938, -1.9696,  1.1776, -0.5141,  0.3272,
           0.0594,  0.5519, -0.7695, -0.5326, -0.0836,  0.6500, -0.5806,
           0.6131,  0.8613]]], grad_fn=<EmbeddingBackward0>)


In [8]:
mean_squared = x.pow(2).mean(dim=-1, keepdim=True)
mean_squared

tensor([[[1.1526],
         [1.0974],
         [1.1526],
         [1.3030],
         [0.8062]]], grad_fn=<MeanBackward1>)

In [9]:
x_normed = x * torch.rsqrt(mean_squared + 1e-6)
print(f'x_normed: {x_normed.shape}\n{x_normed}')

x_normed: torch.Size([1, 5, 16])
tensor([[[ 0.4538, -0.8131,  1.1381, -1.0674, -1.9573,  0.3652, -0.2946,
          -0.2729, -1.6907, -1.2288,  0.2963, -0.2499, -1.3236,  0.8862,
          -0.8858, -0.8557],
         [ 1.7359, -0.1361, -0.3330, -0.2008, -0.6375, -0.9237,  0.2590,
          -1.5051,  1.5658, -1.7325, -0.9085,  1.1682, -0.8437,  0.5484,
          -0.5213,  0.5451],
         [ 0.4538, -0.8131,  1.1381, -1.0674, -1.9573,  0.3652, -0.2946,
          -0.2729, -1.6907, -1.2288,  0.2963, -0.2499, -1.3236,  0.8862,
          -0.8858, -0.8557],
         [ 1.7888, -0.3247, -0.3647,  1.0584,  0.1112,  0.9787,  0.1870,
           0.9086,  0.6044,  2.4289,  0.5353,  0.9422, -0.9330,  0.3353,
          -0.2742, -1.0553],
         [ 0.1016,  2.2682,  0.2159, -2.1936,  1.3116, -0.5725,  0.3644,
           0.0661,  0.6147, -0.8570, -0.5932, -0.0931,  0.7239, -0.6467,
           0.6828,  0.9592]]], grad_fn=<MulBackward0>)


In [10]:
rms_scale = torch.ones(d)
print(f'rms_scale: {rms_scale.shape}\n{rms_scale}\n')

x_normed *= rms_scale
print(f'x_normed: {x_normed.shape}\n{x_normed}')

rms_scale: torch.Size([16])
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

x_normed: torch.Size([1, 5, 16])
tensor([[[ 0.4538, -0.8131,  1.1381, -1.0674, -1.9573,  0.3652, -0.2946,
          -0.2729, -1.6907, -1.2288,  0.2963, -0.2499, -1.3236,  0.8862,
          -0.8858, -0.8557],
         [ 1.7359, -0.1361, -0.3330, -0.2008, -0.6375, -0.9237,  0.2590,
          -1.5051,  1.5658, -1.7325, -0.9085,  1.1682, -0.8437,  0.5484,
          -0.5213,  0.5451],
         [ 0.4538, -0.8131,  1.1381, -1.0674, -1.9573,  0.3652, -0.2946,
          -0.2729, -1.6907, -1.2288,  0.2963, -0.2499, -1.3236,  0.8862,
          -0.8858, -0.8557],
         [ 1.7888, -0.3247, -0.3647,  1.0584,  0.1112,  0.9787,  0.1870,
           0.9086,  0.6044,  2.4289,  0.5353,  0.9422, -0.9330,  0.3353,
          -0.2742, -1.0553],
         [ 0.1016,  2.2682,  0.2159, -2.1936,  1.3116, -0.5725,  0.3644,
           0.0661,  0.6147, -0.8570, -0.5932, -0.0931,  0.7239, -0.6467,
           0.6828, 

In [11]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

In [12]:
h, x_normed

(tensor([[[ 0.4872, -0.8729,  1.2219, -1.1459, -2.1013,  0.3920, -0.3163,
           -0.2930, -1.8151, -1.3193,  0.3181, -0.2683, -1.4210,  0.9514,
           -0.9510, -0.9186],
          [ 1.8185, -0.1426, -0.3489, -0.2103, -0.6678, -0.9676,  0.2713,
           -1.5767,  1.6402, -1.8149, -0.9517,  1.2237, -0.8839,  0.5745,
           -0.5461,  0.5711],
          [ 0.4872, -0.8729,  1.2219, -1.1459, -2.1013,  0.3920, -0.3163,
           -0.2930, -1.8151, -1.3193,  0.3181, -0.2683, -1.4210,  0.9514,
           -0.9510, -0.9186],
          [ 2.0418, -0.3707, -0.4163,  1.2082,  0.1269,  1.1172,  0.2135,
            1.0372,  0.6899,  2.7725,  0.6110,  1.0755, -1.0651,  0.3827,
           -0.3130, -1.2047],
          [ 0.0912,  2.0366,  0.1938, -1.9696,  1.1776, -0.5141,  0.3272,
            0.0594,  0.5519, -0.7695, -0.5326, -0.0836,  0.6500, -0.5806,
            0.6131,  0.8613]]], grad_fn=<EmbeddingBackward0>),
 tensor([[[ 0.4538, -0.8131,  1.1381, -1.0674, -1.9573,  0.3652, -0.2946,
   

In [13]:
num_kv_heads = 2
assert num_heads % num_kv_heads == 0
print(f"as a reminder: num_heads = {num_heads}, head_dim = {head_dim}")

as a reminder: num_heads = 4, head_dim = 4


In [14]:
wq = nn.Linear(d, num_heads * head_dim, bias=False)
wk = nn.Linear(d, num_kv_heads * head_dim, bias=False)
wv = nn.Linear(d, num_kv_heads * head_dim, bias=False)
print("Attention weights: ", wq.weight.shape, wk.weight.shape, wv.weight.shape)

xq = wq(x_normed)
xk = wk(x_normed)
xv = wv(x_normed)
print("Attention projections: ", xq.shape, xk.shape, xv.shape)

xq = xq.view(b, seq_len, num_heads, head_dim)
xk = xk.view(b, seq_len, num_kv_heads, head_dim)
xv = xv.view(b, seq_len, num_kv_heads, head_dim)
print("Reshaped: ", xq.shape, xk.shape, xv.shape)

Attention weights:  torch.Size([16, 16]) torch.Size([8, 16]) torch.Size([8, 16])
Attention projections:  torch.Size([1, 5, 16]) torch.Size([1, 5, 8]) torch.Size([1, 5, 8])
Reshaped:  torch.Size([1, 5, 4, 4]) torch.Size([1, 5, 2, 4]) torch.Size([1, 5, 2, 4])


In [15]:
xq = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
print(f'xq: {xq.shape}\n{xq}\n')
print(f'xk: {xk.shape}\n{xk}')

xq: torch.Size([1, 5, 4, 2])
tensor([[[[-0.4291+0.1962j,  0.5342-0.4326j],
          [ 0.7279-0.0046j,  0.6127-1.0880j],
          [-0.6104+0.3684j,  0.4189+0.2181j],
          [-0.3447+0.0437j,  0.0332+1.4734j]],

         [[ 0.3149+0.0414j,  0.1118-0.6002j],
          [ 0.0803+0.4188j, -0.5217+0.3098j],
          [ 0.3095-0.1339j, -0.8670-0.7433j],
          [-0.6699+0.1433j, -0.2096-0.9590j]],

         [[-0.4291+0.1962j,  0.5342-0.4326j],
          [ 0.7279-0.0046j,  0.6127-1.0880j],
          [-0.6104+0.3684j,  0.4189+0.2181j],
          [-0.3447+0.0437j,  0.0332+1.4734j]],

         [[-0.0009+1.2107j, -0.8004+0.2490j],
          [ 0.1001+0.2970j,  0.1697-0.0677j],
          [ 0.5461+0.8427j, -0.3851-0.2928j],
          [-0.1326-0.3128j,  0.6575-0.4605j]],

         [[ 0.3308+0.3233j,  0.7262-0.3959j],
          [ 0.6365-0.1578j,  0.3134+0.2363j],
          [-0.2056-0.9324j,  0.6066+0.4276j],
          [-0.4668+0.9308j, -0.0943-0.1814j]]]],
       grad_fn=<ViewAsComplexBackward0>)

In [16]:
ndim = xq.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (xq.shape[1], xq.shape[-1]), f'freqs_cis.shape {freqs_cis.shape} != xq.shape[1], xq.shape[-1] {(xq.shape[1], xq.shape[-1])}'

shape = [d if i == 1 or i == xq.ndim - 1 else 1 for i, d in enumerate(xq.shape)]
print(f'shape: {shape}\n')

freqs_cis = freqs_cis.view(*shape)
print(f'freqs_cis: {freqs_cis.shape}\n{freqs_cis}')

shape: [1, 5, 1, 2]

freqs_cis: torch.Size([1, 5, 1, 2])
tensor([[[[ 1.0000+0.0000j,  1.0000+0.0000j]],

         [[ 0.5403+0.8415j,  0.9999+0.0100j]],

         [[-0.4161+0.9093j,  0.9998+0.0200j]],

         [[-0.9900+0.1411j,  0.9996+0.0300j]],

         [[-0.6536-0.7568j,  0.9992+0.0400j]]]])


In [17]:
xq = torch.view_as_real(xq * freqs_cis).flatten(3).type_as(xv)
xk = torch.view_as_real(xk * freqs_cis).flatten(3).type_as(xv)
print(f'xq: {xq.shape}\n{xq}\n')
print(f'xk: {xk.shape}\n{xk}')

xq: torch.Size([1, 5, 4, 4])
tensor([[[[-4.2913e-01,  1.9618e-01,  5.3422e-01, -4.3260e-01],
          [ 7.2795e-01, -4.5819e-03,  6.1274e-01, -1.0880e+00],
          [-6.1038e-01,  3.6835e-01,  4.1887e-01,  2.1810e-01],
          [-3.4472e-01,  4.3680e-02,  3.3236e-02,  1.4734e+00]],

         [[ 1.3534e-01,  2.8732e-01,  1.1781e-01, -5.9904e-01],
          [-3.0900e-01,  2.9386e-01, -5.2476e-01,  3.0456e-01],
          [ 2.7990e-01,  1.8814e-01, -8.5950e-01, -7.5190e-01],
          [-4.8248e-01, -4.8626e-01, -2.0000e-01, -9.6103e-01]],

         [[ 1.9200e-04, -4.7184e-01,  5.4276e-01, -4.2183e-01],
          [-2.9877e-01,  6.6383e-01,  6.3438e-01, -1.0755e+00],
          [-8.0936e-02, -7.0831e-01,  4.1443e-01,  2.2643e-01],
          [ 1.0373e-01, -3.3163e-01,  3.7640e-03,  1.4737e+00]],

         [[-1.6993e-01, -1.1987e+00, -8.0754e-01,  2.2493e-01],
          [-1.4101e-01, -2.7986e-01,  1.7163e-01, -6.2605e-02],
          [-6.5951e-01, -7.5718e-01, -3.7610e-01, -3.0421e-01],
     

In [18]:
if num_kv_heads != num_heads:
  num_queries_per_kv = num_heads // num_kv_heads
  xk = torch.repeat_interleave(xk, num_queries_per_kv, dim=2)
  xv = torch.repeat_interleave(xv, num_queries_per_kv, dim=2)

xq.shape, xk.shape, xv.shape

(torch.Size([1, 5, 4, 4]), torch.Size([1, 5, 4, 4]), torch.Size([1, 5, 4, 4]))

In [19]:
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)

xq.shape, xk.shape, xv.shape

(torch.Size([1, 4, 5, 4]), torch.Size([1, 4, 5, 4]), torch.Size([1, 4, 5, 4]))

In [20]:
scores = torch.matmul(xq, xk.transpose(2, 3))

scores = scores / math.sqrt(head_dim)

scores.shape, scores

(torch.Size([1, 4, 5, 5]),
 tensor([[[[-1.2882e-01,  3.9863e-01, -2.6431e-01, -1.6605e-01, -1.2974e-01],
           [-4.0777e-01,  2.8193e-01, -1.6883e-01,  6.5865e-03, -4.1105e-02],
           [ 1.5198e-01, -1.0384e-02, -1.2882e-01,  1.4967e-01,  1.7464e-01],
           [ 4.9163e-01, -1.0504e+00, -2.9046e-01,  4.2492e-01,  4.6320e-01],
           [ 2.1969e-01,  6.7667e-02, -4.0950e-02,  1.2049e-01,  1.5443e-01]],
 
          [[-4.4131e-01,  3.4228e-01, -4.4563e-02,  2.4650e-01,  1.5564e-01],
           [-9.0590e-02, -6.0891e-02, -8.7912e-02, -1.4631e-01, -1.3989e-01],
           [-6.4649e-01,  7.8515e-01, -4.4131e-01, -2.1182e-01, -2.3316e-01],
           [ 1.5718e-01, -8.2928e-02, -8.6603e-02,  4.8466e-02,  7.8325e-02],
           [ 4.2356e-01, -1.0886e-01, -9.9706e-02, -6.8874e-02,  2.2981e-02]],
 
          [[-2.4204e-01,  2.2885e-01,  1.4402e-01,  1.5196e-01, -1.6548e-02],
           [ 1.6082e-01, -6.6600e-02,  6.8155e-02,  6.4296e-02,  3.4095e-02],
           [-1.2597e-01, -5.748

In [21]:
scores = scores + mask

scores.shape, scores

(torch.Size([1, 4, 5, 5]),
 tensor([[[[-1.2882e-01,        -inf,        -inf,        -inf,        -inf],
           [-4.0777e-01,  2.8193e-01,        -inf,        -inf,        -inf],
           [ 1.5198e-01, -1.0384e-02, -1.2882e-01,        -inf,        -inf],
           [ 4.9163e-01, -1.0504e+00, -2.9046e-01,  4.2492e-01,        -inf],
           [ 2.1969e-01,  6.7667e-02, -4.0950e-02,  1.2049e-01,  1.5443e-01]],
 
          [[-4.4131e-01,        -inf,        -inf,        -inf,        -inf],
           [-9.0590e-02, -6.0891e-02,        -inf,        -inf,        -inf],
           [-6.4649e-01,  7.8515e-01, -4.4131e-01,        -inf,        -inf],
           [ 1.5718e-01, -8.2928e-02, -8.6603e-02,  4.8466e-02,        -inf],
           [ 4.2356e-01, -1.0886e-01, -9.9706e-02, -6.8874e-02,  2.2981e-02]],
 
          [[-2.4204e-01,        -inf,        -inf,        -inf,        -inf],
           [ 1.6082e-01, -6.6600e-02,        -inf,        -inf,        -inf],
           [-1.2597e-01, -5.748

In [22]:
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3341, 0.6659, 0.0000, 0.0000, 0.0000],
          [0.3838, 0.3263, 0.2899, 0.0000, 0.0000],
          [0.3836, 0.0821, 0.1755, 0.3588, 0.0000],
          [0.2236, 0.1921, 0.1723, 0.2025, 0.2095]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4926, 0.5074, 0.0000, 0.0000, 0.0000],
          [0.1559, 0.6526, 0.1914, 0.0000, 0.0000],
          [0.2884, 0.2269, 0.2260, 0.2587, 0.0000],
          [0.2889, 0.1697, 0.1712, 0.1766, 0.1936]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5566, 0.4434, 0.0000, 0.0000, 0.0000],
          [0.3306, 0.3750, 0.2944, 0.0000, 0.0000],
          [0.1980, 0.2926, 0.2354, 0.2740, 0.0000],
          [0.1516, 0.2317, 0.2379, 0.2189, 0.1599]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3556, 0.6444, 0.0000, 0.0000, 0.0000],
          [0.4146, 0.2223, 0.3630, 0.0000, 0.0000],
          [0.2187, 0.2996, 0.2130, 0.2686, 0.0000],
      

In [23]:
output = torch.matmul(scores, xv)
output.shape, output

(torch.Size([1, 4, 5, 4]),
 tensor([[[[ 0.6480, -0.5029, -1.0608, -0.2744],
           [-0.3112, -0.0214, -0.3650, -0.3829],
           [ 0.1780, -0.2669, -0.7198, -0.3276],
           [ 0.4929, -0.1775, -0.1519,  0.0971],
           [-0.0167, -0.0672, -0.2497, -0.1531]],
 
          [[ 0.6480, -0.5029, -1.0608, -0.2744],
           [-0.0829, -0.1360, -0.5306, -0.3571],
           [-0.2921, -0.0310, -0.3789, -0.3808],
           [ 0.2946, -0.1471, -0.2303, -0.0339],
           [ 0.0462, -0.1138, -0.3436, -0.1723]],
 
          [[-0.2318, -0.3154, -0.1987, -0.1365],
           [-0.0119,  0.1223, -0.0089, -0.1421],
           [-0.0458,  0.0548, -0.0382, -0.1412],
           [-0.1699,  0.0096, -0.0344, -0.2951],
           [ 0.0055,  0.0258, -0.0597, -0.3398]],
 
          [[-0.2318, -0.3154, -0.1987, -0.1365],
           [ 0.0878,  0.3207,  0.0772, -0.1446],
           [-0.1215, -0.0959, -0.1035, -0.1393],
           [-0.1648,  0.0158, -0.0321, -0.2921],
           [ 0.0513, -0.0238, -0.

In [24]:
output = output.transpose(1, 2).contiguous().view(b, seq_len, -1)
output.shape, output

(torch.Size([1, 5, 16]),
 tensor([[[ 0.6480, -0.5029, -1.0608, -0.2744,  0.6480, -0.5029, -1.0608,
           -0.2744, -0.2318, -0.3154, -0.1987, -0.1365, -0.2318, -0.3154,
           -0.1987, -0.1365],
          [-0.3112, -0.0214, -0.3650, -0.3829, -0.0829, -0.1360, -0.5306,
           -0.3571, -0.0119,  0.1223, -0.0089, -0.1421,  0.0878,  0.3207,
            0.0772, -0.1446],
          [ 0.1780, -0.2669, -0.7198, -0.3276, -0.2921, -0.0310, -0.3789,
           -0.3808, -0.0458,  0.0548, -0.0382, -0.1412, -0.1215, -0.0959,
           -0.1035, -0.1393],
          [ 0.4929, -0.1775, -0.1519,  0.0971,  0.2946, -0.1471, -0.2303,
           -0.0339, -0.1699,  0.0096, -0.0344, -0.2951, -0.1648,  0.0158,
           -0.0321, -0.2921],
          [-0.0167, -0.0672, -0.2497, -0.1531,  0.0462, -0.1138, -0.3436,
           -0.1723,  0.0055,  0.0258, -0.0597, -0.3398,  0.0513, -0.0238,
           -0.0953, -0.3375]]], grad_fn=<ViewBackward0>))

In [25]:
wo = nn.Linear(num_heads * head_dim, d, bias=False)
Xout = wo(output)
Xout.shape, Xout

(torch.Size([1, 5, 16]),
 tensor([[[-0.6617, -0.0299, -0.4058,  0.2500,  0.0009,  0.1392, -0.4186,
            0.3082,  0.7888,  0.2122, -0.1651, -0.1097, -0.5544, -0.1342,
            0.0739,  0.1340],
          [ 0.0026,  0.0206,  0.0619, -0.2046,  0.1187, -0.0892, -0.2912,
            0.1807,  0.1706,  0.2084,  0.0763, -0.0542, -0.1686,  0.0852,
           -0.1680, -0.0985],
          [-0.2085,  0.0279, -0.1022, -0.0251,  0.1744,  0.0978, -0.2955,
            0.1807,  0.4125,  0.1452, -0.0240,  0.0638, -0.0503,  0.0967,
           -0.0959,  0.0883],
          [-0.1453, -0.0315, -0.1794,  0.0930,  0.0528,  0.1274, -0.1429,
            0.1725,  0.2777, -0.0251, -0.1649, -0.0332, -0.2049, -0.1167,
           -0.0285,  0.0384],
          [-0.0028, -0.0288, -0.0941, -0.0902,  0.1460,  0.0550, -0.1392,
            0.1928,  0.2829,  0.1459,  0.0012, -0.0146, -0.2165, -0.0160,
           -0.0584,  0.0478]]], grad_fn=<UnsafeViewBackward0>))

In [26]:
h += Xout
h.shape, h

(torch.Size([1, 5, 16]),
 tensor([[[-0.1745, -0.9029,  0.8161, -0.8959, -2.1005,  0.5312, -0.7349,
            0.0153, -1.0263, -1.1071,  0.1530, -0.3781, -1.9754,  0.8171,
           -0.8771, -0.7846],
          [ 1.8211, -0.1220, -0.2870, -0.4149, -0.5491, -1.0568, -0.0199,
           -1.3961,  1.8108, -1.6065, -0.8753,  1.1695, -1.0525,  0.6597,
           -0.7141,  0.4726],
          [ 0.2787, -0.8450,  1.1197, -1.1710, -1.9269,  0.4898, -0.6118,
           -0.1123, -1.4026, -1.1741,  0.2941, -0.2045, -1.4713,  1.0480,
           -1.0469, -0.8303],
          [ 1.8966, -0.4021, -0.5957,  1.3011,  0.1797,  1.2446,  0.0706,
            1.2096,  0.9675,  2.7474,  0.4461,  1.0423, -1.2699,  0.2661,
           -0.3415, -1.1662],
          [ 0.0884,  2.0078,  0.0998, -2.0598,  1.3237, -0.4591,  0.1880,
            0.2522,  0.8348, -0.6235, -0.5314, -0.0982,  0.4335, -0.5966,
            0.5547,  0.9091]]], grad_fn=<AddBackward0>))

In [27]:
pre_ffwd_norm = RMSNorm(d)
h_normed = pre_ffwd_norm(h)

In [28]:
hidden_dim = 4 * d
print(hidden_dim)
hidden_dim = int(2 * hidden_dim / 3)
print(hidden_dim)
multiple_of = 256
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
print(hidden_dim)

64
42
256


In [29]:
up = nn.Linear(d, hidden_dim, bias=False)
gate = nn.Linear(d, hidden_dim, bias=False)
down = nn.Linear(hidden_dim, d, bias=False)

In [30]:
up_proj = up(h_normed)
print(up_proj.shape, up_proj)

torch.Size([1, 5, 256]) tensor([[[-0.0770,  0.1202, -1.1196,  ..., -0.5964,  0.7209, -0.0904],
         [-0.7331,  0.2328, -0.6200,  ..., -0.7544, -0.5187, -0.0947],
         [-0.2734,  0.0072, -1.0260,  ..., -0.3971,  0.9436, -0.1348],
         [ 0.5879,  0.1090, -0.7035,  ..., -0.4524, -0.1750, -0.0806],
         [-1.0225, -0.5696,  0.3746,  ...,  0.7557, -0.6852,  0.9189]]],
       grad_fn=<UnsafeViewBackward0>)


In [31]:
gate_proj = F.silu(gate(h_normed))
print(gate_proj.shape, gate_proj)

torch.Size([1, 5, 256]) tensor([[[ 0.1571,  0.2550,  0.1297,  ...,  0.3237,  0.0146, -0.0486],
         [ 0.3652, -0.2275, -0.1862,  ..., -0.2351, -0.1624, -0.1733],
         [ 0.1122,  0.3358,  0.2703,  ...,  0.4038, -0.0508,  0.0257],
         [-0.2461, -0.2777,  0.1492,  ...,  0.1976, -0.1919, -0.1336],
         [-0.1061, -0.0429, -0.2663,  ..., -0.2517,  0.4077, -0.2083]]],
       grad_fn=<SiluBackward0>)


In [32]:
ffwd_output = down(up_proj * gate_proj)
print(ffwd_output.shape, ffwd_output)

torch.Size([1, 5, 16]) tensor([[[-0.0668,  0.0026, -0.0692, -0.0277,  0.2288,  0.0302,  0.0701,
           0.0597, -0.0955, -0.2088, -0.1602, -0.1281, -0.1132, -0.0271,
           0.0646, -0.0286],
         [ 0.2062,  0.1321, -0.0504, -0.0388, -0.0484, -0.0084, -0.0651,
          -0.0323, -0.0965,  0.0505, -0.1811, -0.1295, -0.1527,  0.0595,
           0.0565,  0.0684],
         [-0.0629, -0.0679, -0.0845, -0.0030,  0.1321,  0.0697,  0.0152,
          -0.0110, -0.0517, -0.2454, -0.1528, -0.1100, -0.1216, -0.0202,
           0.0851, -0.0105],
         [-0.0197, -0.0460, -0.0145,  0.0194, -0.0450,  0.1390,  0.1077,
          -0.1112, -0.0831, -0.0597, -0.0115,  0.0719,  0.0805, -0.1357,
           0.0141,  0.1476],
         [-0.0068, -0.1018,  0.2256,  0.0273,  0.1913,  0.0437,  0.1456,
          -0.1769,  0.0084, -0.0047,  0.0938, -0.0595, -0.0561, -0.1378,
           0.0841,  0.1004]]], grad_fn=<UnsafeViewBackward0>)


In [33]:
out = h + ffwd_output
print(out.shape, out)

torch.Size([1, 5, 16]) tensor([[[-0.2413, -0.9002,  0.7469, -0.9236, -1.8717,  0.5614, -0.6648,
           0.0750, -1.1218, -1.3160, -0.0072, -0.5061, -2.0886,  0.7900,
          -0.8125, -0.8132],
         [ 2.0273,  0.0101, -0.3374, -0.4538, -0.5975, -1.0653, -0.0850,
          -1.4284,  1.7143, -1.5559, -1.0565,  1.0400, -1.2051,  0.7191,
          -0.6577,  0.5410],
         [ 0.2158, -0.9129,  1.0352, -1.1740, -1.7948,  0.5595, -0.5966,
          -0.1233, -1.4543, -1.4194,  0.1413, -0.3145, -1.5930,  1.0279,
          -0.9618, -0.8408],
         [ 1.8769, -0.4482, -0.6103,  1.3206,  0.1347,  1.3836,  0.1783,
           1.0985,  0.8845,  2.6877,  0.4346,  1.1142, -1.1894,  0.1304,
          -0.3275, -1.0187],
         [ 0.0816,  1.9060,  0.3254, -2.0325,  1.5149, -0.4154,  0.3336,
           0.0754,  0.8432, -0.6283, -0.4377, -0.1577,  0.3773, -0.7344,
           0.6388,  1.0094]]], grad_fn=<AddBackward0>)


In [34]:
final_norm = RMSNorm(d)
out_normed = final_norm(out)

In [35]:
final_output = nn.Linear(d, v, bias=False)
logits = final_output(out_normed).float()
logits.shape, logits

(torch.Size([1, 5, 10]),
 tensor([[[-1.0293,  0.8460,  0.4050,  0.4700, -0.3143,  0.7373, -0.3887,
            0.4419,  0.4680, -0.2215],
          [-0.0776, -0.0209,  0.0788, -0.1343,  0.1066,  0.6125,  0.6652,
            0.7374, -0.6187,  0.0756],
          [-0.9305,  0.6992,  0.2725,  0.1775, -0.5326,  0.5436, -0.5076,
            0.5103,  0.1687, -0.4812],
          [ 1.3699, -0.9603,  0.4626,  0.4300, -0.1564,  0.5303, -1.4380,
           -0.1278, -0.3727,  0.4631],
          [-0.2596, -0.1003,  0.0659, -0.2814,  0.3217, -0.0557,  0.6066,
           -0.0484, -0.8193, -0.1996]]], grad_fn=<UnsafeViewBackward0>))

In [36]:
probs = F.softmax(logits, dim=-1)
probs

tensor([[[0.0270, 0.1760, 0.1132, 0.1208, 0.0552, 0.1579, 0.0512, 0.1175,
          0.1206, 0.0605],
         [0.0742, 0.0785, 0.0868, 0.0701, 0.0892, 0.1479, 0.1559, 0.1676,
          0.0432, 0.0865],
         [0.0349, 0.1782, 0.1163, 0.1058, 0.0520, 0.1525, 0.0533, 0.1475,
          0.1048, 0.0547],
         [0.2938, 0.0286, 0.1186, 0.1148, 0.0639, 0.1269, 0.0177, 0.0657,
          0.0514, 0.1186],
         [0.0781, 0.0916, 0.1082, 0.0765, 0.1398, 0.0958, 0.1858, 0.0965,
          0.0447, 0.0830]]], grad_fn=<SoftmaxBackward0>)

In [37]:
greedy_indices = torch.argmax(probs, dim=-1)
greedy_indices

tensor([[1, 7, 1, 0, 6]])

In [38]:
target_token_indices = torch.randint(0, v, greedy_indices.shape)
print(target_token_indices)

loss_fn = nn.CrossEntropyLoss()

loss = loss_fn(logits.view(1,v,seq_len), target_token_indices)
print(loss)

tensor([[3, 8, 4, 3, 7]])
tensor(2.3506, grad_fn=<NllLoss2DBackward0>)


In [39]:
!wget https://raw.githubusercontent.com/prathamesh-mandavkar/llama3-implementation/main/tiny_shakespeare_tokenizer.py
# and the tokenizer model
!wget https://raw.githubusercontent.com/prathamesh-mandavkar/llama3-implementation/main/tokenizers/tiny_shakespeare_tokenizer_512.model
!mkdir -p tokenizers
!mv tiny_shakespeare_tokenizer_512.model tokenizers/
from tiny_shakespeare_tokenizer import *
tokenizer = get_tokenizer(size = 512)

--2025-04-22 19:05:01--  https://raw.githubusercontent.com/prathamesh-mandavkar/llama3-implementation/main/tiny_shakespeare_tokenizer.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2180 (2.1K) [text/plain]
Saving to: ‘tiny_shakespeare_tokenizer.py’


2025-04-22 19:05:01 (42.8 MB/s) - ‘tiny_shakespeare_tokenizer.py’ saved [2180/2180]

--2025-04-22 19:05:01--  https://raw.githubusercontent.com/prathamesh-mandavkar/llama3-implementation/main/tokenizers/tiny_shakespeare_tokenizer_512.model
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
L

In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import time
import os
import json

In [41]:
@dataclass
class ModelArgs:
    dim: int = 128
    n_layers: int = 8
    n_heads: int = 4
    n_kv_heads: Optional[int] = 1
    vocab_size: int = tokenizer.vocab_len
    multiple_of: int = 256
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5
    rope_theta: float = 10000
    max_batch_size: int = 32
    max_seq_len: int = 512
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    dropout_rate: float = 0.1

In [42]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

In [43]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis.to(params.device)

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis.shape {freqs_cis.shape} != (x.shape[1], x.shape[-1]) {(x.shape[1], x.shape[-1])}'
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

In [44]:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, seqlen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, seqlen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
    )

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        self.n_rep = args.n_heads // self.n_kv_heads
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim),
            requires_grad = False
        ).to(args.device)
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim),
            requires_grad = False
        ).to(args.device)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
        start_pos: int = None,
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        if start_pos is not None:
            self.cache_k = self.cache_k.to(xq)
            self.cache_v = self.cache_v.to(xq)

            self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
            self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

            keys = self.cache_k[:bsz, : start_pos + seqlen]
            values = self.cache_v[:bsz, : start_pos + seqlen]
        else:
            keys, values = xk, xv

        keys = repeat_kv(keys, self.n_rep)
        values = repeat_kv(values, self.n_rep)

        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)

        output = torch.matmul(scores, values)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

### 2e. Ffwd
<a id='twoe'></a>

In [45]:
class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()

        hidden_dim = int(2 * hidden_dim / 3)
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

### 2f. Residual Layers
<a id='twof'></a>

In [46]:
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.dropout_rate = args.dropout_rate

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
        start_pos: int = None,
        training = False,
    ):

        h = x + F.dropout(self.attention(self.attention_norm(x), freqs_cis, mask, start_pos), p=self.dropout_rate, training=training)
        out = h + F.dropout(self.feed_forward(self.ffn_norm(h)), p=self.dropout_rate, training=training)
        return out

In [47]:
class Llama3(nn.Module):
    def __init__(self, params: ModelArgs, tokenizer):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers
        self.max_seq_len = params.max_seq_len
        self.tokenizer = tokenizer

        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = nn.Linear(
            params.dim,
            params.vocab_size,
            bias=False)

        self.freqs_cis = precompute_freqs_cis(
            params.dim // params.n_heads,
            params.max_seq_len * 2,
            params.rope_theta,)

        mask = torch.full((params.max_seq_len, params.max_seq_len),
                          float("-inf"),
                          device=params.device)
        mask = torch.triu(mask, diagonal=1)
        self.register_buffer('mask', mask)

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, tokens: torch.Tensor, targets: torch.Tensor):
        bsz, seqlen = tokens.shape
        assert tokens.shape == targets.shape
        assert seqlen == self.max_seq_len

        h = self.tok_embeddings(tokens)

        freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[:seqlen]

        for layer in self.layers:
            h = layer(
                h,
                freqs_cis,
                self.mask,
                start_pos = None,
                training = True
            )

        h = self.norm(h)
        logits = self.output(h).float()

        loss = self.criterion(
            logits.view(bsz * seqlen, self.vocab_size),
            targets.reshape(bsz * seqlen))

        return logits, loss

    @torch.inference_mode()
    def forward_inference(self,
                          tokens: torch.Tensor,
                          start_pos: int,
                          max_context_window: int,
                         ):
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        mask = self.mask[:seqlen, :seqlen]

        mask = torch.hstack(
            [torch.zeros((seqlen, start_pos), device=tokens.device), mask]
        ).type_as(h)

        for layer in self.layers:
            h = layer(
                h,
                freqs_cis,
                mask,
                start_pos = start_pos
            )
        h = self.norm(h)
        logits = self.output(h).float()
        return logits

    @torch.inference_mode()
    def Sampler(
        self,
        logits: torch.Tensor,
        temperature: float,
        top_p: float,
        top_k: int,
    ) -> torch.Tensor:
        logits = logits[:,-1,:]
        logits.div_(temperature)
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
        probs_sum = torch.cumsum(probs_sort, dim=-1)
        top_ps_mask = (probs_sum - probs_sort) > top_p
        probs_sort = torch.where(top_ps_mask, 0, probs_sort)
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        top_ks_mask = top_ks_mask >= top_k
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        probs = torch.gather(probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
        next_token_id = torch.multinomial(probs, num_samples=1)
        return next_token_id

    @torch.inference_mode()
    def generate(
        self,
        prompt: str,
        max_gen_len: int = None,
        memory_saver_div: int = 1,
        temperature: float = 0.6,
        top_p: float = 0.9,
        top_k: int = tokenizer.vocab_len,
    ) -> str:
        assert ((memory_saver_div & (memory_saver_div-1)) == 0) & (memory_saver_div > 0), f'memory_saver_div {memory_saver_div} must be power of 2'
        max_context_window = self.max_seq_len // memory_saver_div
        if max_context_window < self.max_seq_len:
            print(f'maximum attention matrix size will be {max_context_window}x{self.max_seq_len} rather than {self.max_seq_len}x{self.max_seq_len}\n')
        tokens = self.tokenizer.encode(prompt)

        if max_gen_len is None:
            max_gen_len = self.max_seq_len - len(tokens)
        elif max_gen_len + len(tokens) > self.max_seq_len:
            print(f'capping max_gen_len at max_seq_len={self.max_seq_len} including input\n')
            max_gen_len = self.max_seq_len - len(tokens)
        tokens = torch.tensor(tokens, device=self.params.device)
        tokens = tokens.unsqueeze(0) if len(tokens.shape)==1 else tokens # jic we need to add a batch dimension
        start_pos = max(tokens.shape[1] - max_context_window, 0)

        for i in range(max_gen_len):
            logits = self.forward_inference(
                tokens[:,-max_context_window:],
                start_pos = start_pos,
                max_context_window = max_context_window
            )
            next_token = self.Sampler(
                logits = logits,
                temperature = temperature,
                top_p = top_p,
                top_k = top_k
            )

            tokens = torch.cat((tokens, next_token), dim=1)
            if tokens.shape[1] >= max_context_window:
                start_pos += 1
        output = self.tokenizer.decode(tokens.squeeze(0).tolist())
        return output

In [48]:
!wget -O input.txt https://raw.githubusercontent.com/prathamesh-mandavkar/llama3-implementation/main/input.txt

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print(text[:200])
chars = sorted(list(set(text)))
v = len(chars)
print(chars)
print(v)


data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

--2025-04-22 19:05:02--  https://raw.githubusercontent.com/prathamesh-mandavkar/llama3-implementation/main/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-04-22 19:05:02 (29.6 MB/s) - ‘input.txt’ saved [1115394/1115394]

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', '

In [49]:
params = ModelArgs()
print(params)
model = Llama3(params, tokenizer).to(params.device)
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')
print(model)

ModelArgs(dim=128, n_layers=8, n_heads=4, n_kv_heads=1, vocab_size=512, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=10000, max_batch_size=32, max_seq_len=512, device='cuda', dropout_rate=0.1)
2033.792 K parameters
Llama3(
  (tok_embeddings): Embedding(512, 128)
  (layers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=128, out_features=128, bias=False)
        (wk): Linear(in_features=128, out_features=32, bias=False)
        (wv): Linear(in_features=128, out_features=32, bias=False)
        (wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=128, out_features=512, bias=False)
        (w2): Linear(in_features=512, out_features=128, bias=False)
        (w3): Linear(in_features=128, out_features=512, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output

In [50]:

def get_batch(split, batch_size):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - params.max_seq_len, (batch_size,))
    x = torch.stack([data[i:i+params.max_seq_len] for i in ix])
    y = torch.stack([data[i+1:i+params.max_seq_len+1] for i in ix])
    x, y = x.to(params.device), y.to(params.device)
    return x, y

In [51]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 5):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, targets=Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [52]:
lr_init = 1e-2
weight_decay = 0.02
optimizer = torch.optim.AdamW(model.parameters(), lr=lr_init, weight_decay=weight_decay)
max_iters = 1000
eval_interval = 50
warmup_iters = 10
warmup_factor = 1e-3
lr_final = 1e-5

def lr_lambda(current_iter):
    if current_iter < warmup_iters:
        return warmup_factor + (1 - warmup_factor) * current_iter / warmup_iters
    else:
        decay_iters = max_iters - warmup_iters
        cosine_decay = 0.5 * (1 + math.cos(math.pi * (current_iter - warmup_iters) / decay_iters))
        return max(cosine_decay, lr_final / lr_init)

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [53]:
start_time = time.time()
for iter in range(max_iters):
    xb, yb = get_batch('train', params.max_batch_size)
    logits, loss = model(xb, targets=yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    scheduler.step()
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, params.max_batch_size)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"step {iter:04d}: lr {current_lr:.6f}, train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)



step 0000: lr 0.001009, train loss 6.3484, val loss 6.3537, time elapsed: 1.03 seconds
step 0050: lr 0.009958, train loss 6.3550, val loss 6.3538, time elapsed: 11.60 seconds
step 0100: lr 0.009793, train loss 6.3520, val loss 6.3527, time elapsed: 22.37 seconds
step 0150: lr 0.009508, train loss 6.3531, val loss 6.3533, time elapsed: 33.21 seconds
step 0200: lr 0.009109, train loss 6.3545, val loss 6.3479, time elapsed: 44.09 seconds
step 0250: lr 0.008608, train loss 6.3530, val loss 6.3566, time elapsed: 55.02 seconds
step 0300: lr 0.008015, train loss 6.3558, val loss 6.3513, time elapsed: 66.04 seconds
step 0350: lr 0.007347, train loss 6.3567, val loss 6.3524, time elapsed: 77.11 seconds
step 0400: lr 0.006620, train loss 6.3551, val loss 6.3543, time elapsed: 88.26 seconds
step 0450: lr 0.005853, train loss 6.3548, val loss 6.3526, time elapsed: 99.49 seconds
step 0500: lr 0.005063, train loss 6.3538, val loss 6.3530, time elapsed: 110.81 seconds
step 0550: lr 0.004273, train lo

In [66]:
!wget https://github.com/prathamesh-mandavkar/llama3-implementation/raw/main/models/Llama3_2024-10-05_23-12-27.pth
!wget https://github.com/prathamesh-mandavkar/llama3-implementation/raw/main/models/Llama3_2024-10-05_23-12-27.json
name = 'Llama3_2024-10-05_23-12-27'

with open(f'{name}.json', 'r', encoding='utf-8') as f:
    params_dict = json.load(f)

params = ModelArgs(**params_dict)
params.device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Llama3(params, tokenizer).to(params.device)
path = f'{name}.pth'
model.load_state_dict(torch.load(path))
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')
model.eval()

--2025-04-22 19:28:52--  https://github.com/prathamesh-mandavkar/llama3-implementation/raw/main/models/Llama3_2024-10-05_23-12-27.pth
Resolving github.com (github.com)... 140.82.116.4
Connecting to github.com (github.com)|140.82.116.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/prathamesh-mandavkar/llama3-implementation/main/models/Llama3_2024-10-05_23-12-27.pth [following]
--2025-04-22 19:28:53--  https://raw.githubusercontent.com/prathamesh-mandavkar/llama3-implementation/main/models/Llama3_2024-10-05_23-12-27.pth
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9213026 (8.8M) [application/octet-stream]
Saving to: ‘Llama3_2024-10-05_23-12-27.pth’


2025-04-22 19:28:53 (115 MB/s) - ‘Llama3_

Llama3(
  (tok_embeddings): Embedding(512, 128)
  (layers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=128, out_features=128, bias=False)
        (wk): Linear(in_features=128, out_features=32, bias=False)
        (wv): Linear(in_features=128, out_features=32, bias=False)
        (wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=128, out_features=512, bias=False)
        (w2): Linear(in_features=512, out_features=128, bias=False)
        (w3): Linear(in_features=128, out_features=512, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=128, out_features=512, bias=False)
  (criterion): CrossEntropyLoss()
)

In [67]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R"

In [68]:
print(model.generate(input_str))

JULIET:
O Romeo, Romeo! wherefore art thou Romeo?

JULIET:
I will be king, and I shall be not be too much:
I am so many to a temple of me.

Nurse:
Well, so look to be so, thou art revenged on,
The time hath I could been to death.

Nurse:
My lord, I would have I will be all great kiss,
And so I pray thee, and in my true heart.

Nurse:
The morning calls greet me in my life,
I will be spoke with him to the face.

Nurse:
We shall be so, thou art a bark, woman!

JULIET:
I am forgot to thee to the corse.

JULIET:
And thou art good madam; let me be gone?

Nurse:
I am my true, and go intent; and there is not.

JULIET:
I would it be a courage of me:
Then let them hear me to follow us with a tear.

JULIET:
What say you do you?

JULIET:
I will not bid you that I can not talk of.

Nurse:
No more than you have of a man to be done:
I am a servant is far in


In [69]:
output = model.generate(
    input_str,
    max_gen_len = params.max_seq_len - len(input_str),
    memory_saver_div = 8,
    temperature = 0.6,
    top_p = 0.9,
    top_k = 32,
)
print(output)

maximum attention matrix size will be 64x512 rather than 512x512

JULIET:
O Romeo, Romeo! wherefore art thou Romeo?

Nurse:
I will be satisfied, for I am so speak too.

JULIET:
I will confess to me with him to the king.

Nurse:
I am so farther to my heart as we can.

Nurse:

Nurse:
No, if it be so, sir.

Nurse:
Good friend, I will. Friar, ho!

Nurse:
I am the devil of this man in a fear,
That it will not what you shall see him hate,
That I will have done to be a man's death.

JULIET:
Go, then, what news?

Nurse:
I have been gone.

Nurse:
What says the care?

JULIET:
Ay, for I will content it, I would not know.

JULIET:
I know my son, my lord.

JULIET:
I do not go to be home.

JULIET:
I am a little day of my bosom.

Nurse:
Then then, and says the king hath been my face?

Nurse:
I am so hour to me; and my fault is in mine.

Nurse:
I will be her hence shall I be so draw there.

Nurse:
That is my
