Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 56 additions & 62 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,30 +187,23 @@ def _create_padded_prompts(

def _batch_decode_next_tokens(
output: torch.Tensor,
tokenizer,
prompt_lengths: Optional[List[int]] = None,
) -> List[Tuple[int, str]]:
pos: int,
) -> torch.Tensor:
"""
Decode the next token for each prompt in the batch.
Args:
output (torch.Tensor): The output tensor to decode.
pos: the position of the `output` to decode in the sequence length dimension.

Returns:
List[Tuple[int, str]]: List of tuples containing the next token id and its
decoded string for each prompt in the batch.
Decoded token ids.
"""
batch_size = output.shape[0]
results = []

for i in range(batch_size):
pos = prompt_lengths[i] - 1 if prompt_lengths is not None else 0
next_token_logits = output[i, pos, :]

# Argmax (deterministic) TODO: add temperature
next_token = torch.argmax(next_token_logits, dim=-1)

next_token_decoded = tokenizer.decode([next_token.item()])
results.append((next_token.item(), next_token_decoded))

return results
# Take the next token logits for each prompt
next_token_logits = output[:, pos, :]
# Argmax (deterministic) TODO: add temperature
next_token = torch.argmax(next_token_logits, dim=-1)
# Token ids in int tensor form
return next_token


def _update_padded_sequence(
Expand Down Expand Up @@ -399,11 +392,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)

# New token generated each iteration
total_prompts = len(prompt_lengths)
# need a new token dimension (row) for each prompt in the batch
new_token = torch.zeros(total_prompts, 1, device=device, dtype=torch.int64)
res = [[] for _ in range(total_prompts)]
num_tokens = 40
# need a row dimension for each prompt in the batch
new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

general comment- do we really need int64 for the dtype given that int32 holds to positive 2B? (meant to ask this earlier).
Seems like some minor savings and I don't see any vocabs getting to 2B anytime soon?
Can just leave as is for consistency and maybe a single PR in the future to sweep this up if we agree int32 is fine.

# Store the generated tokens
res = []

# Prefill phase
# Run context input through pipeline
Expand All @@ -422,23 +414,24 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)

# Decode the output -- first generated token
if pp_rank == last_pp_rank:
decode_results = _batch_decode_next_tokens(
output=output,
tokenizer=tokenizer,
prompt_lengths=prompt_lengths,
)
for i in range(len(decode_results)):
new_token[i, 0] = torch.tensor(
[decode_results[i][0]], device=device
) # token_id in int form
# Decode token id into string and print it
def decode_in_flight(token):
# Make a 2D tensor with ids on row dimension
unsqueezed = torch.unsqueeze(token, 1)
token_str = tokenizer.decode(unsqueezed.tolist())
if tp_rank == 0:
logger.info(
f"{color.green} {'* Prefill *'} "
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
f"{color.green} responses ====>>>> "
f"{color.blue} {token_str} {color.reset}"
)

# Decode the output -- first generated token
if pp_rank == last_pp_rank:
new_token = _batch_decode_next_tokens(output, prompt_lengths[0] - 1)
res.append(new_token)
if not args.disable_in_flight_decode:
decode_in_flight(new_token)

# seqlen = 1 now
seqlen_decode = 1
input_pos = torch.tensor([prompt_lengths[0]], device=device)
Expand All @@ -460,7 +453,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:

# Decoding
with torch.no_grad(), CUDATrackTime() as timer:
for step in range(num_tokens - 1):
for step in range(args.ntokens - 1):
kwargs = {"input_pos": input_pos, "cache_lane": lane}
# sendrecv between last and first ranks, only if:
# first_pp_rank != last_pp_rank.
Expand All @@ -487,21 +480,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:

# Decode the output
if pp_rank == last_pp_rank:
decode_results = _batch_decode_next_tokens(
output=output, tokenizer=tokenizer
)
if tp_rank == 0:
logger.info(
f"{color.green} {'* Decode *'} "
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
)
# decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
for i in range(len(decode_results)):
res[i].append(decode_results[i][1])
new_token[i, 0] = torch.tensor(
[decode_results[i][0]], device=device
) # decode_results[i][0]
new_token = _batch_decode_next_tokens(output, 0)
res.append(new_token)
if not args.disable_in_flight_decode:
decode_in_flight(new_token)

# Increment input position
input_pos += 1

logger.info(
Expand All @@ -512,21 +496,19 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:

# output formatted response via last pp group and tp rank 0
if pp_rank == last_pp_rank and tp_rank == 0:
for i in range(len(prompt_lengths)):
logger.info(f"\nPrompt:{color.green} {prompt[i]} {color.reset}")

# TODO: resolve issue with llama2-7b-chat model and "".join
if model_name != "llama2-7b-chat":
formatted_response = "".join(res[i])
else:
formatted_response = " ".join(res[i])
logger.info(f"$$ {color.red}{formatted_response} {color.reset} $$\n")
# `res` is a list of tensors, each being a batch of generated token ids
res = torch.stack(res, dim=1)
res_list = res.tolist()
response = tokenizer.decode(res_list)
for i in range(len(response)):
logger.info(f"Prompt: {color.green}{prompt[i]} {color.reset}")
logger.info(f"Response: {color.red}{response[i]} {color.reset}")

# Cleanup
_cleanup()
logger.info(
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
)
_cleanup()


if __name__ == "__main__":
Expand All @@ -538,6 +520,18 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
)
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
parser.add_argument(
"--ntokens",
type=int,
default=40,
help="Number of tokens to generate",
)
parser.add_argument(
"--disable-in-flight-decode",
action="store_true",
default=False,
help="Whether to decode token into string in flight",
)
args = parser.parse_args()

main(args)
Loading