## This is a basic implimentainion of the Llama 3 without crying 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
# vocabulary length. Llama's real vocab size is 128256. Here let's just use an absurdly small number
v = 10

# Llama's maximum sequence length is 8192, but for inference they cache 3/4 of it and only use an effective length of 2048. more on that later
seq_len = 5

# we'll use a batch size of 1 for simplicity when visualizing our tensors
b = 1

# now let's make ourselves a list of token indices. Each represents somewhere between a letter and a word
tokens = torch.randint(v, (b, seq_len))
tokens.shape, tokens

(torch.Size([1, 5]), tensor([[5, 3, 9, 6, 9]]))

## 1 Initilizing the first residual state

In [3]:
# embedding dimention for toy llama 3b but Llama 3 8b uses 4096
d = 16

# initilizing token embeddings matrix
embedding = nn.Embedding(v, d)
embedding.weight.shape, embedding.weight
# each row in this embedding is high dimentinal representation of its corresponding token

(torch.Size([10, 16]),
 Parameter containing:
 tensor([[ 0.6211, -0.3046,  1.2291,  0.0654,  0.0206,  0.3026,  0.5760, -0.9168,
          -1.3130, -0.6738, -2.5882, -2.4276, -1.4069,  1.1768, -0.9213, -0.4535],
         [ 2.4885,  0.1856, -0.0444, -0.9305, -0.8918,  1.7270, -0.7148,  0.3457,
           1.2337,  1.3811,  1.0488,  0.0633,  0.1826,  1.4127, -1.6101,  0.1325],
         [-0.8538, -0.9467,  1.0717, -0.8660, -0.0293,  0.2173, -2.4219,  0.1731,
           0.9239, -0.1767,  0.4690,  1.0902, -1.0578, -0.8805,  0.9535,  0.2885],
         [ 1.0089, -0.0534, -0.7067,  0.1578, -0.3137, -0.5377,  0.1424, -0.1772,
          -1.6768,  0.7786,  0.5583,  0.3120,  1.4882,  1.0504, -0.8779,  0.6144],
         [-0.3958, -0.3728,  2.0452,  0.4676,  2.3109, -0.7293,  2.8377, -1.5640,
           0.6458, -0.0217,  1.0460, -1.1739, -0.2615, -0.1555, -0.4311, -0.5748],
         [-0.0508, -0.6860,  0.2296,  1.1494,  0.7252,  2.0619,  0.1963, -0.1768,
          -0.3589, -1.4255, -0.0347,  1.7100, -

In [4]:
# grambbing the embeddings that correspond to our sequence of token indices
x = embedding(tokens)
x.shape, x
# at this points many models would multiply the embeddings by the square root of the embedding dimension, but Llama 3 foregoes that strategy

(torch.Size([1, 5, 16]),
 tensor([[[-0.0508, -0.6860,  0.2296,  1.1494,  0.7252,  2.0619,  0.1963,
           -0.1768, -0.3589, -1.4255, -0.0347,  1.7100, -3.1712,  0.1504,
           -0.8984,  0.5605],
          [ 1.0089, -0.0534, -0.7067,  0.1578, -0.3137, -0.5377,  0.1424,
           -0.1772, -1.6768,  0.7786,  0.5583,  0.3120,  1.4882,  1.0504,
           -0.8779,  0.6144],
          [ 0.5823, -1.7120, -1.5994, -0.5196,  0.5663,  0.2875,  0.3768,
            0.3237,  0.5569, -1.7216, -1.8272,  0.2705, -0.3602,  0.5064,
           -0.1078,  0.6862],
          [-0.2742, -0.5251, -1.4861, -2.5345,  0.6212, -1.4692, -0.5230,
           -0.4949, -0.3542,  0.3015, -1.0694, -1.5897, -0.7937, -1.3195,
            1.8830,  0.4581],
          [ 0.5823, -1.7120, -1.5994, -0.5196,  0.5663,  0.2875,  0.3768,
            0.3237,  0.5569, -1.7216, -1.8272,  0.2705, -0.3602,  0.5064,
           -0.1078,  0.6862]]], grad_fn=<EmbeddingBackward0>))

## Precompute our RoPE Frequencies

In [5]:
theta = 10000 # 10,000 is the most common value but Llama 3 uses 50,000. In theory smaller models should use a smaller value
num_heads = 4 # Llama 3 8b has 32 total attention heads
head_dim = d // num_heads # Llama 3 ties its head dimension to the embedding dimension. This value comes out to 128 in Llama 3, which is purposeful to

freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))
print(f'freqs: {freqs.shape}\n{freqs}\n')

t = torch.arange(seq_len * 2, device=freqs.device, dtype=torch.float32)
print(f't: {t.shape}\n{t}\n')

freqs = torch.outer(t, freqs)
print(f'freqs: {freqs.shape}\n{freqs}\n')

freqs_cis = torch.polar(torch.ones_like(freqs), freqs)[:seq_len]  # complex64
print(f'freqs_cis: {freqs_cis.shape}\n{freqs_cis}') 

freqs: torch.Size([2])
tensor([1.0000, 0.0100])

t: torch.Size([10])
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])

freqs: torch.Size([10, 2])
tensor([[0.0000, 0.0000],
        [1.0000, 0.0100],
        [2.0000, 0.0200],
        [3.0000, 0.0300],
        [4.0000, 0.0400],
        [5.0000, 0.0500],
        [6.0000, 0.0600],
        [7.0000, 0.0700],
        [8.0000, 0.0800],
        [9.0000, 0.0900]])

