-
Notifications
You must be signed in to change notification settings - Fork 0
/
llama_utils.py
60 lines (49 loc) · 2.1 KB
/
llama_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from sentencepiece import SentencePieceProcessor
import torch
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device('cpu')
import time
from pathlib import Path
import json
from model import ModelArgs, Transformer
"""
This module provides utility functions for loading the LLaMA model and its tokenizer.
The module includes functions to:
- Load a pretrained LLaMA model from checkpoint files.
- Load a SentencePiece tokenizer for encoding and decoding text.
Functions:
- `load_llama(checkpoints_dir: str, vocab_size: int, max_seq_len: int)`: Loads the LLaMA model from checkpoint files.
- `load_tokenizer(tokenizer_path: str)`: Loads the SentencePiece tokenizer.
Example usage:
model = load_llama('checkpoints/', vocab_size=32000, max_seq_len=2048)
tokenizer = load_tokenizer('tokenizer.model')
"""
def load_llama(checkpoints_dir: str, vocab_size: int, max_seq_len: int):
prev_time = time.time()
checkpoints = sorted(Path(checkpoints_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {checkpoints_dir}"
ckpt_path = checkpoints[0]
print(f'Loading checkpoint "{ckpt_path}"')
checkpoint = torch.load(ckpt_path, map_location="cpu")
print(f"Loaded checkpoint in {time.time() - prev_time:.2f}s")
with open(Path(checkpoints_dir) / "params.json", "r") as f:
params = json.loads(f.read())
print(f"params: {params}")
model_args = ModelArgs()
model_args.max_seq_len = max_seq_len
assert(model_args.dim == params['dim'])
assert(model_args.n_layers == params['n_layers'])
assert(model_args.vocab_size == vocab_size)
assert(model_args.n_heads == params['n_heads'])
assert(model_args.n_layers == params['n_layers'])
model_args.vocab_size = vocab_size
print(f"model_args: {model_args}")
model = Transformer(model_args)
del checkpoint['rope.freqs']
model.load_state_dict(checkpoint, strict=True)
print(f"Loaded model in {time.time() - prev_time:.2f}s")
return model
def load_tokenizer(tokenizer_path: str):
tokenizer = SentencePieceProcessor()
tokenizer.load(tokenizer_path)
return tokenizer