In [1]:
import sys
sys.path.append('/accounts/grad/zhangyunzhe2023/tlide')

from typing import List, Optional
from llama import Dialog, Llama

import os
import torch
import torch.distributed as dist

# choose any free port
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29500")
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")
os.environ.setdefault("LOCAL_RANK", "0")

# if you have a CUDA GPU, use 'nccl'; otherwise use 'gloo'
backend = "nccl" if torch.cuda.is_available() else "gloo"
if backend == "nccl":
    torch.cuda.set_device(0)

dist.init_process_group(backend=backend)

In [2]:
generator = Llama.build(
    ckpt_dir='/accounts/grad/zhangyunzhe2023/.llama/checkpoints/Llama3.2-1B',
    tokenizer_path='/accounts/grad/zhangyunzhe2023/.llama/checkpoints/Llama3.2-1B/tokenizer.model',
    max_seq_len=8192,
    max_batch_size=1,
)

tokenizer = generator.tokenizer
model = generator.model.eval()
params = generator.model.params

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


  _C._set_default_tensor_type(t)


Loaded in 7.42 seconds


In [3]:
from datasets import load_dataset
dataset = load_dataset('THUDM/LongBench-v2', split='train')

In [4]:
text = dataset[0]['context']
start = text.find('\n')
text = text[start+1:]

prompts: List[str] = [text]

In [10]:
prompt_tokens = [tokenizer.encode(prompts[0], bos=True, eos=False)]
bsz = 1
pad_id = tokenizer.pad_id
tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cuda")

model.clear_cache()
result = []
for pos in range(10):
    result.append(model.forward(tokens[:, pos:pos+1], pos, return_last_hidden=True))

In [9]:
model.forward(tokens[:, 1:2], 1, return_last_hidden=True)

tensor([[[ 1.4498,  1.9637,  1.7737,  ..., -5.3823, -4.5258,  0.7892]]])

In [None]:
max_gen_len = 100


prompt_tokens = [tokenizer.encode(x, bos=True, eos=False) for x in prompts]

params = generator.model.params
bsz = len(prompt_tokens)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
assert max_prompt_len <= params.max_seq_len
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

pad_id = tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")

for k, t in enumerate(prompt_tokens):
    tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id

stop_tokens = torch.tensor(list(tokenizer.stop_tokens))

In [5]:
cur_pos = min_prompt_len
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
next_token = torch.argmax(logits[:, -1], dim=-1)

In [23]:
8192 - (model.layers[0].attention.cache_k == 0).all(dim=-1).all(dim=-1).sum()

tensor(1)

In [12]:
params

ModelArgs(dim=2048, n_layers=16, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=256, ffn_dim_multiplier=1.5, norm_eps=1e-05, rope_theta=500000.0, max_batch_size=1, max_seq_len=8192, use_scaled_rope=True)