From 6c35e6c98f8db8636d72c02535e4715d12339a0f Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 23 Sep 2024 09:59:43 -0700 Subject: [PATCH 1/3] Replace total_prompts with batch_size --- dist_run.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dist_run.py b/dist_run.py index 3fbb857c7..ca510c324 100644 --- a/dist_run.py +++ b/dist_run.py @@ -399,10 +399,9 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) # New token generated each iteration - total_prompts = len(prompt_lengths) - # need a new token dimension (row) for each prompt in the batch - new_token = torch.zeros(total_prompts, 1, device=device, dtype=torch.int64) - res = [[] for _ in range(total_prompts)] + # need a row dimension for each prompt in the batch + new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64) + res = [[] for _ in range(batch_size)] num_tokens = 40 # Prefill phase From 9d0ed06e012c45a676512bebe82dab1fbe5303a1 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 23 Sep 2024 12:55:42 -0700 Subject: [PATCH 2/3] Make in-flight decoding optional --- dist_run.py | 112 +++++++++++++++++++++++++--------------------------- 1 file changed, 53 insertions(+), 59 deletions(-) diff --git a/dist_run.py b/dist_run.py index ca510c324..5a31a8e2a 100644 --- a/dist_run.py +++ b/dist_run.py @@ -187,30 +187,23 @@ def _create_padded_prompts( def _batch_decode_next_tokens( output: torch.Tensor, - tokenizer, - prompt_lengths: Optional[List[int]] = None, -) -> List[Tuple[int, str]]: + pos: int, +) -> torch.Tensor: """ Decode the next token for each prompt in the batch. + Args: + output (torch.Tensor): The output tensor to decode. + pos: the position of the `output` to decode in the sequence length dimension. Returns: - List[Tuple[int, str]]: List of tuples containing the next token id and its - decoded string for each prompt in the batch. + Decoded token ids. """ - batch_size = output.shape[0] - results = [] - - for i in range(batch_size): - pos = prompt_lengths[i] - 1 if prompt_lengths is not None else 0 - next_token_logits = output[i, pos, :] - - # Argmax (deterministic) TODO: add temperature - next_token = torch.argmax(next_token_logits, dim=-1) - - next_token_decoded = tokenizer.decode([next_token.item()]) - results.append((next_token.item(), next_token_decoded)) - - return results + # Take the next token logits for each prompt + next_token_logits = output[:, pos, :] + # Argmax (deterministic) TODO: add temperature + next_token = torch.argmax(next_token_logits, dim=-1) + # Token ids in int tensor form + return next_token def _update_padded_sequence( @@ -401,8 +394,8 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # New token generated each iteration # need a row dimension for each prompt in the batch new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64) - res = [[] for _ in range(batch_size)] - num_tokens = 40 + # Store the generated tokens + res = [] # Prefill phase # Run context input through pipeline @@ -421,23 +414,24 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" ) - # Decode the output -- first generated token - if pp_rank == last_pp_rank: - decode_results = _batch_decode_next_tokens( - output=output, - tokenizer=tokenizer, - prompt_lengths=prompt_lengths, - ) - for i in range(len(decode_results)): - new_token[i, 0] = torch.tensor( - [decode_results[i][0]], device=device - ) # token_id in int form + # Decode token id into string and print it + def decode_in_flight(token): + # Make a 2D tensor with ids on row dimension + unsqueezed = torch.unsqueeze(token, 1) + token_str = tokenizer.decode(unsqueezed.tolist()) if tp_rank == 0: logger.info( - f"{color.green} {'* Prefill *'} " - f"responses ====>>>> {color.blue} {decode_results=}{color.reset}" + f"{color.green} responses ====>>>> " + f"{color.blue} {token_str} {color.reset}" ) + # Decode the output -- first generated token + if pp_rank == last_pp_rank: + new_token = _batch_decode_next_tokens(output, prompt_lengths[0] - 1) + res.append(new_token) + if not args.disable_in_flight_decode: + decode_in_flight(new_token) + # seqlen = 1 now seqlen_decode = 1 input_pos = torch.tensor([prompt_lengths[0]], device=device) @@ -459,7 +453,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # Decoding with torch.no_grad(), CUDATrackTime() as timer: - for step in range(num_tokens - 1): + for step in range(args.ntokens - 1): kwargs = {"input_pos": input_pos, "cache_lane": lane} # sendrecv between last and first ranks, only if: # first_pp_rank != last_pp_rank. @@ -486,21 +480,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # Decode the output if pp_rank == last_pp_rank: - decode_results = _batch_decode_next_tokens( - output=output, tokenizer=tokenizer - ) - if tp_rank == 0: - logger.info( - f"{color.green} {'* Decode *'} " - f"responses ====>>>> {color.blue} {decode_results=}{color.reset}" - ) - # decode results returns both token_id (int) and token_str (readable), hence [0] and [1] - for i in range(len(decode_results)): - res[i].append(decode_results[i][1]) - new_token[i, 0] = torch.tensor( - [decode_results[i][0]], device=device - ) # decode_results[i][0] + new_token = _batch_decode_next_tokens(output, 0) + res.append(new_token) + if not args.disable_in_flight_decode: + decode_in_flight(new_token) + # Increment input position input_pos += 1 logger.info( @@ -511,21 +496,18 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # output formatted response via last pp group and tp rank 0 if pp_rank == last_pp_rank and tp_rank == 0: - for i in range(len(prompt_lengths)): - logger.info(f"\nPrompt:{color.green} {prompt[i]} {color.reset}") - - # TODO: resolve issue with llama2-7b-chat model and "".join - if model_name != "llama2-7b-chat": - formatted_response = "".join(res[i]) - else: - formatted_response = " ".join(res[i]) - logger.info(f"$$ {color.red}{formatted_response} {color.reset} $$\n") + # `res` is a list of tensors, each being a batch of generated token ids + res = torch.stack(res, dim=1) + res_list = res.tolist() + response = tokenizer.decode(res_list) + for i in range(len(response)): + logger.info(f"$$ {color.red}{response[i]} {color.reset} $$\n") # Cleanup + _cleanup() logger.info( f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}" ) - _cleanup() if __name__ == "__main__": @@ -537,6 +519,18 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(), ) parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree") + parser.add_argument( + "--ntokens", + type=int, + default=40, + help="Number of tokens to generate", + ) + parser.add_argument( + "--disable-in-flight-decode", + action="store_true", + default=False, + help="Whether to decode token into string in flight", + ) args = parser.parse_args() main(args) From 3364c72571551e5b90f15712ba34484095ed1833 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 23 Sep 2024 13:28:05 -0700 Subject: [PATCH 3/3] Add back prompt print --- dist_run.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dist_run.py b/dist_run.py index 5a31a8e2a..f8597c563 100644 --- a/dist_run.py +++ b/dist_run.py @@ -501,7 +501,8 @@ def decode_in_flight(token): res_list = res.tolist() response = tokenizer.decode(res_list) for i in range(len(response)): - logger.info(f"$$ {color.red}{response[i]} {color.reset} $$\n") + logger.info(f"Prompt: {color.green}{prompt[i]} {color.reset}") + logger.info(f"Response: {color.red}{response[i]} {color.reset}") # Cleanup _cleanup()