Skip to content

Commit

Permalink
Merge pull request meta-llama#26 from pytorch-tpu/llama2_tpu_optimized
Browse files Browse the repository at this point in the history
Llama2 TPU optimizations
  • Loading branch information
miladm committed Jul 25, 2023
2 parents 6c7fe27 + fe2d932 commit d2b8802
Show file tree
Hide file tree
Showing 4 changed files with 868 additions and 109 deletions.
169 changes: 169 additions & 0 deletions example_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

from typing import Tuple
import os
import sys
import torch
import fire
import time
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import json
from pathlib import Path

from llama import ModelArgs, Transformer, Tokenizer, Llama
from llama.xla_model_parallel import get_model_parallel_rank, get_model_parallel_world_size


def setup_model_parallel() -> Tuple[int, int]:
# assuming model parallelism over the whole world size
rank = get_model_parallel_rank()
world_size = get_model_parallel_world_size()

# seed must be the same in all processes
torch.manual_seed(1)
device = xm.xla_device()
xm.set_rng_state(1, device=device)
return rank, world_size


def load(
ckpt_dir: str,
tokenizer_path: str,
rank: int,
world_size: int,
max_seq_len: int,
max_batch_size: int,
dim: int = 4096,
n_layers: int = 32,
n_heads: int = 32,
) -> Llama:
start_time = time.time()
print("Loading")
if ckpt_dir:
# load checkpoint if ckpt_dir is provided
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert world_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
ckpt_path = checkpoints[rank]
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
else:
params = {"dim": dim,
"n_layers": n_layers,
"n_heads": n_heads,
}

model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
if ckpt_dir:
model.load_state_dict(checkpoint, strict=False)
device = xm.xla_device()
model = model.to(device)
for i in range(len(model.cache_kvs)):
model.cache_kvs[i] = tuple(t.to(device) for t in model.cache_kvs[i])
torch.set_default_tensor_type(torch.FloatTensor)

generator = Llama(model, tokenizer)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return generator


def main(
tokenizer_path: str,
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 512,
max_batch_size: int = 32,
ckpt_dir: str = '',
dim: int = 4096,
n_layers: int = 32,
n_heads: int = 32,
):
rank, world_size = setup_model_parallel()
if rank > 0:
sys.stdout = open(os.devnull, "w")

generator = load(
ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads
)

prompts = [
# For these prompts, the expected answer is the natural continuation of the prompt
"I believe the meaning of life is",
# "Simply put, the theory of relativity states that ",
# "Building a website can be done in 10 simple steps:\n",
# Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api
# """Tweet: "I hate it when my phone battery dies."
#Sentiment: Negative
####
#Tweet: "My day has been 👍"
#Sentiment: Positive
####
#Tweet: "This is the link to the article"
#Sentiment: Neutral
####
#Tweet: "This new music video was incredibile"
#Sentiment:""",
# """Translate English to French:
#
#sea otter => loutre de mer
#
#peppermint => menthe poivrée
#
#plush girafe => girafe peluche
#
#cheese =>""",
]
for _ in range(2):
generation_tokens = generator.text_completion(
prompts, temperature=temperature, top_p=top_p, max_gen_len=256
)

for result in generation_tokens:
print(result)
print("\n==================================\n")


def _fn(
idx,
tokenizer_path: str,
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 512,
max_batch_size: int = 32,
ckpt_dir: str = '',
dim: int = 4096,
n_layers: int = 32,
n_heads: int = 32,
):
main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads)

def mp_main(
mp: bool,
tokenizer_path: str,
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 512,
max_batch_size: int = 32,
ckpt_dir: str = '',
dim: int = 4096,
n_layers: int = 32,
n_heads: int = 32,
):
if mp:
xmp.spawn(_fn, args=(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads))
else:
main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads)


if __name__ == "__main__":
fire.Fire(mp_main)
122 changes: 86 additions & 36 deletions llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import List, Literal, Optional, Tuple, TypedDict

import torch
import torch_xla.core.xla_model as xm

import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
Expand Down Expand Up @@ -44,7 +46,7 @@ class ChatPrediction(TypedDict, total=False):
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""

