From 0f255058ff7b3110aa8e9a0f54dbea50a91d6ac9 Mon Sep 17 00:00:00 2001 From: Grace Ho Date: Wed, 12 Jun 2024 14:40:41 -0700 Subject: [PATCH] improve benchmark tput by moving prompt preparation outside of loop --- src/llmperf/utils.py | 6 ++---- token_benchmark_ray.py | 28 +++++++++++++++++----------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/llmperf/utils.py b/src/llmperf/utils.py index 4e3b2e9..3b0a32b 100644 --- a/src/llmperf/utils.py +++ b/src/llmperf/utils.py @@ -60,6 +60,8 @@ def randomly_sample_sonnet_lines_prompt( prompt_tokens_mean: int = 550, prompt_tokens_stddev: int = 250, expect_output_tokens: int = 150, + tokenizer = LlamaTokenizerFast.from_pretrained( + "hf-internal-testing/llama-tokenizer") ) -> Tuple[str, int]: """Generate a prompt that randomly samples lines from a the shakespeare sonnet at sonnet.txt. @@ -80,10 +82,6 @@ def randomly_sample_sonnet_lines_prompt( A tuple of the prompt and the length of the prompt. """ - tokenizer = LlamaTokenizerFast.from_pretrained( - "hf-internal-testing/llama-tokenizer" - ) - get_token_length = lambda text: len(tokenizer.encode(text)) prompt = ( diff --git a/token_benchmark_ray.py b/token_benchmark_ray.py index bdaa537..a8c7754 100644 --- a/token_benchmark_ray.py +++ b/token_benchmark_ray.py @@ -71,6 +71,21 @@ def get_token_throughput_latencies( req_launcher = RequestsLauncher(clients) completed_requests = [] num_completed_requests = 0 + # make up prompts outside of send loop for faster benchmarking loop + num_output_tokens_list = [] + prompts = [] + for i in range(max_num_completed_requests): + num_output_tokens = (sample_random_positive_int( + mean_output_tokens, stddev_output_tokens + )) + num_output_tokens_list.append(num_output_tokens) + + prompts.append(randomly_sample_sonnet_lines_prompt( + prompt_tokens_mean=mean_input_tokens, + prompt_tokens_stddev=stddev_input_tokens, + expect_output_tokens=num_output_tokens, + tokenizer=tokenizer + )) start_time = time.monotonic() iter = 0 pbar = tqdm(total=max_num_completed_requests) @@ -79,21 +94,12 @@ def get_token_throughput_latencies( and len(completed_requests) < max_num_completed_requests ): iter += 1 - num_output_tokens = sample_random_positive_int( - mean_output_tokens, stddev_output_tokens - ) - - prompt = randomly_sample_sonnet_lines_prompt( - prompt_tokens_mean=mean_input_tokens, - prompt_tokens_stddev=stddev_input_tokens, - expect_output_tokens=num_output_tokens, - ) - default_sampling_params = {"max_tokens": num_output_tokens} + default_sampling_params = {"max_tokens": num_output_tokens_list.pop()} default_sampling_params.update(additional_sampling_params) request_config = RequestConfig( model=model, - prompt=prompt, + prompt=prompts.pop(), sampling_params=default_sampling_params, llm_api=llm_api, )