In [1]:
import os
import sys
import gc
import numpy as np

from transformers import AutoTokenizer, LlamaForCausalLM
import torch
import torch.nn.functional as F

In [2]:
sys.path.append("..")

In [3]:
from models.llama3.transformer import Transformer
from models.llama3.tokenizer import Tokenizer
from models.llama3.config import LlamaConfig
from models.llama3.load import build

In [4]:
def hf_main(prompts):
    device = "mps"
    model_id = "meta-llama/Llama-3.2-3B"
    os.environ["TOKENIZERS_PARALLELISM"] = "true"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
    # prompts = ["The theory of relativity states that"]
    print(prompts[0])
    inputs = tokenizer(prompts, return_tensors="pt")
    print(inputs["input_ids"].numpy().tolist())
    # [[128000, 791, 10334, 315, 1375, 44515, 5415, 430]]
    inputs = {k:v.to(device) for k,v in inputs.items()}
    model = LlamaForCausalLM.from_pretrained(model_id).to(device)
    model.generation_config.pad_token_id = model.config.eos_token_id
    """
    /Users/jyotirmaya.mahanta/projects/thelonejordan/personal/deeplearning.scratchpad/.venv/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:628: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.
    warnings.warn(
    /Users/jyotirmaya.mahanta/projects/thelonejordan/personal/deeplearning.scratchpad/.venv/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:633: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.
    warnings.warn(
    """
    outputs = model.generate(**inputs, max_length=30, do_sample=False, temperature=None, top_p=None)
    print(model.generation_config)
    print(outputs.cpu().numpy().tolist())
    # [[128000, 791, 10334, 315, 1375, 44515, 5415, 430, 279, 4732, 315, 3177, 374, 6926, 304, 682, 5905, 14418, 13, 1115, 3445, 430, 422, 499, 527, 7366, 520, 264, 6926, 4732]]
    texts = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    print(texts[0])
    # The theory of relativity states that the speed of light is constant in all reference frames. This means that if you are moving at a constant speed
    return model, tokenizer, model.config, model.generation_config, device, inputs, texts[0]

In [5]:
@torch.inference_mode()
def generate(prompt_tokens: list[str],model: Transformer, tokenizer: Tokenizer, config: LlamaConfig, device, logprobs: bool=False):
  max_batch_size, max_seq_len = config.max_batch_size, config.max_seq_len
  bsz = len(prompt_tokens)
  assert bsz <= max_batch_size, (bsz, max_batch_size)
  max_gen_len = config.max_seq_len
  min_prompt_len = min(len(t) for t in prompt_tokens)
  max_prompt_len = max(len(t) for t in prompt_tokens)
  assert max_prompt_len <= max_seq_len
  total_len = min(max_seq_len, max_gen_len + max_prompt_len)
  pad_id = tokenizer.pad_id
  tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)
  for k, t in enumerate(prompt_tokens):
    tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
  if logprobs:
    token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
  prev_pos = 0
  eos_reached = torch.tensor([False] * bsz, device=device)
  input_text_mask = tokens != pad_id
  if min_prompt_len == total_len:
    logits = model.forward(tokens, prev_pos)
    token_logprobs = -F.cross_entropy(
      input=logits.transpose(1, 2),
      target=tokens,
      reduction="none",
      ignore_index=pad_id,
    )
  stop_tokens = torch.tensor(list(tokenizer.stop_tokens), device=device)

  for cur_pos in range(min_prompt_len, total_len):
    logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
    probs = torch.softmax(logits[:, -1], dim=-1)
    next_token = torch.argmax(logits[:, -1], dim=-1)
    next_token = next_token.reshape(-1)
    # only replace token if prompt has already been generated
    next_token = torch.where(
      input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
    )
    tokens[:, cur_pos] = next_token
    if logprobs:
      token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
        input=logits.transpose(1, 2),
        target=tokens[:, prev_pos + 1 : cur_pos + 1],
        reduction="none",
        ignore_index=pad_id,
      )
    eos_reached |= (~input_text_mask[:, cur_pos]) & (
      torch.isin(next_token, stop_tokens)
    )
    prev_pos = cur_pos
    if all(eos_reached):
      break
  if logprobs:
    token_logprobs = token_logprobs.tolist()  # type: ignore
  out_tokens, out_logprobs = [], []
  for i, toks in enumerate(tokens.tolist()):
    # cut to max gen len
    start = 0
    toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
    if logprobs:
      probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
    # cut to after eos tok if any
    for stop_token in tokenizer.stop_tokens:
      try:
        eos_idx = toks.index(stop_token)
        toks = toks[:eos_idx]
        if logprobs:
          probs = probs[:eos_idx]
      except ValueError:
        pass
    out_tokens.append(toks)
    out_logprobs.append(probs)
  return out_tokens, (out_logprobs if logprobs else None)