Expand Down Expand Up @@ -102,17 +104,57 @@ def build(
def __init__(self, model: Transformer, tokenizer: Tokenizer):
self.model = model
self.tokenizer = tokenizer
self._generate_one_token_fn = self._generate_one_token
self._generate_one_token_fn = torch.compile(self._generate_one_token_fn,
backend="torchxla_trace_once", fullgraph=True)


def _generate_one_token(self, tokens, input_tokens, input_text_mask, cur_pos_tensor,
input_pos_tensor, output_pos_tensor, cache_kvs, temperature, top_p, logprobs, prev_pos):
logits, cache_kvs = self.model(input_tokens, input_pos_tensor, output_pos_tensor, cache_kvs)
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)

if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
input_text_mask_tmp = input_text_mask.index_select(1, cur_pos_tensor).squeeze(dim=1)
tokens_tmp = tokens.index_select(1, cur_pos_tensor).squeeze(dim=1)
next_token = torch.where(
input_text_mask_tmp, tokens_tmp, next_token
)
next_token = next_token.unsqueeze(1)
tokens = tokens.index_copy(1, cur_pos_tensor, next_token)
input_pos_tensor = input_pos_tensor[-1:] + 1
cur_pos_tensor = cur_pos_tensor + 1
output_pos_tensor = cur_pos_tensor - 1
input_tokens = tokens.index_select(1, input_pos_tensor)

#TODO: optimize and bring back
eos_reached = False #temp variable
#eos_reached |= (~input_text_mask[:, cur_pos]) & (
# next_token == self.tokenizer.eos_id
#)
return tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs, eos_reached

@torch.inference_mode()
def generate(
self,
prompt_tokens: List[List[int]],
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
echo: bool = False,#NEW LINE
) -> Tuple[List[List[int]], Optional[List[List[float]]]]: #NEW OUTPUT FORMAT
params = self.model.params
bsz = len(prompt_tokens)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
Expand All @@ -123,47 +165,54 @@ def generate(
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
tokens = torch.full((params.max_batch_size, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
token_logprobs.to(device)

device = xm.xla_device()
tokens = tokens.to(device)

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
start_pos = 1
cur_pos_tensor = torch.tensor(start_pos).to(device)
input_pos_tensor = torch.arange(0, start_pos).to(device)
output_pos_tensor = cur_pos_tensor - 1
input_tokens = tokens.index_select(1, input_pos_tensor)
cache_kvs = self.model.cache_kvs #TODO: revisit the cache implementation between the two models

prev_pos = 0 #TODO: drop this parameter
eos_reached = torch.tensor([False] * bsz, device=device)
input_text_mask = tokens != pad_id
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
xm.mark_step(wait=True)

decoding_start_time = time.time()
for _ in range(start_pos, total_len):
#logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvsm, eos_reached \
= self._generate_one_token_fn(
tokens, input_tokens, input_text_mask, cur_pos_tensor,
input_pos_tensor, output_pos_tensor, cache_kvs, temperature, top_p, logprobs, prev_pos
)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)

next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
eos_reached |= (~input_text_mask[:, cur_pos]) & (
next_token == self.tokenizer.eos_id
)
prev_pos = cur_pos
if all(eos_reached):
break
xm.mark_step()
###TODO: optimize this block of code
#prev_pos = cur_pos
#if all(eos_reached):
# xm.mark_step()
# break
####
self.model.cache_kvs = cache_kvs
print(f"Decoded in {time.time() - decoding_start_time:.5f} seconds")

if logprobs:
token_logprobs = token_logprobs.tolist()

# TODO: this block is different from llama1; it's best to re-optimize it as needed. no decode() call here
out_tokens, out_logprobs = [], []
for i, toks in enumerate(tokens.tolist()):
if i >= len(prompt_tokens): #TODO: brought in from llama1 optimization
break
# cut to max gen len
start = 0 if echo else len(prompt_tokens[i])
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
Expand Down Expand Up @@ -191,6 +240,7 @@ def text_completion(
if max_gen_len is None:
max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
print("prompt_tokens ", prompt_tokens)
generation_tokens, generation_logprobs = self.generate(
prompt_tokens=prompt_tokens,
max_gen_len=max_gen_len,
Expand All @@ -199,6 +249,7 @@ def text_completion(
logprobs=logprobs,
echo=echo,
)
print("generation_tokens ", generation_tokens)
if logprobs:
return [
{
Expand Down Expand Up @@ -292,12 +343,11 @@ def chat_completion(
for t in generation_tokens
]


def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort = torch.where(mask, 0.0, probs_sort)
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
Expand Down
Loading

0 comments on commit d2b8802

Please sign in to comment.