Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
# New token generated each iteration
# need a row dimension for each prompt in the batch
new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64)
logger.info(f"{color.green}{new_token.shape=}, {new_token=}{color.reset}")
# Store the generated tokens
res = []

Expand Down Expand Up @@ -519,7 +518,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:

# Decode the output
if pp_rank == last_pp_rank:
# logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}")
new_token = _batch_decode_next_tokens(output, prompt_lengths, step)
res.append(new_token)
if not args.disable_in_flight_decode:
Expand All @@ -541,7 +539,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
# token ids. Thus cat'ing along dim 1.
res = torch.cat(res, dim=1)
res_list = res.tolist()
responses = tokenizer.decode(res_list)
if isinstance(tokenizer, TiktokenTokenizer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - I think this type of check is better done directly in the _build_chat_tokenizer function and then we just have an enum for the tokenizer type and can correctly err out if a not recognized tokenizer.
The reason for that is two fold:
a - we future proof ourself so that if llama4 has a new tokenizer we are not going back through the code trying to figure where all the 'tokenizer' type checks are, and updating.
b - we only check once in a logical point, and then all other code uses the enum, and we have an upfront single failure point to err out if we are hitting an unrecognized tokenizer.
As this code is currently, it assumes that if not tiktoken then it must be sentencepiece which is a brittle assumption long term.

# For TiktokenTokenizer, we need to decode prompt by prompt.
# TODO: is there a better way to do this?
responses = [tokenizer.decode(sequence) for sequence in res_list]
else: # SentencePieceProcessor
# For SentencePieceProcessor, we can decode the entire 2D list at once.
responses = tokenizer.decode(res_list)
# Show prompts and responses
for prompt_text, response_text in zip(prompt, responses):
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")
Expand Down
6 changes: 6 additions & 0 deletions torchchat/distributed/safetensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,19 @@ def get_hf_weight_map_and_path(
raise FileNotFoundError(
f"Weight index file for {model_id} does not exist in HF cache."
)
logger.info(
f"Loading weight map from: {index_file}"
)
weight_map = read_weights_from_json(index_file)
if weight_map is None:
raise ValueError(f"Weight map not found in config file {index_file}")
weight_map, new_to_old_keymap = remap_weight_keys(weight_map)
weight_path = os.path.dirname(index_file)
if not os.path.exists(weight_path):
raise FileNotFoundError(f"Weight path {weight_path} does not exist")
logger.info(
f"Loading weights from: {weight_path}"
)
return weight_map, weight_path, new_to_old_keymap


Expand Down
Loading