In [1]:
import os
import sys
import numpy as np

from transformers import AutoTokenizer, LlamaForCausalLM
import torch
import torch.nn.functional as F

In [2]:
sys.path.append("..")

In [3]:
from models.llama3.transformer import Transformer
from models.llama3.tokenizer import Tokenizer
from models.llama3.config import LlamaConfig
from models.llama3.load import build

In [4]:
prompts = ["The theory of relativity states that the speed of light is constant in all reference frames"]

In [5]:
device = "mps"
model, tokenizer, config = build(
    max_seq_len=19,
    max_batch_size=1,
    model_desc="3B",
    version=2,
    safetensors=True,
)
model = model.to(device)

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Reloaded tiktoken model from /Users/jyotirmaya.mahanta/.cache/huggingface/hub/models--meta-llama--Llama-3.2-3B/snapshots/13afe5124825b4f3751f836b40dafda64c1ed062/original/tokenizer.model
#words: 128256 - BOS ID: 128000 - EOS ID: 128001


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

number of parameters: 3.21B


In [6]:
tokenizer.pad_id = tokenizer.eos_id

In [7]:
inputs = [tokenizer.encode(s, bos=True, eos=False) for s in prompts]
print(inputs)

[[128000, 791, 10334, 315, 1375, 44515, 5415, 430, 279, 4732, 315, 3177, 374, 6926, 304, 682, 5905, 14418]]


In [8]:
prompt_tokens = inputs
max_batch_size, max_seq_len = config.max_batch_size, config.max_seq_len
bsz = len(prompt_tokens)
# assert bsz <= max_batch_size, (bsz, max_batch_size)
max_gen_len = config.max_seq_len
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
# assert max_prompt_len <= max_seq_len
total_len = min(max_seq_len, max_gen_len + max_prompt_len)
pad_id = tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)
for k, t in enumerate(prompt_tokens):
    tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
# if logprobs:
#     token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device=device)
input_text_mask = tokens != pad_id

In [9]:
prev_pos = 0
cur_pos = min_prompt_len
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
probs = torch.softmax(logits[:, -1], dim=-1)
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token

tensor([315], device='mps:0')

In [21]:
toks = tokens[:, prev_pos:cur_pos]
toks

tensor([[128000,    791,  10334,    315,   1375,  44515,   5415,    430,    279,
           4732,    315,   3177,    374,   6926,    304,    682,   5905,  14418]],
       device='mps:0')

In [11]:
# outputs, logprobs = generate(inputs, model, tokenizer, config, device, logprobs=False)
# print(outputs)

In [24]:
h = model.model.embed_tokens(toks)
print(h.shape)
h

torch.Size([1, 18, 3072])


tensor([[[-0.0011, -0.0007, -0.0046,  ..., -0.0015, -0.0021,  0.0018],
         [ 0.0065, -0.0332, -0.0101,  ..., -0.0303,  0.0197, -0.0017],
         [-0.0264, -0.0152, -0.0183,  ...,  0.0131,  0.0369, -0.0364],
         ...,
         [ 0.0013,  0.0025, -0.0155,  ...,  0.0320,  0.0058, -0.0131],
         [ 0.0454, -0.0099,  0.0054,  ..., -0.0061, -0.0052, -0.0125],
         [-0.0330, -0.0001,  0.0117,  ..., -0.0094,  0.0117, -0.0187]]],
       device='mps:0', grad_fn=<EmbeddingBackward0>)

In [29]:
seqlen = toks.size(1)
start_pos = prev_pos

In [31]:
freqs_cis = model.freqs_cis[start_pos : start_pos + seqlen]
print(freqs_cis.shape)
freqs_cis

torch.Size([18, 64])


tensor([[ 1.0000+0.0000e+00j,  1.0000+0.0000e+00j,  1.0000+0.0000e+00j,
          ...,  1.0000+0.0000e+00j,  1.0000+0.0000e+00j,
          1.0000+0.0000e+00j],
        [ 0.5403+8.4147e-01j,  0.6861+7.2746e-01j,  0.7878+6.1596e-01j,
          ...,  1.0000+3.6997e-06j,  1.0000+3.0139e-06j,
          1.0000+2.4551e-06j],
        [-0.4161+9.0930e-01j, -0.0584+9.9829e-01j,  0.2412+9.7048e-01j,
          ...,  1.0000+7.3994e-06j,  1.0000+6.0277e-06j,
          1.0000+4.9103e-06j],
        ...,
        [-0.7597+6.5029e-01j,  0.9404-3.4018e-01j, -0.8632-5.0488e-01j,
          ...,  1.0000+5.5496e-05j,  1.0000+4.5208e-05j,
          1.0000+3.6827e-05j],
        [-0.9577-2.8790e-01j,  0.8927+4.5066e-01j, -0.3690-9.2942e-01j,
          ...,  1.0000+5.9196e-05j,  1.0000+4.8222e-05j,
          1.0000+3.9282e-05j],
        [-0.2752-9.6140e-01j,  0.2847+9.5862e-01j,  0.2818-9.5948e-01j,
          ...,  1.0000+6.2895e-05j,  1.0000+5.1236e-05j,
          1.0000+4.1737e-05j]], device='mps:0')