diff --git a/dist_run.py b/dist_run.py index bc3badbc4..1580f19b1 100644 --- a/dist_run.py +++ b/dist_run.py @@ -10,6 +10,7 @@ import argparse import os +from enum import auto, Enum from pathlib import Path from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple @@ -49,6 +50,7 @@ logger = SingletonLogger.get_logger() +_tokenizer_type = None # global variable to store the tokenizer type # Using model name to identify the model to load, for example "llama2-7b-chat". # You can change it to other values listed below. @@ -59,6 +61,11 @@ } +class TokenizerType(Enum): + Tiktoken = auto() + SentencePiece = auto() + + def _init_distributed(): dist.init_process_group("nccl") rank = dist.get_rank() @@ -80,7 +87,10 @@ def _build_chat_tokenizer( model_name: str, model_base_name: Optional[str] = None, ) -> SentencePieceProcessor | TiktokenTokenizer: - """Builds a tokenizer for the given model name.""" + """Builds a tokenizer for the given model name, and sets the global tokenizer type variable""" + + global _tokenizer_type + # Try to infer the model base name from the model name: # e.g. "llama2-7b-chat" -> "llama2" if model_base_name is None: @@ -107,6 +117,15 @@ def _build_chat_tokenizer( logger.info( f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}" ) + # set global variable _tokenizer_type + if isinstance(tokenizer, TiktokenTokenizer): + _tokenizer_type = TokenizerType.Tiktoken + elif isinstance(tokenizer, SentencePieceProcessor): + _tokenizer_type = TokenizerType.SentencePiece + else: + raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}") + + logger.info(f"tokenizer type = {_tokenizer_type}") return tokenizer @@ -269,6 +288,7 @@ def _cleanup(): prompt = [ "What is Snow?", + # "Can you explain what is the purpose of back propagation in neural networks?", "Who is Santa Claus?", "Where does Santa live?", # "Who is Abraham Lincoln?", @@ -487,7 +507,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: group=pp_group, ) # create schedule - decorder = ScheduleGPipe(decode_stage, 1) + decoder = ScheduleGPipe(decode_stage, 1) # Decoding with torch.no_grad(), CUDATrackTime() as timer: @@ -510,11 +530,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # Run data through pipeline if pp_rank == first_pp_rank: - output = decorder.step(new_token, **kwargs) + output = decoder.step(new_token, **kwargs) elif pp_rank == last_pp_rank: - output = decorder.step(**kwargs) + output = decoder.step(**kwargs) else: # middle pp ranks - decorder.step(**kwargs) + decoder.step(**kwargs) # Decode the output if pp_rank == last_pp_rank: @@ -539,13 +559,16 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # token ids. Thus cat'ing along dim 1. res = torch.cat(res, dim=1) res_list = res.tolist() - if isinstance(tokenizer, TiktokenTokenizer): + if _tokenizer_type == TokenizerType.Tiktoken: # For TiktokenTokenizer, we need to decode prompt by prompt. # TODO: is there a better way to do this? responses = [tokenizer.decode(sequence) for sequence in res_list] - else: # SentencePieceProcessor + elif _tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor # For SentencePieceProcessor, we can decode the entire 2D list at once. responses = tokenizer.decode(res_list) + else: + raise ValueError(f"Unknown tokenizer type {_tokenizer_type}") + # Show prompts and responses for prompt_text, response_text in zip(prompt, responses): logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")