forked from meta-llama/llama
-
Notifications
You must be signed in to change notification settings - Fork 0
/
example_torch.py
92 lines (76 loc) · 2.71 KB
/
example_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
from typing import Tuple
import os
import sys
import torch
import fire
import time
import json
from pathlib import Path
from llama import Tokenizer, LLaMA, ModelArgs
from llama import Transformer2 as Transformer
def load(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
ckpt_path = checkpoints[0] # only have one file, for 7B
print("Loading")
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint, strict=False)
generator = LLaMA(model, tokenizer)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return generator
if __name__ == "__main__":
# fire.Fire(main)
ckpt_dir = "ckpts/7B"
tokenizer_path = "ckpts/tokenizer.model"
max_seq_len: int = 512
max_batch_size: int = 32
model = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)
generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)
temperature: float = 0.8
top_p: float = 0.95
prompts = [
# For these prompts, the expected answer is the natural continuation of the prompt
"I believe the meaning of life is",
"Simply put, the theory of relativity states that ",
"Building a website can be done in 10 simple steps:\n",
# Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api
"""Tweet: "I hate it when my phone battery dies."
Sentiment: Negative
###
Tweet: "My day has been 👍"
Sentiment: Positive
###
Tweet: "This is the link to the article"
Sentiment: Neutral
###
Tweet: "This new music video was incredibile"
Sentiment:""",
"""Translate English to French:
sea otter => loutre de mer
peppermint => menthe poivrée
plush girafe => girafe peluche
cheese =>""",
]
results = generator.generate(
prompts, max_gen_len=256, temperature=temperature, top_p=top_p
)
for prompt, result in zip(prompts, results):
print(prompt, result.text)
print("\n==================================\n")