In [None]:
# Download the model weights
import os
import urllib.request

downloads = [
    {
        "filename": "data/llama3-8b/tokenizer.model",
        "url": "https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/1460c22666392e470910ce3d44ffeb2ab7dbd4df/original/tokenizer.model",
    },
    {
        "filename": "data/llama3-8b/consolidated.00.pth",
        "url": "https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/1460c22666392e470910ce3d44ffeb2ab7dbd4df/original/consolidated.00.pth",
    },
]

for download in downloads:
    if not os.path.isfile(download["filename"]):
        os.makedirs(os.path.dirname(download["filename"]), exist_ok=True)
        print(f"Downloading {download["url"]} to {download["filename"]}")
        urllib.request.urlretrieve(download["url"], download["filename"])
    else:
        print(f"File {download["filename"]} already found, skipping download")

In [None]:
# Load the Tiktoken tokenizer
import torch
import micro_llama

tokenizer = micro_llama.make_tokenizer("data/llama3-8B/tokenizer.model")

In [None]:
# Demonstrate the Tiktoken tokenizer
prompt = "the answer to the ultimate question of life, the universe, and everything is "
tokens = tokenizer.encode(prompt)
prompt_ = tokenizer.decode(tokens)

print(prompt)
print(tokens)
print(prompt_)

In [None]:
# Demonstrate the RoPE positional embedding
N = 64
D = 256
theta = 500_000
theta = 5

x = torch.randn(1, D)
x = x.expand(N, D) + torch.randn(N, D) * 0.01
x = x / x.norm(dim=-1, keepdim=True)

y = micro_llama.rope(x.reshape(1, N, 1, D), theta=theta)
y = y.reshape(N,D)

M = x @ x.transpose(-2, -1)
M_ = y @ y.transpose(-2, -1)

from matplotlib import pyplot as plt
plt.figure()
plt.subplot(1,2,1)
plt.imshow(M.detach().numpy())
plt.title("Without RoPE")
plt.subplot(1,2,2)
plt.title("With RoPE")
plt.imshow(M_.detach().numpy())

In [None]:
# Load the LLAMA3 8B model
llama = micro_llama.Llama()
params = torch.load('data/llama3-8B/consolidated.00.pth', weights_only=True)
llama.load_state_dict(params)
llama.eval()

In [None]:
# Demonstrate the LLAMA3 model
prompt = "the answer to the ultimate question of life, the universe, and everything is "
x = torch.tensor([128000] + tokenizer.encode(prompt))
print(tokenizer.decode(list(x)))

y = llama(x.unsqueeze(0))
print(tokenizer.decode(list(y.argmax(dim=-1)[0])))