Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit b4630f2

Browse files
committed
[Distributed] Separate prefill and decode
1 parent 3b896db commit b4630f2

File tree

1 file changed

+96
-44
lines changed

1 file changed

+96
-44
lines changed

dist_run.py

Lines changed: 96 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ def _create_padded_prompts(
187187

188188
def _batch_decode_next_tokens(
189189
output: torch.Tensor,
190-
prompt_lengths: List[int],
191190
tokenizer,
191+
prompt_lengths: Optional[List[int]] = None,
192192
) -> List[Tuple[int, str]]:
193193
"""
194194
Decode the next token for each prompt in the batch.
@@ -201,7 +201,8 @@ def _batch_decode_next_tokens(
201201
results = []
202202

203203
for i in range(batch_size):
204-
next_token_logits = output[i, prompt_lengths[i] - 1, :]
204+
pos = prompt_lengths[i] - 1 if prompt_lengths is not None else 0
205+
next_token_logits = output[i, pos, :]
205206

206207
# Argmax (deterministic) TODO: add temperature
207208
next_token = torch.argmax(next_token_logits, dim=-1)
@@ -293,7 +294,7 @@ def main(args):
293294
logger.info(f"Model: {model}")
294295

295296
mbs = 1 # number of micro-batches
296-
mb_size = 5 # micro-batch size
297+
mb_size = 4 # micro-batch size
297298
batch_size = mbs * mb_size # total batch size
298299

299300
seqlen = 4096 # sequence length
@@ -309,13 +310,14 @@ def main(args):
309310
activation = torch.rand(
310311
mb_size, seqlen, dim, device=device, dtype=model_dtype
311312
)
312-
example_args = mb_ids if pp_rank == 0 else activation
313+
example_inputs = (mb_ids if pp_rank == 0 else activation,)
314+
example_outputs = (activation,)
313315

314316
# Load weights
315317
logger.info(f"Loading weights for {pp_rank=} on {device=}")
316-
317318
with CUDATrackTime() as timer:
318319
_load_model_weights(model, distribution, device=device, model_config=config)
320+
model.to(device)
319321

320322
logger.info(
321323
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
@@ -334,22 +336,24 @@ def main(args):
334336
model.setup_input_pos(input_pos)
335337
model.eval()
336338

337-
logger.info(f"Creating pipeline stage {pp_rank=}, {pp_degree=}")
338-
stage = PipelineStage(
339+
logger.info(f"Creating pipeline stage for prefill {pp_rank=}, {pp_degree=}")
340+
prefill_stage = PipelineStage(
339341
model,
340342
pp_rank,
341343
pp_degree,
342344
device,
343-
input_args=(example_args,),
345+
input_args=example_inputs,
346+
output_args=example_outputs,
344347
group=pp_group,
345348
)
349+
# create schedule
350+
prefill_schedule = ScheduleGPipe(prefill_stage, mbs)
346351

347352
prompt = [
348-
"What is snow?",
349-
"Where does Santa Claus live?",
350-
"What is PyTorch?",
351-
"Write a poem about the beauty of the night sky.",
352-
"What is the capital of France, Germany and Switzerland?",
353+
"What is a computer?",
354+
"Where does Santa live?",
355+
"Who is Abraham Lincoln?",
356+
"How are models trained?",
353357
]
354358

355359
"""
@@ -390,9 +394,10 @@ def main(args):
390394
padded_sequence, prompt_lengths = _create_padded_prompts(
391395
input_ids, tokenizer, seqlen, start_pos, device
392396
)
393-
394-
# create schedule
395-
schedule = ScheduleGPipe(stage, mbs)
397+
# TODO: figure out how to set input_pos for each prompt in the batch then we
398+
# can remove this limitation.
399+
s = set(prompt_lengths)
400+
assert len(s) == 1, f"prompt_lengths should be the same, got {s}"
396401

397402
# with CUDATrackTime() as timer:
398403
first_pp_rank = 0
@@ -408,25 +413,92 @@ def main(args):
408413
res = [[] for _ in range(total_prompts)]
409414
num_tokens = 40
410415

416+
# Prefill phase
417+
# Run context input through pipeline, in 1 step
418+
with torch.no_grad():
419+
if pp_rank == first_pp_rank:
420+
output = prefill_schedule.step(padded_sequence)
421+
elif pp_rank == last_pp_rank:
422+
output = prefill_schedule.step()
423+
else: # middle pp ranks
424+
prefill_schedule.step()
425+
426+
# Decode the output -- first generated token
427+
if pp_rank == last_pp_rank:
428+
decode_results = _batch_decode_next_tokens(
429+
output=output,
430+
tokenizer=tokenizer,
431+
prompt_lengths=prompt_lengths,
432+
)
433+
for i in range(len(decode_results)):
434+
new_token[i, 0] = torch.tensor(
435+
[decode_results[i][0]], device=device
436+
) # token_id in int form
437+
if tp_rank == 0:
438+
logger.info(
439+
f"{color.green} {'* Prefill *'} "
440+
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
441+
)
442+
443+
# seqlen = 1 now
444+
seqlen_decode = 1
445+
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen_decode), device=device)
446+
activation = torch.rand(
447+
mb_size, seqlen_decode, dim, device=device, dtype=model_dtype
448+
)
449+
example_inputs = (mb_ids if pp_rank == 0 else activation,)
450+
example_outputs = (activation,)
451+
452+
input_pos = torch.tensor([prompt_lengths[0]], device=device)
453+
model.setup_input_pos(input_pos)
454+
455+
logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}")
456+
decode_stage = PipelineStage(
457+
model,
458+
pp_rank,
459+
pp_degree,
460+
device,
461+
input_args=example_inputs,
462+
output_args=example_outputs,
463+
group=pp_group,
464+
)
465+
# create schedule
466+
decode_schedule = ScheduleGPipe(decode_stage, mbs)
467+
411468
# Decoding
412469
with torch.no_grad():
413-
for step in range(num_tokens):
470+
for step in range(num_tokens - 1):
471+
# sendrecv between last and first ranks, only if:
472+
# first_pp_rank != last_pp_rank.
473+
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
474+
dist.send(
475+
new_token,
476+
dst=first_pp_rank_global_id,
477+
group=pp_group,
478+
)
479+
elif pp_rank == first_pp_rank and pp_rank != last_pp_rank:
480+
dist.recv(
481+
new_token,
482+
src=last_pp_rank_global_id,
483+
group=pp_group,
484+
)
485+
414486
# Run data through pipeline
415487
if pp_rank == first_pp_rank:
416-
output = schedule.step(padded_sequence)
488+
output = decode_schedule.step(new_token)
417489
elif pp_rank == last_pp_rank:
418-
output = schedule.step()
490+
output = decode_schedule.step()
419491
else: # middle pp ranks
420-
schedule.step()
492+
decode_schedule.step()
421493