freqs_cis: torch.Size([5, 2])
tensor([[ 1.0000+0.0000j,  1.0000+0.0000j],
        [ 0.5403+0.8415j,  0.9999+0.0100j],
        [-0.4161+0.9093j,  0.9998+0.0200j],
        [-0.9900+0.1411j,  0.9996+0.0300j],
        [-0.6536-0.7568j,  0.9992+0.0400j]])


### 1d. Precomputing the Causal Mask
<a id='d'></a>

Similar to RoPE embeddings, the causal mask is another part of the attention mechanism that we can create ahead of time to then be reused in every layer.

The basic idea of a causal mask is that by default, attention mechanisms allow every single token to pay attention to every single other token. This is okay or even preferable for some model types, but Llama is auto-regressive, meaning it would be bad if a given token to be predicted was able to see itself and future tokens during training but not during inference. The negative infinity's in the upper-triangle prevent the model from attending to the corresponding token; how this works will be more clear later when we do the attention softmax

In [6]:
mask = torch.full(
    (seq_len, seq_len),
    float("-inf")
)
mask = torch.triu(mask, diagonal=1)
mask

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

## Normalization (RMS Norm)
<a id='e'></a>

Root Mean Square Normalization has also been the norm for quite awhile. Like its predecessor LayerNorm, RMSNorm restricts the variability of the entries in each embedding vector such that the vector lies on a hypersphere with radius $\sqrt{d}$. However unlike LayerNorm which centers that hypersphere with a mean of zero, RMSNorm does not mess with the mean, which is an important source of data for networks that utilize residual connections.

In [7]:
# first setup the residual connection that will be used later
h = x
print(f"h: {h.shape}\n{h}\n")

h: torch.Size([1, 5, 16])
tensor([[[-0.0508, -0.6860,  0.2296,  1.1494,  0.7252,  2.0619,  0.1963,
          -0.1768, -0.3589, -1.4255, -0.0347,  1.7100, -3.1712,  0.1504,
          -0.8984,  0.5605],
         [ 1.0089, -0.0534, -0.7067,  0.1578, -0.3137, -0.5377,  0.1424,
          -0.1772, -1.6768,  0.7786,  0.5583,  0.3120,  1.4882,  1.0504,
          -0.8779,  0.6144],
         [ 0.5823, -1.7120, -1.5994, -0.5196,  0.5663,  0.2875,  0.3768,
           0.3237,  0.5569, -1.7216, -1.8272,  0.2705, -0.3602,  0.5064,
          -0.1078,  0.6862],
         [-0.2742, -0.5251, -1.4861, -2.5345,  0.6212, -1.4692, -0.5230,
          -0.4949, -0.3542,  0.3015, -1.0694, -1.5897, -0.7937, -1.3195,
           1.8830,  0.4581],
         [ 0.5823, -1.7120, -1.5994, -0.5196,  0.5663,  0.2875,  0.3768,
           0.3237,  0.5569, -1.7216, -1.8272,  0.2705, -0.3602,  0.5064,
          -0.1078,  0.6862]]], grad_fn=<EmbeddingBackward0>)



In [8]:
# perfroming first normalization
# first squash each entry in x and then take the mean of those values across each embedding
mean_squared = x.pow(2).mean(dim=-1,  keepdim=True)
mean_squared

tensor([[[1.4363],
         [0.6424],
         [0.8939],
         [1.3787],
         [0.8939]]], grad_fn=<MeanBackward1>)

In [9]:
# then multiply x by the recirocal of the square roots of mean_squared
# 1e-6 is very small number added for stability just in case an entry happens to be equal to 0 (since you can't divide by 0)
x_normed = x / torch.rsqrt(mean_squared + 1e-6)
print(f'x_normed: {x_normed.shape}\n{x_normed}\n')

x_normed: torch.Size([1, 5, 16])
tensor([[[-0.0609, -0.8221,  0.2751,  1.3775,  0.8692,  2.4711,  0.2352,
          -0.2119, -0.4301, -1.7084, -0.0416,  2.0494, -3.8005,  0.1803,
          -1.0767,  0.6718],
         [ 0.8086, -0.0428, -0.5664,  0.1265, -0.2514, -0.4310,  0.1142,
          -0.1420, -1.3439,  0.6241,  0.4475,  0.2501,  1.1927,  0.8419,
          -0.7036,  0.4925],
         [ 0.5505, -1.6187, -1.5122, -0.4913,  0.5354,  0.2718,  0.3563,
           0.3061,  0.5265, -1.6278, -1.7276,  0.2557, -0.3405,  0.4788,
          -0.1019,  0.6488],
         [-0.3220, -0.6166, -1.7449, -2.9760,  0.7294, -1.7251, -0.6141,
          -0.5811, -0.4159,  0.3540, -1.2557, -1.8666, -0.9319, -1.5493,
           2.2110,  0.5379],
         [ 0.5505, -1.6187, -1.5122, -0.4913,  0.5354,  0.2718,  0.3563,
           0.3061,  0.5265, -1.6278, -1.7276,  0.2557, -0.3405,  0.4788,
          -0.1019,  0.6488]]], grad_fn=<DivBackward0>)



In [10]:
# now time to multiply the learnable parameter the gamma and beta 
# This scale is initialized to 1's but if we were to train then these values will change
rms_scale = torch.ones(d)
print(f'rms_scale: {rms_scale.shape}\n{rms_scale}\n')

x_normed *= rms_scale
print(f'x_normed: {x_normed.shape}\n{x_normed}\n')

