From e54c8d1ffd948eba0088fe1fdf239adec6243c82 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 3 Oct 2024 00:59:50 -0700 Subject: [PATCH] [Distributed] Fix tiktokenizer decoding --- dist_run.py | 10 +++++++--- torchchat/distributed/safetensor_utils.py | 6 ++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/dist_run.py b/dist_run.py index ceb18bf37..bc3badbc4 100644 --- a/dist_run.py +++ b/dist_run.py @@ -442,7 +442,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # New token generated each iteration # need a row dimension for each prompt in the batch new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64) - logger.info(f"{color.green}{new_token.shape=}, {new_token=}{color.reset}") # Store the generated tokens res = [] @@ -519,7 +518,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # Decode the output if pp_rank == last_pp_rank: - # logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}") new_token = _batch_decode_next_tokens(output, prompt_lengths, step) res.append(new_token) if not args.disable_in_flight_decode: @@ -541,7 +539,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # token ids. Thus cat'ing along dim 1. res = torch.cat(res, dim=1) res_list = res.tolist() - responses = tokenizer.decode(res_list) + if isinstance(tokenizer, TiktokenTokenizer): + # For TiktokenTokenizer, we need to decode prompt by prompt. + # TODO: is there a better way to do this? + responses = [tokenizer.decode(sequence) for sequence in res_list] + else: # SentencePieceProcessor + # For SentencePieceProcessor, we can decode the entire 2D list at once. + responses = tokenizer.decode(res_list) # Show prompts and responses for prompt_text, response_text in zip(prompt, responses): logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}") diff --git a/torchchat/distributed/safetensor_utils.py b/torchchat/distributed/safetensor_utils.py index 01c5091b1..80ae6b585 100644 --- a/torchchat/distributed/safetensor_utils.py +++ b/torchchat/distributed/safetensor_utils.py @@ -88,6 +88,9 @@ def get_hf_weight_map_and_path( raise FileNotFoundError( f"Weight index file for {model_id} does not exist in HF cache." ) + logger.info( + f"Loading weight map from: {index_file}" + ) weight_map = read_weights_from_json(index_file) if weight_map is None: raise ValueError(f"Weight map not found in config file {index_file}") @@ -95,6 +98,9 @@ def get_hf_weight_map_and_path( weight_path = os.path.dirname(index_file) if not os.path.exists(weight_path): raise FileNotFoundError(f"Weight path {weight_path} does not exist") + logger.info( + f"Loading weights from: {weight_path}" + ) return weight_map, weight_path, new_to_old_keymap