Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import argparse
import os
from enum import auto, Enum
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple
Expand Down Expand Up @@ -49,6 +50,7 @@


logger = SingletonLogger.get_logger()
_tokenizer_type = None # global variable to store the tokenizer type

# Using model name to identify the model to load, for example "llama2-7b-chat".
# You can change it to other values listed below.
Expand All @@ -59,6 +61,11 @@
}


class TokenizerType(Enum):
Tiktoken = auto()
SentencePiece = auto()


def _init_distributed():
dist.init_process_group("nccl")
rank = dist.get_rank()
Expand All @@ -80,7 +87,10 @@ def _build_chat_tokenizer(
model_name: str,
model_base_name: Optional[str] = None,
) -> SentencePieceProcessor | TiktokenTokenizer:
"""Builds a tokenizer for the given model name."""
"""Builds a tokenizer for the given model name, and sets the global tokenizer type variable"""

global _tokenizer_type

# Try to infer the model base name from the model name:
# e.g. "llama2-7b-chat" -> "llama2"
if model_base_name is None:
Expand All @@ -107,6 +117,15 @@ def _build_chat_tokenizer(
logger.info(
f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}"
)
# set global variable _tokenizer_type
if isinstance(tokenizer, TiktokenTokenizer):
_tokenizer_type = TokenizerType.Tiktoken
elif isinstance(tokenizer, SentencePieceProcessor):
_tokenizer_type = TokenizerType.SentencePiece
else:
raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}")

logger.info(f"tokenizer type = {_tokenizer_type}")
return tokenizer


Expand Down Expand Up @@ -269,6 +288,7 @@ def _cleanup():

prompt = [
"What is Snow?",
# "Can you explain what is the purpose of back propagation in neural networks?",
"Who is Santa Claus?",
"Where does Santa live?",
# "Who is Abraham Lincoln?",
Expand Down Expand Up @@ -487,7 +507,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
group=pp_group,
)
# create schedule
decorder = ScheduleGPipe(decode_stage, 1)
decoder = ScheduleGPipe(decode_stage, 1)

# Decoding
with torch.no_grad(), CUDATrackTime() as timer:
Expand All @@ -510,11 +530,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:

# Run data through pipeline
if pp_rank == first_pp_rank:
output = decorder.step(new_token, **kwargs)
output = decoder.step(new_token, **kwargs)
elif pp_rank == last_pp_rank:
output = decorder.step(**kwargs)
output = decoder.step(**kwargs)
else: # middle pp ranks
decorder.step(**kwargs)
decoder.step(**kwargs)

# Decode the output
if pp_rank == last_pp_rank:
Expand All @@ -539,13 +559,16 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
# token ids. Thus cat'ing along dim 1.
res = torch.cat(res, dim=1)
res_list = res.tolist()
if isinstance(tokenizer, TiktokenTokenizer):
if _tokenizer_type == TokenizerType.Tiktoken:
# For TiktokenTokenizer, we need to decode prompt by prompt.
# TODO: is there a better way to do this?
responses = [tokenizer.decode(sequence) for sequence in res_list]
else: # SentencePieceProcessor
elif _tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor
# For SentencePieceProcessor, we can decode the entire 2D list at once.
responses = tokenizer.decode(res_list)
else:
raise ValueError(f"Unknown tokenizer type {_tokenizer_type}")

# Show prompts and responses
for prompt_text, response_text in zip(prompt, responses):
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")
Expand Down
Loading