###### we convert the original GPT architecture into a Llama 2 model step by step (note the GPT and GPT-2 share the same architecture)
###### Why not Llama 1 or Llama 3?
###### The Llama 1 architecture is similar to Llama 2, except that Llama 2 has a larger context window (which is nice); the Llama 1 weights are not readily available and have more usage restrictions, so it makes more sense to focus on Llama 2
###### Regarding Llama 3, I will share a separate notebook to convert Llama 2 to Llama 3 (there are only a few small additional changes)

###### Packages that are being used in this notebook:

In [None]:
from importlib.metadata import version

pkgs = [
    "huggingface_hub",  # to download pretrained weights
    "sentencepiece",    # to implement the tokenizer
    "torch",            # to implement the model
]
for p in pkgs:
    print(f"{p} version: {version(p)}")

 
#### 1. Convert the GPT model implementation step by step
######  go through the GPT model code from chapter 4 and modify it step by step to implement the Llama 2 architecture
###### Later, we load the original Llama 2 weights shared by Meta AI

In [None]:
	
import torch
import torch.nn as nn


#####################################
# Chapter 4
#####################################

# class LayerNorm(nn.Module):
#     def __init__(self, emb_dim):
#         super().__init__()
#         self.eps = 1e-5
#         self.scale = nn.Parameter(torch.ones(emb_dim))
#         self.shift = nn.Parameter(torch.zeros(emb_dim))

#     def forward(self, x):
#         mean = x.mean(dim=-1, keepdim=True)
#         var = x.var(dim=-1, keepdim=True, unbiased=False)
#         norm_x = (x - mean) / torch.sqrt(var + self.eps)
#         return self.scale * norm_x + self.shift


class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.emb_dim = emb_dim
        self.weight = nn.Parameter(torch.ones(emb_dim)).float()

    def forward(self, x):
        means = x.pow(2).mean(dim=-1, keepdim=True)
        x_normed = x * torch.rsqrt(means + self.eps)
        return (x_normed * self.weight).to(dtype=x.dtype)

###### The following code cell checks that this implementation works the same as PyTorch's built-in implementation:

In [None]:
torch.manual_seed(123)

example_batch = torch.randn(2, 3, 4)

rms_norm = RMSNorm(emb_dim=example_batch.shape[-1])
rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)

assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))

#### 1.2 Replace GELU with SiLU activation

In [None]:
#####################################
# Chapter 4
#####################################

# class GELU(nn.Module):
#     def __init__(self):
#         super().__init__()

#     def forward(self, x):
#         return 0.5 * x * (1 + torch.tanh(
#             torch.sqrt(torch.tensor(2.0 / torch.pi)) *
#             (x + 0.044715 * torch.pow(x, 3))
#         ))


class SiLU(nn.Module):
    def __init__(self):
        super(SiLU, self).__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)

In [None]:
silu = SiLU()

assert torch.allclose(silu(example_batch), torch.nn.functional.silu(example_batch))

#### 1.3 Update the FeedForward module
###### In fact, Llama uses a "Gates Linear Unit" (GLU) variant of SiLU called SwiGLU, which essentially results in a slightly differently structured FeedForward module
###### SwiGLU uses a gating mechanism in the feedforward layer, with the formula:

###### Here, 
######  and 
###### are two linear layers, and  denotes element-wise multiplication

###### The third linear layer, 
###### , is applied after this gated activation

###### For more information, see SwiGLU paper: GLU Variants Improve Transformer (2020)

In [None]:
#####################################
# Chapter 4
#####################################
# class FeedForward(nn.Module):
#     def __init__(self, cfg):
#         super().__init__()
#         self.layers = nn.Sequential(
#             nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
#             GELU(),
#             nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
#         )

#     def forward(self, x):
#         return self.layers(x)

In [None]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
        self.silu = SiLU()

    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = self.silu(x_fc1) * x_fc2
        return self.fc3(x)

###### Note that we also added a dtype=cfg["dtype"] setting above, which will allow us to load the model directly in lower precision formats later to reduce memory usage (versus instantiating it in the original 32-bit precision format and then converting it)
###### We also set bias=False since Llama doesn't use any bias units

#### 1.4 Implement RoPE
###### In the GPT model, the positional embeddings are implemented as follows:
###### self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
###### Unlike traditional absolute positional embeddings, Llama uses rotary position embeddings (RoPE), which enable it to capture both absolute and relative positional information simultaneously
###### The reference paper for RoPE is RoFormer: Enhanced Transformer with Rotary Position Embedding (2021)

In [None]:
	
def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))

    # Generate position indices
    positions = torch.arange(context_length)

    # Compute the angles
    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)

    # Expand angles to match the head_dim
    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)

    # Precompute sine and cosine
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin

def compute_rope(x, cos, sin):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2]  # First half
    x2 = x[..., head_dim // 2 :]  # Second half

    # Adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotary transformation
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    return x_rotated.to(dtype=x.dtype)

###### The following is an example of applying RoPE to the q and k tensors:

In [None]:
# Settings
batch_size = 2
context_len = 5
num_heads = 4
head_dim = 16

# Instantiate RoPE parameters
cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)

# Apply rotary position embeddings
queries_rot = compute_rope(queries, cos, sin)
keys_rot = compute_rope(keys, cos, sin)

#### 1.5 Add RoPE to MultiHeadAttention module
###### It's important to note that GPT applies the positional embeddings to the inputs, whereas Llama applies rotations to the query and key vectors in the self-attention mechanism itself
###### Here, we modify the MultiHeadAttention class with the appropriate RoPE code
###### In addition, we remove the qkv_bias option and hardcode the bias=False setting
###### Also, we add a dtype setting to be able to instantiate the model with a lower precision later
###### Tip: since the TransformerBlocks (in the next section) are repeated exactly, we could simplify the code and only initialize the buffers once instead for each MultiHeadAttention module; however, we add the precomputed RoPE parameters to the MultiHeadAttention class so that it can function as a standalone module