In [6]:
import functools

def my_main(prompts, safetensors=False):
    device = "mps"
    model, tokenizer, config = build(
        max_seq_len=30,
        max_batch_size=1,
        model_desc="3B",
        version=2,
        safetensors=safetensors,
    )
    model = model.to(device)
    # prompts = ["The theory of relativity states that"]
    print(prompts[0])
    tokenizer.pad_id = tokenizer.eos_id
    inputs = [tokenizer.encode(s, bos=True, eos=False) for s in prompts]
    print(inputs)
    # [[128000, 791, 10334, 315, 1375, 44515, 5415, 430]]
    outputs, logprobs = generate(inputs, model, tokenizer, config, device, logprobs=False)
    print(outputs)
    # Before rope fix:
    # [[128000, 791, 10334, 315, 1375, 44515, 5415, 430, 279, 4732, 315, 3177, 374, 6926, 304, 682, 5905, 14418, 315, 5905, 13, 578, 10334, 315, 1375, 44515, 374, 264, 10334, 315]]
    # After rope fix:
    # [128000, 791, 10334, 315, 1375, 44515, 5415, 430, 279, 4732, 315, 3177, 374, 6926, 304, 682, 5905, 14418, 13, 1115, 3445, 430, 422, 499, 527, 7366, 520, 264, 6926, 4732]]
    outputs = [i[i.index(tokenizer.bos_id)+1:]for i in outputs]
    texts = [tokenizer.decode(toks) for toks in outputs]
    print(texts[0])
    # print(np.exp(np.array(logprobs, dtype=np.float32)).tolist())
    return model, tokenizer, config, None, device, inputs, texts[0]


my_main_torch = functools.partial(my_main, safetensors=False)
my_main_safetensors = functools.partial(my_main, safetensors=True)

In [7]:
from dataclasses import dataclass
from typing import Any

@dataclass
class Benchmark:
    model: torch.nn.Module
    tokenizer: Any
    config: Any
    generation_config: Any
    device: Any
    inp: list[list[int]]
    out: list[list[int]]

In [8]:
# prompts = ["The theory of relativity states that"]
prompts = ["The theory of relativity states that the speed of light is constant in all reference frames"]

In [9]:
bench_hf = Benchmark(*hf_main(prompts))

The theory of relativity states that the speed of light is constant in all reference frames
[[128000, 791, 10334, 315, 1375, 44515, 5415, 430, 279, 4732, 315, 3177, 374, 6926, 304, 682, 5905, 14418]]


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

GenerationConfig {
  "bos_token_id": 128000,
  "do_sample": true,
  "eos_token_id": 128001,
  "pad_token_id": 128001,
  "temperature": 0.6,
  "top_p": 0.9
}

[[128000, 791, 10334, 315, 1375, 44515, 5415, 430, 279, 4732, 315, 3177, 374, 6926, 304, 682, 5905, 14418, 13, 1115, 3445, 430, 422, 499, 527, 7366, 520, 264, 6926, 4732]]
The theory of relativity states that the speed of light is constant in all reference frames. This means that if you are moving at a constant speed


In [10]:
bench_cu = Benchmark(*my_main_safetensors(prompts))

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Reloaded tiktoken model from /Users/jyotirmaya.mahanta/.cache/huggingface/hub/models--meta-llama--Llama-3.2-3B/snapshots/13afe5124825b4f3751f836b40dafda64c1ed062/original/tokenizer.model
#words: 128256 - BOS ID: 128000 - EOS ID: 128001


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

number of parameters: 3.21B
The theory of relativity states that the speed of light is constant in all reference frames
[[128000, 791, 10334, 315, 1375, 44515, 5415, 430, 279, 4732, 315, 3177, 374, 6926, 304, 682, 5905, 14418]]
[[128000, 791, 10334, 315, 1375, 44515, 5415, 430, 279, 4732, 315, 3177, 374, 6926, 304, 682, 5905, 14418, 13, 1115, 3445, 430, 422, 499, 527, 7366, 520, 264, 6926, 4732]]
The theory of relativity states that the speed of light is constant in all reference frames. This means that if you are moving at a constant speed


In [11]:
assert bench_hf.out == bench_cu.out, "output mismatch"

In [12]:
del bench_hf
del bench_cu
gc.collect()

39