422494
# Decode the output
423495
if pp_rank == last_pp_rank:
424496
decode_results = _batch_decode_next_tokens(
425-
output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer
497+
output=output, tokenizer=tokenizer
426498
)
427499
if tp_rank == 0:
428500
logger.info(
429-
f"{color.green} {'Prefill' if step == 0 else '* Decode *'} "
501+
f"{color.green} {'* Decode *'} "
430502
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
431503
)
432504
# decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
@@ -436,28 +508,8 @@ def main(args):
436508
[decode_results[i][0]], device=device
437509
) # decode_results[i][0]
438510

439-
# sendrecv between last and first ranks, only if:
440-
# first_pp_rank != last_pp_rank.
441-
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
442-
dist.send(
443-
new_token,
444-
dst=first_pp_rank_global_id,
445-
group=pp_group,
446-
)
447-
elif pp_rank == first_pp_rank and pp_rank != last_pp_rank:
448-
dist.recv(
449-
new_token,
450-
src=last_pp_rank_global_id,
451-
group=pp_group,
452-
)
453-
454-
# Update input sequence with new token
455-
if pp_rank == first_pp_rank:
456-
_update_padded_sequence(padded_sequence, new_token, prompt_lengths)
457-
458-
# increment prompt lengths for next token
459-
for i in range(len(prompt_lengths)):
460-
prompt_lengths[i] += 1
511+
input_pos += 1
512+
model.setup_input_pos(input_pos)
461513

462514
# Display the decoding results
463515

0 commit comments

Comments
 (0)