From 0e3efeee00f9302dba1b2d6e917dc1fd4ac9c56e Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 3 Oct 2024 16:20:44 -0700 Subject: [PATCH 1/3] add TokenizerType enum, update decoder spelling --- dist_run.py | 40 +++++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/dist_run.py b/dist_run.py index bc3badbc4..02809ca03 100644 --- a/dist_run.py +++ b/dist_run.py @@ -10,6 +10,7 @@ import argparse import os +from enum import auto, Enum from pathlib import Path from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple @@ -59,6 +60,11 @@ } +class TokenizerType(Enum): + Tiktoken = auto() + SentencePiece = auto() + + def _init_distributed(): dist.init_process_group("nccl") rank = dist.get_rank() @@ -79,7 +85,7 @@ def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace: def _build_chat_tokenizer( model_name: str, model_base_name: Optional[str] = None, -) -> SentencePieceProcessor | TiktokenTokenizer: +) -> tuple[SentencePieceProcessor | TiktokenTokenizer, TokenizerType]: """Builds a tokenizer for the given model name.""" # Try to infer the model base name from the model name: # e.g. "llama2-7b-chat" -> "llama2" @@ -107,7 +113,15 @@ def _build_chat_tokenizer( logger.info( f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}" ) - return tokenizer + if isinstance(tokenizer, TiktokenTokenizer): + tokenizer_type = TokenizerType.Tiktoken + elif isinstance(tokenizer, SentencePieceProcessor): + tokenizer_type = TokenizerType.SentencePiece + else: + raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}") + + logger.info(f"tokenizer type = {tokenizer_type}") + return tokenizer, tokenizer_type def _load_model_weights(stage_module, distribution, device, model_config): @@ -269,8 +283,9 @@ def _cleanup(): prompt = [ "What is Snow?", - "Who is Santa Claus?", - "Where does Santa live?", + "Can you explain what is the purpose of back propagation in neural networks?", + # "Who is Santa Claus?", + # "Where does Santa live?", # "Who is Abraham Lincoln?", # "How are models trained?", ] @@ -294,7 +309,7 @@ def main(args): config = TransformerArgs.from_params(model_config.transformer_args["text"]) logger.info(f"Transformer Config: {config}") - tokenizer = _build_chat_tokenizer(model_name) + tokenizer, tokenizer_type = _build_chat_tokenizer(model_name) set_precision(model_dtype) logger.info(f"Using cache precision {model_dtype}") @@ -487,7 +502,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: group=pp_group, ) # create schedule - decorder = ScheduleGPipe(decode_stage, 1) + decoder = ScheduleGPipe(decode_stage, 1) # Decoding with torch.no_grad(), CUDATrackTime() as timer: @@ -510,11 +525,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # Run data through pipeline if pp_rank == first_pp_rank: - output = decorder.step(new_token, **kwargs) + output = decoder.step(new_token, **kwargs) elif pp_rank == last_pp_rank: - output = decorder.step(**kwargs) + output = decoder.step(**kwargs) else: # middle pp ranks - decorder.step(**kwargs) + decoder.step(**kwargs) # Decode the output if pp_rank == last_pp_rank: @@ -539,13 +554,16 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # token ids. Thus cat'ing along dim 1. res = torch.cat(res, dim=1) res_list = res.tolist() - if isinstance(tokenizer, TiktokenTokenizer): + if tokenizer_type == TokenizerType.Tiktoken: # For TiktokenTokenizer, we need to decode prompt by prompt. # TODO: is there a better way to do this? responses = [tokenizer.decode(sequence) for sequence in res_list] - else: # SentencePieceProcessor + elif tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor # For SentencePieceProcessor, we can decode the entire 2D list at once. responses = tokenizer.decode(res_list) + else: + raise ValueError(f"Unknown tokenizer type {tokenizer_type}") + # Show prompts and responses for prompt_text, response_text in zip(prompt, responses): logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}") From 31988b20a83d6b4d130859f73c57aa65b2ec9a6d Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 3 Oct 2024 22:02:17 -0700 Subject: [PATCH 2/3] revert prompts to same length --- dist_run.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dist_run.py b/dist_run.py index 02809ca03..0f912b0a9 100644 --- a/dist_run.py +++ b/dist_run.py @@ -283,9 +283,9 @@ def _cleanup(): prompt = [ "What is Snow?", - "Can you explain what is the purpose of back propagation in neural networks?", - # "Who is Santa Claus?", - # "Where does Santa live?", + # "Can you explain what is the purpose of back propagation in neural networks?", + "Who is Santa Claus?", + "Where does Santa live?", # "Who is Abraham Lincoln?", # "How are models trained?", ] From 31fb1cf418b94c3a2aaf4d6b2df8124bebbdfbf9 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sat, 5 Oct 2024 16:37:29 -0700 Subject: [PATCH 3/3] PR comment, update _tokenizer_type to global --- dist_run.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/dist_run.py b/dist_run.py index 0f912b0a9..1580f19b1 100644 --- a/dist_run.py +++ b/dist_run.py @@ -50,6 +50,7 @@ logger = SingletonLogger.get_logger() +_tokenizer_type = None # global variable to store the tokenizer type # Using model name to identify the model to load, for example "llama2-7b-chat". # You can change it to other values listed below. @@ -85,8 +86,11 @@ def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace: def _build_chat_tokenizer( model_name: str, model_base_name: Optional[str] = None, -) -> tuple[SentencePieceProcessor | TiktokenTokenizer, TokenizerType]: - """Builds a tokenizer for the given model name.""" +) -> SentencePieceProcessor | TiktokenTokenizer: + """Builds a tokenizer for the given model name, and sets the global tokenizer type variable""" + + global _tokenizer_type + # Try to infer the model base name from the model name: # e.g. "llama2-7b-chat" -> "llama2" if model_base_name is None: @@ -113,15 +117,16 @@ def _build_chat_tokenizer( logger.info( f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}" ) + # set global variable _tokenizer_type if isinstance(tokenizer, TiktokenTokenizer): - tokenizer_type = TokenizerType.Tiktoken + _tokenizer_type = TokenizerType.Tiktoken elif isinstance(tokenizer, SentencePieceProcessor): - tokenizer_type = TokenizerType.SentencePiece + _tokenizer_type = TokenizerType.SentencePiece else: raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}") - logger.info(f"tokenizer type = {tokenizer_type}") - return tokenizer, tokenizer_type + logger.info(f"tokenizer type = {_tokenizer_type}") + return tokenizer def _load_model_weights(stage_module, distribution, device, model_config): @@ -309,7 +314,7 @@ def main(args): config = TransformerArgs.from_params(model_config.transformer_args["text"]) logger.info(f"Transformer Config: {config}") - tokenizer, tokenizer_type = _build_chat_tokenizer(model_name) + tokenizer = _build_chat_tokenizer(model_name) set_precision(model_dtype) logger.info(f"Using cache precision {model_dtype}") @@ -554,15 +559,15 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # token ids. Thus cat'ing along dim 1. res = torch.cat(res, dim=1) res_list = res.tolist() - if tokenizer_type == TokenizerType.Tiktoken: + if _tokenizer_type == TokenizerType.Tiktoken: # For TiktokenTokenizer, we need to decode prompt by prompt. # TODO: is there a better way to do this? responses = [tokenizer.decode(sequence) for sequence in res_list] - elif tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor + elif _tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor # For SentencePieceProcessor, we can decode the entire 2D list at once. responses = tokenizer.decode(res_list) else: - raise ValueError(f"Unknown tokenizer type {tokenizer_type}") + raise ValueError(f"Unknown tokenizer type {_tokenizer_type}") # Show prompts and responses for prompt_text, response_text in zip(prompt, responses):