In [1]:
from tokenizer import Tokenizer
import torch
import json
import math

In [2]:
model = torch.load("Meta-Llama-3-8B/consolidated.00.pth")
tokenizer = Tokenizer("Meta-Llama-3-8B/tokenizer.model")

with open("Meta-Llama-3-8B/params.json", "r") as f:
    config = json.load(f)

rope_theta = torch.tensor(config["rope_theta"])
norm_eps = config["norm_eps"]

In [3]:
def rms_norm(tensor, norm_weights):
    return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights

zero_to_one_split_into_64_parts = torch.tensor(range(64))/64
freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)
freqs_for_each_token = torch.outer(torch.arange(1024), freqs)
freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)

def apply_rotation(x: torch.tensor, freqs_cis_: torch.Tensor):
    x_in_pairs = x.float().reshape(*x.shape[:-1], -1, 2)
    x_complex = torch.view_as_complex(x_in_pairs)
    x_rotated = torch.view_as_real(x_complex * freqs_cis_.unsqueeze(1)).flatten(2)
    return x_rotated.type_as(x)

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    seq_len, n_kv_head, head_dim = x.shape

    return (
        x[:, :, None, :]
        .expand(seq_len, n_kv_head, n_rep, head_dim)
        .reshape(seq_len, n_kv_head * n_rep, head_dim)
    )

class Llama3():
  def __init__(self, config, model, max_seq_len=1024):
    self.model = model
    self.dim = config["dim"]
    self.n_layers = config["n_layers"]
    self.n_heads = config["n_heads"]
    self.head_dim = self.dim // self.n_heads
    self.n_kv_heads = config["n_kv_heads"]
    self.vocab_size = config["vocab_size"]
    self.norm_eps = config["norm_eps"]
    self.n_reps = self.n_heads // self.n_kv_heads
    self.rope_theta = torch.tensor(config["rope_theta"])
    self.embedding_layer = torch.nn.Embedding(self.vocab_size, self.dim)
    self.embedding_layer.weight.data.copy_(model["tok_embeddings.weight"])
    self.k_caches = [torch.zeros(max_seq_len, self.n_kv_heads, self.head_dim, dtype=torch.bfloat16) for _ in range(self.n_layers)]
    self.v_caches = [torch.zeros(max_seq_len, self.n_kv_heads, self.head_dim, dtype=torch.bfloat16) for _ in range(self.n_layers)]
    self.last_xq = None
    self.last_keys = None
    self.last_values = None
    self.last_scores = None

  def forward(self, tokens:torch.tensor, pos: int):
    seq_len = len(tokens)
    embeddings = self.embedding_layer(tokens).to(torch.bfloat16)
    mask = None
    if seq_len > 1:
      mask = torch.full((seq_len, seq_len), float("-inf"))
      mask = torch.triu(mask, diagonal=1)
    for layer in range(self.n_layers):
      layer_embedding_norm = rms_norm(embeddings, model[f"layers.{layer}.attention_norm.weight"])
      q_w = model[f"layers.{layer}.attention.wq.weight"]
      k_w = model[f"layers.{layer}.attention.wk.weight"]
      v_w = model[f"layers.{layer}.attention.wv.weight"]

      xq = torch.matmul(layer_embedding_norm, q_w.T)
      xk = torch.matmul(layer_embedding_norm, k_w.T)
      xv = torch.matmul(layer_embedding_norm, v_w.T)

      xq = xq.view(seq_len, self.n_heads, self.head_dim)
      xk = xk.view(seq_len, self.n_kv_heads, self.head_dim)
      xv = xv.view(seq_len, self.n_kv_heads, self.head_dim)

      xq = apply_rotation(xq, freqs_cis[pos:pos+seq_len])
      xk = apply_rotation(xk, freqs_cis[pos:pos+seq_len])


      k_cache = self.k_caches[layer]
      v_cache = self.v_caches[layer]


      k_cache[pos: pos+ seq_len] = xk
      v_cache[pos: pos+ seq_len] = xv

      keys = k_cache[:pos + seq_len]
      values = v_cache[:pos + seq_len]

      keys = repeat_kv(keys, self.n_reps)
      values = repeat_kv(values, self.n_reps)

      xq = xq.transpose(0,1)
      keys = keys.transpose(0,1)
      values = values.transpose(0,1)

      scores = torch.matmul(xq, keys.transpose(1,2)) / math.sqrt(self.head_dim)
      if mask is not None:
        scores += mask

      scores = torch.nn.functional.softmax(scores.float(), dim=-1).type_as(xq)
      output = torch.matmul(scores, values)
      wo = model[f"layers.{layer}.attention.wo.weight"]
      output = output.transpose(0,1).contiguous().view(seq_len, -1)
      embedding_delta = torch.matmul(output, wo.T)
      embedding_with_attention = embeddings + embedding_delta
      embedding_normalized = rms_norm(embedding_with_attention, model[f"layers.{layer}.ffn_norm.weight"])

      w1 = model[f"layers.{layer}.feed_forward.w1.weight"]
      w2 = model[f"layers.{layer}.feed_forward.w2.weight"]
      w3 = model[f"layers.{layer}.feed_forward.w3.weight"]

      fc_gate =  torch.functional.F.silu(torch.matmul(embedding_normalized, w1.T))
      fc_up = torch.matmul(embedding_normalized, w3.T)
      output_ffn = torch.matmul(fc_gate * fc_up, w2.T)
      embeddings = embedding_with_attention + output_ffn
    embeddings = rms_norm(embeddings, model[f"norm.weight"])
    logits = torch.matmul(embeddings[-1], model["output.weight"].T)
    return logits

  def generate(self, prompt: str, max_len: int):
    tokens = tokenizer.encode(prompt, bos=True, eos=False)
    logits = self.forward(torch.tensor(tokens), 0)
    next_token = torch.argmax(logits).item()
    tokens.append(next_token)
    for i in range(max_len - len(tokens)):
      logits = self.forward(torch.tensor(tokens[-1]).reshape(1), len(tokens) -1)
      next_token = torch.argmax(logits).item()
      tokens.append(next_token)
      print(tokenizer.decode(tokens))
    return tokenizer.decode(tokens)


In [5]:
llama = Llama3(config, model)
prompt = "Richard Feynman was a "
res = llama.generate(prompt, 30)


<|begin_of_text|>Richard Feynman was a 20th
<|begin_of_text|>Richard Feynman was a 20th century
<|begin_of_text|>Richard Feynman was a 20th century physicist
<|begin_of_text|>Richard Feynman was a 20th century physicist who
<|begin_of_text|>Richard Feynman was a 20th century physicist who won
<|begin_of_text|>Richard Feynman was a 20th century physicist who won the
<|begin_of_text|>Richard Feynman was a 20th century physicist who won the Nobel
<|begin_of_text|>Richard Feynman was a 20th century physicist who won the Nobel Prize
<|begin_of_text|>Richard Feynman was a 20th century physicist who won the Nobel Prize in
<|begin_of_text|>Richard Feynman was a 20th century physicist who won the Nobel Prize in 
<|begin_of_text|>Richard Feynman was a 20th century physicist who won the Nobel Prize in 196
<|begin_of_text|>Richard Feynman was a 20th century physicist who won the Nobel Prize in 1965
<|begin_of_text|>Richard Feynman was a 20th century physicist who won the Nobel Prize in 1965 for
<|