rms_scale: torch.Size([16])
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

x_normed: torch.Size([1, 5, 16])
tensor([[[-0.0609, -0.8221,  0.2751,  1.3775,  0.8692,  2.4711,  0.2352,
          -0.2119, -0.4301, -1.7084, -0.0416,  2.0494, -3.8005,  0.1803,
          -1.0767,  0.6718],
         [ 0.8086, -0.0428, -0.5664,  0.1265, -0.2514, -0.4310,  0.1142,
          -0.1420, -1.3439,  0.6241,  0.4475,  0.2501,  1.1927,  0.8419,
          -0.7036,  0.4925],
         [ 0.5505, -1.6187, -1.5122, -0.4913,  0.5354,  0.2718,  0.3563,
           0.3061,  0.5265, -1.6278, -1.7276,  0.2557, -0.3405,  0.4788,
          -0.1019,  0.6488],
         [-0.3220, -0.6166, -1.7449, -2.9760,  0.7294, -1.7251, -0.6141,
          -0.5811, -0.4159,  0.3540, -1.2557, -1.8666, -0.9319, -1.5493,
           2.2110,  0.5379],
         [ 0.5505, -1.6187, -1.5122, -0.4913,  0.5354,  0.2718,  0.3563,
           0.3061,  0.5265, -1.6278, -1.7276,  0.2557, -0.3405,  0.4788,
          -0.1019, 

In [11]:
# RMS function
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight 

## Initilize Multi-Query Attention
<a id='f'></a>
[multi-query attention](https://arxiv.org/abs/1911.02150) is the de facto standard for saving on parameter counts in order to get a bigger model. The idea is that the model can make multiple queries to the residual state and have those many queries be answered by shared keys & values.

In [12]:
# x is for the residual connection and x_normed will go into our Attention calculation
h, x_normed

(tensor([[[-0.0508, -0.6860,  0.2296,  1.1494,  0.7252,  2.0619,  0.1963,
           -0.1768, -0.3589, -1.4255, -0.0347,  1.7100, -3.1712,  0.1504,
           -0.8984,  0.5605],
          [ 1.0089, -0.0534, -0.7067,  0.1578, -0.3137, -0.5377,  0.1424,
           -0.1772, -1.6768,  0.7786,  0.5583,  0.3120,  1.4882,  1.0504,
           -0.8779,  0.6144],
          [ 0.5823, -1.7120, -1.5994, -0.5196,  0.5663,  0.2875,  0.3768,
            0.3237,  0.5569, -1.7216, -1.8272,  0.2705, -0.3602,  0.5064,
           -0.1078,  0.6862],
          [-0.2742, -0.5251, -1.4861, -2.5345,  0.6212, -1.4692, -0.5230,
           -0.4949, -0.3542,  0.3015, -1.0694, -1.5897, -0.7937, -1.3195,
            1.8830,  0.4581],
          [ 0.5823, -1.7120, -1.5994, -0.5196,  0.5663,  0.2875,  0.3768,
            0.3237,  0.5569, -1.7216, -1.8272,  0.2705, -0.3602,  0.5064,
           -0.1078,  0.6862]]], grad_fn=<EmbeddingBackward0>),
 tensor([[[-0.0609, -0.8221,  0.2751,  1.3775,  0.8692,  2.4711,  0.2352,
   

In [13]:
# let's define the hyperparameters of MQA
num_kv_heads = 2 # Llama uses 8 key and value heads per layer
assert num_heads % num_kv_heads == 0 # each q needs to match up to a kv
print(f"as a reminder: num_heads = {num_heads}, head_dim = {head_dim}")

as a reminder: num_heads = 4, head_dim = 4


In [14]:
# self-attention weight matrices
wq = nn.Linear(d, num_heads * head_dim, bias=False)
wk = nn.Linear(d, num_kv_heads * head_dim, bias=False)
wv = nn.Linear(d, num_kv_heads * head_dim, bias=False)
print("Attention weights: ", wq.weight.shape, wk.weight.shape, wv.weight.shape)

# and project x_normed out to get our queries, keys and values
xq = wq(x_normed)
xk = wk(x_normed)
xv = wv(x_normed)
print("Attention projections: ", xq.shape, xk.shape, xv.shape)

# then reshape them to separate out by head
xq = xq.view(b, seq_len, num_heads, head_dim)
xk = xk.view(b, seq_len, num_kv_heads, head_dim)
xv = xv.view(b, seq_len, num_kv_heads, head_dim)
print("Reshaped: ", xq.shape, xk.shape, xv.shape)

Attention weights:  torch.Size([16, 16]) torch.Size([8, 16]) torch.Size([8, 16])
Attention projections:  torch.Size([1, 5, 16]) torch.Size([1, 5, 8]) torch.Size([1, 5, 8])
Reshaped:  torch.Size([1, 5, 4, 4]) torch.Size([1, 5, 2, 4]) torch.Size([1, 5, 2, 4])


## RoPE 

In [15]:
xq = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
print(f'xq: {xq.shape}\n{xq}\n')
print(f'xk: {xk.shape}\n{xk}')

xq: torch.Size([1, 5, 4, 2])
tensor([[[[ 0.0605+0.4496j, -0.1676+0.9707j],
          [-0.8873-0.5318j,  0.1048-0.2221j],
          [ 0.8685-1.4743j, -1.2008+0.8642j],
          [ 1.2625+0.1612j,  1.0077+0.7350j]],

         [[-0.3397-0.1445j,  0.3150+0.8894j],
          [-0.6021-0.1736j, -0.3982+0.0303j],
          [-0.3812+0.3620j, -0.2967+0.2477j],
          [-0.0091+0.2755j, -0.0595-0.0617j]],

         [[-0.0409+0.1300j,  0.2640-0.4343j],
          [-0.2360+0.1575j,  0.1715+0.1449j],
          [ 0.7209-0.6044j, -0.5415+0.8812j],
          [ 0.3178-1.5954j,  0.2347+0.1981j]],

         [[ 0.7335+0.4946j,  0.1636-1.6794j],
          [ 1.0972-0.1658j,  1.0036+0.3853j],
          [ 0.7999+0.4196j,  0.2458-0.8357j],
          [-0.7795-1.3275j, -1.5838+0.9322j]],

         [[-0.0409+0.1300j,  0.2640-0.4343j],
          [-0.2360+0.1575j,  0.1715+0.1449j],
          [ 0.7209-0.6044j, -0.5415+0.8812j],
          [ 0.3178-1.5954j,  0.2347+0.1981j]]]],
       grad_fn=<ViewAsComplexBackward0>)

In [16]:
ndim = xq.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (xq.shape[1], xq.shape[-1]), f'freqs_cis.shape {freqs_cis.shape} != xq.shape[1], xq.shape[-1] {(xq.shape[1], xq.shape[-1])}'

# reshape our queries
shape = [d if i == 1 or i == xq.ndim - 1 else 1 for i, d in enumerate(xq.shape)]
print(f'shape: {shape}\n')

freqs_cis = freqs_cis.view(*shape)
print(f'freqs_cis: {freqs_cis.shape}\n{freqs_cis}')

shape: [1, 5, 1, 2]

freqs_cis: torch.Size([1, 5, 1, 2])
tensor([[[[ 1.0000+0.0000j,  1.0000+0.0000j]],

         [[ 0.5403+0.8415j,  0.9999+0.0100j]],

         [[-0.4161+0.9093j,  0.9998+0.0200j]],

         [[-0.9900+0.1411j,  0.9996+0.0300j]],

         [[-0.6536-0.7568j,  0.9992+0.0400j]]]])


In [17]:
# now multiply the data by the frequencies, turn them back into real numbers, revert the shape and make sure they're of the right type
xq = torch.view_as_real(xq * freqs_cis).flatten(3).type_as(xv)
xk = torch.view_as_real(xk * freqs_cis).flatten(3).type_as(xv)
print(f'xq: {xq.shape}\n{xq}\n')
print(f'xk: {xk.shape}\n{xk}')

xq: torch.Size([1, 5, 4, 4])
tensor([[[[ 0.0605,  0.4496, -0.1676,  0.9707],
          [-0.8873, -0.5318,  0.1048, -0.2221],
          [ 0.8685, -1.4743, -1.2008,  0.8642],
          [ 1.2625,  0.1612,  1.0077,  0.7350]],

         [[-0.0619, -0.3639,  0.3061,  0.8925],
          [-0.1792, -0.6004, -0.3985,  0.0263],
          [-0.5106, -0.1252, -0.2991,  0.2447],
          [-0.2367,  0.1412, -0.0589, -0.0623]],

         [[-0.1012, -0.0913,  0.2726, -0.4289],
          [-0.0450, -0.2801,  0.1686,  0.1483],
          [ 0.2495,  0.9070, -0.5590,  0.8702],
          [ 1.3185,  0.9529,  0.2307,  0.2028]],

         [[-0.7960, -0.3862,  0.2139, -1.6738],
          [-1.0629,  0.3190,  0.9916,  0.4152],
          [-0.8511, -0.3025,  0.2708, -0.8279],
          [ 0.9590,  1.2042, -1.6110,  0.8843]],

         [[ 0.1251, -0.0541,  0.2812, -0.4234],
          [ 0.2734,  0.0757,  0.1656,  0.1517],
          [-0.9286, -0.1506, -0.5763,  0.8588],
          [-1.4151,  0.8024,  0.2266,  0.2073]]]], 

## Self Attention Calculation

In [18]:
# If the number of K & V heads is different from the number of query heads, adjusts keys and values to match the query heads count.
if num_kv_heads != num_heads:
  num_queries_per_kv = num_heads // num_kv_heads
  xk = torch.repeat_interleave(xk, num_queries_per_kv, dim=2)
  xv = torch.repeat_interleave(xv, num_queries_per_kv, dim=2)

xq.shape, xk.shape, xv.shape

(torch.Size([1, 5, 4, 4]), torch.Size([1, 5, 4, 4]), torch.Size([1, 5, 4, 4]))

In [19]:
# Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)

xq.shape, xk.shape, xv.shape

(torch.Size([1, 4, 5, 4]), torch.Size([1, 4, 5, 4]), torch.Size([1, 4, 5, 4]))

In [20]:
# Calculates attention logits by performing a batch matrix multiplication between queries and keys
scores = torch.matmul(xq, xk.transpose(2, 3))

# then we scale the logits by the reciprocal of the square root of the head dimension
scores = scores / math.sqrt(head_dim)

scores.shape, scores

(torch.Size([1, 4, 5, 5]),
 tensor([[[[ 0.3148,  0.1697,  0.6111,  0.0161,  0.4098],
           [ 0.0668, -0.0682,  0.2327, -0.1408,  0.4025],
           [-0.2006, -0.0947, -0.2434, -0.1355, -0.1460],
           [-0.7688, -0.4630, -1.0483, -0.1572, -0.3864],
           [-0.1056, -0.0057, -0.1782, -0.1063, -0.2434]],
 
          [[-0.4835, -0.4430, -0.4939, -0.1309,  0.2818],
           [-0.0428, -0.1527, -0.2630,  0.2509,  0.0759],
           [-0.0498, -0.0744, -0.0547, -0.0613,  0.0749],
           [-0.5194, -0.3698,  0.0813, -0.7402,  0.6566],
           [ 0.1163,  0.1102,  0.1541, -0.0516, -0.0547]],
 
          [[-0.1182,  0.7197,  0.7664,  0.0950, -0.2429],
           [ 0.2152,  0.0422, -0.1120,  0.2388,  0.3269],
           [ 0.8286, -0.0604,  0.7606,  0.9111,  0.6772],
           [-0.3638, -0.0824, -0.8322, -0.6061, -0.1122],
           [ 0.5882,  0.0650, -0.0476,  0.7882,  0.7606]],
 
          [[-0.1723, -0.1338,  0.5848,  0.4465, -0.5340],
           [ 0.0964, -0.0457, -0.090

In [21]:
# use the mask that we precomputed earlier
scores = scores + mask

scores.shape, scores

(torch.Size([1, 4, 5, 5]),
 tensor([[[[ 0.3148,    -inf,    -inf,    -inf,    -inf],
           [ 0.0668, -0.0682,    -inf,    -inf,    -inf],
           [-0.2006, -0.0947, -0.2434,    -inf,    -inf],
           [-0.7688, -0.4630, -1.0483, -0.1572,    -inf],
           [-0.1056, -0.0057, -0.1782, -0.1063, -0.2434]],
 
          [[-0.4835,    -inf,    -inf,    -inf,    -inf],
           [-0.0428, -0.1527,    -inf,    -inf,    -inf],
           [-0.0498, -0.0744, -0.0547,    -inf,    -inf],
           [-0.5194, -0.3698,  0.0813, -0.7402,    -inf],
           [ 0.1163,  0.1102,  0.1541, -0.0516, -0.0547]],
 
          [[-0.1182,    -inf,    -inf,    -inf,    -inf],
           [ 0.2152,  0.0422,    -inf,    -inf,    -inf],
           [ 0.8286, -0.0604,  0.7606,    -inf,    -inf],
           [-0.3638, -0.0824, -0.8322, -0.6061,    -inf],
           [ 0.5882,  0.0650, -0.0476,  0.7882,  0.7606]],
 
          [[-0.1723,    -inf,    -inf,    -inf,    -inf],
           [ 0.0964, -0.0457,    -in

In [22]:
# now we perform the softmax operation to get our actual probabilities
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores
# notice that thanks to the causal mask, 0 probability is placed on future tokens

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5337, 0.4663, 0.0000, 0.0000, 0.0000],
          [0.3258, 0.3621, 0.3121, 0.0000, 0.0000],
          [0.2017, 0.2739, 0.1525, 0.3718, 0.0000],
          [0.2038, 0.2253, 0.1896, 0.2037, 0.1776]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5275, 0.4725, 0.0000, 0.0000, 0.0000],
          [0.3366, 0.3284, 0.3350, 0.0000, 0.0000],
          [0.2089, 0.2426, 0.3809, 0.1675, 0.0000],
          [0.2118, 0.2106, 0.2200, 0.1791, 0.1785]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5431, 0.4569, 0.0000, 0.0000, 0.0000],
          [0.4264, 0.1753, 0.3983, 0.0000, 0.0000],
          [0.2677, 0.3547, 0.1676, 0.2101, 0.0000],
          [0.2207, 0.1308, 0.1169, 0.2695, 0.2622]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5355, 0.4645, 0.0000, 0.0000, 0.0000],
          [0.2715, 0.1980, 0.5305, 0.0000, 0.0000],
          [0.2776, 0.1077, 0.4017, 0.2130, 0.0000],
      

In [23]:
# then matmul by our values projection
output = torch.matmul(scores, xv)
output.shape, output

(torch.Size([1, 4, 5, 4]),
 tensor([[[[-0.8451,  0.1792, -0.7642, -1.6837],
           [-0.6167, -0.1819, -0.2313, -0.8675],
           [-0.3947, -0.1515, -0.1678, -0.8775],
           [-0.0391, -0.2616,  0.2056, -0.3395],
           [-0.1186, -0.1662,  0.0189, -0.6591]],
 
          [[-0.8451,  0.1792, -0.7642, -1.6837],
           [-0.6137, -0.1867, -0.2242, -0.8566],
           [-0.3912, -0.1291, -0.1930, -0.9239],
           [-0.1505, -0.1620, -0.0085, -0.6970],
           [-0.1340, -0.1464, -0.0172, -0.7192]],
 
          [[ 0.1289,  0.1438,  0.1470, -2.1349],
           [-0.1190,  0.1893,  0.0623, -1.2832],
           [-0.0828,  0.0757,  0.1174, -1.2552],
           [ 0.0698,  0.0708, -0.1399, -0.6530],
           [ 0.1811, -0.0174, -0.1598, -0.6104]],
 
          [[ 0.1289,  0.1438,  0.1470, -2.1349],
           [-0.1231,  0.1901,  0.0609, -1.2689],
           [-0.1338,  0.0496,  0.1141, -1.0293],
           [ 0.1378, -0.0051, -0.0955, -0.7803],
           [ 0.1873, -0.0219, -0.

In [24]:
# and reshape to put the sequence length back into place and the outputs of our heads lined up
output = output.transpose(1, 2).contiguous().view(b, seq_len, -1)
output.shape, output

(torch.Size([1, 5, 16]),
 tensor([[[-0.8451,  0.1792, -0.7642, -1.6837, -0.8451,  0.1792, -0.7642,
           -1.6837,  0.1289,  0.1438,  0.1470, -2.1349,  0.1289,  0.1438,
            0.1470, -2.1349],
          [-0.6167, -0.1819, -0.2313, -0.8675, -0.6137, -0.1867, -0.2242,
           -0.8566, -0.1190,  0.1893,  0.0623, -1.2832, -0.1231,  0.1901,
            0.0609, -1.2689],
          [-0.3947, -0.1515, -0.1678, -0.8775, -0.3912, -0.1291, -0.1930,
           -0.9239, -0.0828,  0.0757,  0.1174, -1.2552, -0.1338,  0.0496,
            0.1141, -1.0293],
          [-0.0391, -0.2616,  0.2056, -0.3395, -0.1505, -0.1620, -0.0085,
           -0.6970,  0.0698,  0.0708, -0.1399, -0.6530,  0.1378, -0.0051,
           -0.0955, -0.7803],
          [-0.1186, -0.1662,  0.0189, -0.6591, -0.1340, -0.1464, -0.0172,
           -0.7192,  0.1811, -0.0174, -0.1598, -0.6104,  0.1873, -0.0219,
           -0.1411, -0.6860]]], grad_fn=<ViewBackward0>))

In [25]:
# finally initializing and apply output projection that mixes the information from the heads together
wo = nn.Linear(num_heads * head_dim, d, bias=False)
Xout = wo(output)
Xout.shape, Xout

(torch.Size([1, 5, 16]),
 tensor([[[ 4.7982e-01,  1.2011e+00, -3.2596e-01, -3.0273e-01,  1.2942e+00,
           -1.5138e+00,  4.5424e-01, -3.9018e-01,  1.1150e-01, -5.3451e-01,
            9.0415e-01, -1.0406e-01,  4.7493e-01,  9.8201e-02, -1.9394e-01,
           -7.4698e-01],
          [ 2.0381e-01,  4.9071e-01, -8.6177e-02, -1.0249e-01,  8.0522e-01,
           -8.4437e-01,  2.3502e-01, -9.6294e-02,  2.2250e-01, -2.2936e-01,
            4.1498e-01, -3.8364e-02,  1.6870e-01,  5.7585e-02, -1.7539e-01,
           -3.1306e-01],
          [ 2.8165e-01,  4.9939e-01, -1.2491e-01, -9.5168e-02,  6.8374e-01,
           -7.3566e-01,  1.7579e-01, -5.3414e-02,  1.0993e-01, -1.8658e-01,
            3.0999e-01,  8.2164e-04,  2.6105e-01,  1.4615e-02, -1.1898e-01,
           -2.1451e-01],
          [ 1.6622e-01,  2.2077e-01, -1.2813e-01, -6.7325e-02,  3.9521e-01,
           -4.2479e-01,  1.1963e-01,  5.0992e-02, -1.8103e-02,  5.8158e-02,
            2.5292e-01,  3.8821e-02,  1.7923e-01,  8.8306e-02, -

## Our First Residual Conenction

In [26]:
h += Xout
h.shape, h

(torch.Size([1, 5, 16]),
 tensor([[[ 0.4290,  0.5151, -0.0964,  0.8467,  2.0194,  0.5481,  0.6505,
           -0.5670, -0.2474, -1.9600,  0.8694,  1.6060, -2.6963,  0.2486,
           -1.0924, -0.1864],
          [ 1.2127,  0.4373, -0.7929,  0.0553,  0.4915, -1.3821,  0.3775,
           -0.2735, -1.4543,  0.5493,  0.9733,  0.2736,  1.6569,  1.1080,
           -1.0533,  0.3014],
          [ 0.8639, -1.2126, -1.7243, -0.6148,  1.2500, -0.4482,  0.5526,
            0.2703,  0.6668, -1.9082, -1.5172,  0.2713, -0.0991,  0.5210,
           -0.2268,  0.4716],
          [-0.1080, -0.3044, -1.6142, -2.6018,  1.0164, -1.8940, -0.4034,
           -0.4439, -0.3723,  0.3597, -0.8165, -1.5508, -0.6144, -1.2312,
            1.7155,  0.3459],
          [ 0.7302, -1.4391, -1.7590, -0.6543,  0.9951, -0.0999,  0.5189,
            0.2620,  0.5243, -1.8334, -1.6524,  0.3472, -0.1787,  0.6154,
           -0.2812,  0.5932]]], grad_fn=<AddBackward0>))

In [27]:
# normalize the current state of our residual for use in our MoE later
pre_ffwd_norm = RMSNorm(d)
h_normed = pre_ffwd_norm(h)
# so now we're working with x, which we'll use later for our next residual conenction, and x_normed which is used by our MoE MLP

### The SwiGLU Feedforward Network
<a id='j'></a>

Llama 3 models have surprisingly not opted for a mixture of experts strategy which i was assuming they'd go for by now. Their feedforward networks use the SwiGLU activation which basically uses the activation function as a gate that dynamically determines what information gets through

In [28]:
# first we need to define our actual hidden dimension, which Llama's code does in an unnecessarily complicated manner
hidden_dim = 4 * d # usually i would designate a hyperparameter for this 4, but in llama's code it was just there
print(hidden_dim)
hidden_dim = int(2 * hidden_dim / 3)
print(hidden_dim)
multiple_of = 256 # their description of this was "make SwiGLU hidden layer size multiple of large power of 2"
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
print(hidden_dim)
# so basically this overly convoluted setup is designed to ensure that hidden_dim is a multiple of 256, likely for hardware efficiency reasons

64
42
256


In [29]:
up = nn.Linear(d, hidden_dim, bias=False)
gate = nn.Linear(d, hidden_dim, bias=False)
down = nn.Linear(hidden_dim, d, bias=False)

In [30]:
up_proj = up(h_normed)
print(up_proj.shape, up_proj)

torch.Size([1, 5, 256]) tensor([[[-0.0642, -0.1545, -0.2192,  ...,  0.9438,  0.0570, -0.3783],
         [ 0.4076, -0.3049,  0.1735,  ...,  0.2229,  0.2467, -0.5810],
         [-0.0606,  0.7448,  0.9840,  ...,  0.6000,  0.2958, -1.6931],
         [ 0.3540,  0.0034,  0.6070,  ..., -0.2506,  0.7748, -0.9145],
         [-0.2998,  0.8691,  1.0339,  ...,  0.3945,  0.1622, -1.5953]]],
       grad_fn=<UnsafeViewBackward0>)


In [31]:
gate_proj = F.silu(gate(h_normed))
print(gate_proj.shape, gate_proj)

torch.Size([1, 5, 256]) tensor([[[-0.0687, -0.1480, -0.1143,  ...,  0.0761, -0.1955,  0.3848],
         [ 0.1410,  0.2848, -0.1783,  ...,  0.4109, -0.2723,  0.3605],
         [ 0.1933,  0.6690,  0.1321,  ...,  0.3607, -0.1581,  0.1268],
         [ 0.0360,  0.3590,  0.5313,  ...,  0.3714,  0.0489,  0.0404],
         [ 0.1203,  0.6909,  0.1560,  ...,  0.3374, -0.1442,  0.0950]]],
       grad_fn=<SiluBackward0>)


In [32]:
ffwd_output = down(up_proj * gate_proj)
print(ffwd_output.shape, ffwd_output)

torch.Size([1, 5, 16]) tensor([[[ 0.0179, -0.0192, -0.1657, -0.0852,  0.0557, -0.0963, -0.1082,
          -0.0335,  0.0246,  0.0104, -0.0752,  0.0661, -0.1234, -0.0924,
           0.0812, -0.1590],
         [-0.0547,  0.0771,  0.1036, -0.1620, -0.0560,  0.0905, -0.3136,
           0.0094,  0.0883, -0.0112, -0.1757,  0.0637, -0.1961, -0.0808,
           0.1623, -0.0319],
         [-0.0080, -0.1601,  0.0289,  0.0709,  0.0767,  0.2686,  0.0853,
           0.0287, -0.1481, -0.0976, -0.1968, -0.0298, -0.2074,  0.0245,
           0.0166, -0.0764],
         [-0.0239, -0.1397, -0.0004, -0.0705,  0.1095,  0.1232, -0.0401,
           0.0346, -0.1058, -0.0731, -0.1919,  0.1236, -0.0659,  0.0992,
           0.0396, -0.1036],
         [ 0.0255, -0.1310,  0.0366,  0.0466,  0.0820,  0.2912,  0.0744,
           0.0321, -0.0764, -0.1134, -0.1699, -0.0249, -0.2140,  0.0253,
           0.0200, -0.1287]]], grad_fn=<UnsafeViewBackward0>)


In [33]:
# and then do our final residual connection of this layer
out = h + ffwd_output
print(out.shape, out)

torch.Size([1, 5, 16]) tensor([[[ 0.4469,  0.4960, -0.2621,  0.7615,  2.0751,  0.4517,  0.5423,
          -0.6005, -0.2227, -1.9495,  0.7942,  1.6721, -2.8197,  0.1563,
          -1.0112, -0.3454],
         [ 1.1579,  0.5144, -0.6892, -0.1067,  0.4355, -1.2916,  0.0639,
          -0.2641, -1.3660,  0.5381,  0.7976,  0.3373,  1.4607,  1.0272,
          -0.8911,  0.2695],
         [ 0.8559, -1.3728, -1.6954, -0.5439,  1.3268, -0.1796,  0.6378,
           0.2991,  0.5187, -2.0058, -1.7140,  0.2414, -0.3065,  0.5455,
          -0.2102,  0.3952],
         [-0.1319, -0.4440, -1.6146, -2.6724,  1.1260, -1.7708, -0.4435,
          -0.4093, -0.4781,  0.2866, -1.0083, -1.4273, -0.6804, -1.1320,
           1.7551,  0.2423],
         [ 0.7557, -1.5701, -1.7224, -0.6078,  1.0771,  0.1913,  0.5933,
           0.2941,  0.4479, -1.9468, -1.8223,  0.3222, -0.3927,  0.6407,
          -0.2613,  0.4646]]], grad_fn=<AddBackward0>)


### Output
<a id='k'></a>
So usually we'd run it back on steps 1e through 1j for however many layers our model has (Llama 3 8b uses 32) using different weight matrices but you get the point. Since our current `out` is of the same shape that it would be if we were to do more layers, let's go ahead and just see what Llama's output mechanism looks like. It's nothing interesting though, just a linear layer. Notably they chose to use a separate linear layer rather than re-using the embedding layer as is relatively common

In [35]:
# first we norm the residual state
final_norm = RMSNorm(d)
out_normed = final_norm(out)

In [36]:
# then multiply by the linear layer to get our final output logits
final_output = nn.Linear(d, v, bias=False)
logits = final_output(out_normed).float()
logits.shape, logits

(torch.Size([1, 5, 10]),
 tensor([[[-0.0716, -0.2248,  0.3506,  0.1224,  0.0167, -0.1844,  0.1912,
            0.6053,  0.5315, -0.3186],
          [ 0.5396, -1.4051,  0.7336, -0.1780,  0.2618, -0.5472,  0.2408,
           -0.6137,  0.3824, -1.2147],
          [-0.1733, -0.3616,  0.4563,  0.0155,  0.3227,  0.7008,  0.0223,
            0.2701, -0.3240,  0.1441],
          [-0.8815,  0.4323, -0.8879,  1.0441, -0.6592,  0.8555, -0.6688,
            0.5801, -1.1062,  0.8051],
          [-0.2352, -0.3454,  0.3589, -0.0648,  0.2186,  0.7198,  0.1076,
            0.2584, -0.2721,  0.1699]]], grad_fn=<UnsafeViewBackward0>))

In [37]:
# softmax the logits to get the probability for each token's prediction across every token in the sequence
probs = F.softmax(logits, dim=-1)
probs

tensor([[[0.0803, 0.0689, 0.1224, 0.0974, 0.0877, 0.0717, 0.1044, 0.1579,
          0.1467, 0.0627],
         [0.1660, 0.0237, 0.2015, 0.0810, 0.1257, 0.0560, 0.1231, 0.0524,
          0.1418, 0.0287],
         [0.0716, 0.0593, 0.1344, 0.0865, 0.1176, 0.1717, 0.0871, 0.1116,
          0.0616, 0.0984],
         [0.0320, 0.1190, 0.0318, 0.2195, 0.0400, 0.1818, 0.0396, 0.1380,
          0.0256, 0.1728],
         [0.0686, 0.0614, 0.1243, 0.0814, 0.1080, 0.1783, 0.0967, 0.1124,
          0.0661, 0.1029]]], grad_fn=<SoftmaxBackward0>)

In [38]:
# Greedily decode the probabilities to get our final predicted indices
greedy_indices = torch.argmax(probs, dim=-1)
greedy_indices
# if we were performing inference rather than training, that final token in the list would be the one to show the user

tensor([[7, 2, 5, 3, 5]])

### The loss functions
<a id='l'></a>

Of course we use [cross-entropy loss](https://machinelearningmastery.com/cross-entropy-for-machine-learning/) which should need no introduction if this isn't your first machine-learning rodeo, so we'll be skimming past it. Basically the idea is that the single correct value is rewarded and all other values are suppressed

In [39]:
# create some random fake target indices to train on
target_token_indices = torch.randint(0, v, greedy_indices.shape)
print(target_token_indices)

# initialize the loss function
loss_fn = nn.CrossEntropyLoss()

# reshape logits to be compatible and calculate loss
loss = loss_fn(logits.view(1,v,seq_len), target_token_indices)
print(loss)

tensor([[7, 6, 9, 4, 7]])
tensor(1.9709, grad_fn=<NllLoss2DBackward0>)


## Now let's code everything up the correct way into classes so that we can actually build a functioning model

In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from dataclasses import dataclass
from typing import Optional, Tuple

import time

import os
import json

we'll be using a crazy small & simple tokenizer based on the TinyShakespeare dataset
Llama 3 8b's vocabulary size is 128256 including special tokens like <|endoftext|>

In [41]:
import pickle
import os

class SimpleTokenizer:
    def __init__(self, stoi, merges):
        self.stoi = stoi
        self.merges = merges
        self.itos = {i: s for s, i in stoi.items()}  # Inverse mapping for decoding

        self.vocab_len = len(stoi) + len(merges)

    def encode(self, text):
        # Convert the text to a list of token IDs, using space for unknown characters
        tokens = [self.stoi.get(c, self.stoi[' ']) for c in text]

        # Perform merging with the possibility of nested merges
        i = 0
        while i < len(tokens) - 1:
            pair = (tokens[i], tokens[i + 1])
            if pair in self.merges:
                # Replace the current pair with its merged token
                merged_token = self.merges[pair]
                tokens[i] = merged_token
                del tokens[i + 1]

                # Move back to handle possible nested merges
                if i > 0:
                    i -= 1
            else:
                i += 1

        return tokens

    def decode(self, tokens):
        def expand_token(token):
            # Base case: if the token is a direct mapping, return its character
            if token in self.itos:
                return self.itos[token]
            # Recursive case: if the token is a merged token, expand its constituents
            elif token in self.merges.values():
                pair = next(key for key, value in self.merges.items() if value == token)
                return ''.join(expand_token(t) for t in pair)
            # Fallback for unknown tokens
            else:
                return ''

        # Decode each token in the list, handling nested merges recursively
        return ''.join(expand_token(token) for token in tokens)

def load_tokenizer_data(size: int):
    file_name = f'./tokenizers/tiny_shakespeare_tokenizer_{size}.model'
    with open(file_name, 'rb') as f:
        tokenizer_data = pickle.load(f)
    return tokenizer_data

def get_tokenizer(size: int):
    tokenizer_data = load_tokenizer_data(size)
    loaded_stoi = tokenizer_data['stoi']
    loaded_merges = tokenizer_data['merges']
    return SimpleTokenizer(loaded_stoi, loaded_merges)

In [42]:
tokenizer = get_tokenizer(size = 512)

FileNotFoundError: [Errno 2] No such file or directory: './tokenizers/tiny_shakespeare_tokenizer_512.model'