@@ -187,8 +187,8 @@ def _create_padded_prompts(
187187
188188def _batch_decode_next_tokens (
189189 output : torch .Tensor ,
190- prompt_lengths : List [int ],
191190 tokenizer ,
191+ prompt_lengths : Optional [List [int ]] = None ,
192192) -> List [Tuple [int , str ]]:
193193 """
194194 Decode the next token for each prompt in the batch.
@@ -201,7 +201,8 @@ def _batch_decode_next_tokens(
201201 results = []
202202
203203 for i in range (batch_size ):
204- next_token_logits = output [i , prompt_lengths [i ] - 1 , :]
204+ pos = prompt_lengths [i ] - 1 if prompt_lengths is not None else 0
205+ next_token_logits = output [i , pos , :]
205206
206207 # Argmax (deterministic) TODO: add temperature
207208 next_token = torch .argmax (next_token_logits , dim = - 1 )
@@ -293,7 +294,7 @@ def main(args):
293294 logger .info (f"Model: { model } " )
294295
295296 mbs = 1 # number of micro-batches
296- mb_size = 5 # micro-batch size
297+ mb_size = 4 # micro-batch size
297298 batch_size = mbs * mb_size # total batch size
298299
299300 seqlen = 4096 # sequence length
@@ -309,13 +310,14 @@ def main(args):
309310 activation = torch .rand (
310311 mb_size , seqlen , dim , device = device , dtype = model_dtype
311312 )
312- example_args = mb_ids if pp_rank == 0 else activation
313+ example_inputs = (mb_ids if pp_rank == 0 else activation ,)
314+ example_outputs = (activation ,)
313315
314316 # Load weights
315317 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
316-
317318 with CUDATrackTime () as timer :
318319 _load_model_weights (model , distribution , device = device , model_config = config )
320+ model .to (device )
319321
320322 logger .info (
321323 f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for stage { rank } { color .reset } "
@@ -334,22 +336,24 @@ def main(args):
334336 model .setup_input_pos (input_pos )
335337 model .eval ()
336338
337- logger .info (f"Creating pipeline stage { pp_rank = } , { pp_degree = } " )
338- stage = PipelineStage (
339+ logger .info (f"Creating pipeline stage for prefill { pp_rank = } , { pp_degree = } " )
340+ prefill_stage = PipelineStage (
339341 model ,
340342 pp_rank ,
341343 pp_degree ,
342344 device ,
343- input_args = (example_args ,),
345+ input_args = example_inputs ,
346+ output_args = example_outputs ,
344347 group = pp_group ,
345348 )
349+ # create schedule
350+ prefill_schedule = ScheduleGPipe (prefill_stage , mbs )
346351
347352 prompt = [
348- "What is snow?" ,
349- "Where does Santa Claus live?" ,
350- "What is PyTorch?" ,
351- "Write a poem about the beauty of the night sky." ,
352- "What is the capital of France, Germany and Switzerland?" ,
353+ "What is a computer?" ,
354+ "Where does Santa live?" ,
355+ "Who is Abraham Lincoln?" ,
356+ "How are models trained?" ,
353357 ]
354358
355359 """
@@ -390,9 +394,10 @@ def main(args):
390394 padded_sequence , prompt_lengths = _create_padded_prompts (
391395 input_ids , tokenizer , seqlen , start_pos , device
392396 )
393-
394- # create schedule
395- schedule = ScheduleGPipe (stage , mbs )
397+ # TODO: figure out how to set input_pos for each prompt in the batch then we
398+ # can remove this limitation.
399+ s = set (prompt_lengths )
400+ assert len (s ) == 1 , f"prompt_lengths should be the same, got { s } "
396401
397402 # with CUDATrackTime() as timer:
398403 first_pp_rank = 0
@@ -408,25 +413,92 @@ def main(args):
408413 res = [[] for _ in range (total_prompts )]
409414 num_tokens = 40
410415
416+ # Prefill phase
417+ # Run context input through pipeline, in 1 step
418+ with torch .no_grad ():
419+ if pp_rank == first_pp_rank :
420+ output = prefill_schedule .step (padded_sequence )
421+ elif pp_rank == last_pp_rank :
422+ output = prefill_schedule .step ()
423+ else : # middle pp ranks
424+ prefill_schedule .step ()
425+
426+ # Decode the output -- first generated token
427+ if pp_rank == last_pp_rank :
428+ decode_results = _batch_decode_next_tokens (
429+ output = output ,
430+ tokenizer = tokenizer ,
431+ prompt_lengths = prompt_lengths ,
432+ )
433+ for i in range (len (decode_results )):
434+ new_token [i , 0 ] = torch .tensor (
435+ [decode_results [i ][0 ]], device = device
436+ ) # token_id in int form
437+ if tp_rank == 0 :
438+ logger .info (
439+ f"{ color .green } { '* Prefill *' } "
440+ f"responses ====>>>> { color .blue } { decode_results = } { color .reset } "
441+ )
442+
443+ # seqlen = 1 now
444+ seqlen_decode = 1
445+ mb_ids = torch .randint (0 , config .vocab_size , (mb_size , seqlen_decode ), device = device )
446+ activation = torch .rand (
447+ mb_size , seqlen_decode , dim , device = device , dtype = model_dtype
448+ )
449+ example_inputs = (mb_ids if pp_rank == 0 else activation ,)
450+ example_outputs = (activation ,)
451+
452+ input_pos = torch .tensor ([prompt_lengths [0 ]], device = device )
453+ model .setup_input_pos (input_pos )
454+
455+ logger .info (f"Creating pipeline stage for decode { pp_rank = } , { pp_degree = } " )
456+ decode_stage = PipelineStage (
457+ model ,
458+ pp_rank ,
459+ pp_degree ,
460+ device ,
461+ input_args = example_inputs ,
462+ output_args = example_outputs ,
463+ group = pp_group ,
464+ )
465+ # create schedule
466+ decode_schedule = ScheduleGPipe (decode_stage , mbs )
467+
411468 # Decoding
412469 with torch .no_grad ():
413- for step in range (num_tokens ):
470+ for step in range (num_tokens - 1 ):
471+ # sendrecv between last and first ranks, only if:
472+ # first_pp_rank != last_pp_rank.
473+ if pp_rank == last_pp_rank and pp_rank != first_pp_rank :
474+ dist .send (
475+ new_token ,
476+ dst = first_pp_rank_global_id ,
477+ group = pp_group ,
478+ )
479+ elif pp_rank == first_pp_rank and pp_rank != last_pp_rank :
480+ dist .recv (
481+ new_token ,
482+ src = last_pp_rank_global_id ,
483+ group = pp_group ,
484+ )
485+
414486 # Run data through pipeline
415487 if pp_rank == first_pp_rank :
416- output = schedule .step (padded_sequence )
488+ output = decode_schedule .step (new_token )
417489 elif pp_rank == last_pp_rank :
418- output = schedule .step ()
490+ output = decode_schedule .step ()
419491 else : # middle pp ranks
420- schedule .step ()
492+ decode_schedule .step ()
421493
422494 # Decode the output
423495 if pp_rank == last_pp_rank :
424496 decode_results = _batch_decode_next_tokens (
425- output = output , prompt_lengths = prompt_lengths , tokenizer = tokenizer
497+ output = output , tokenizer = tokenizer
426498 )
427499 if tp_rank == 0 :
428500 logger .info (
429- f"{ color .green } { 'Prefill' if step == 0 else ' * Decode *'} "
501+ f"{ color .green } { '* Decode *' } "
430502 f"responses ====>>>> { color .blue } { decode_results = } { color .reset } "
431503 )
432504 # decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
@@ -436,28 +508,8 @@ def main(args):
436508 [decode_results [i ][0 ]], device = device
437509 ) # decode_results[i][0]
438510
439- # sendrecv between last and first ranks, only if:
440- # first_pp_rank != last_pp_rank.
441- if pp_rank == last_pp_rank and pp_rank != first_pp_rank :
442- dist .send (
443- new_token ,
444- dst = first_pp_rank_global_id ,
445- group = pp_group ,
446- )
447- elif pp_rank == first_pp_rank and pp_rank != last_pp_rank :
448- dist .recv (
449- new_token ,
450- src = last_pp_rank_global_id ,
451- group = pp_group ,
452- )
453-
454- # Update input sequence with new token
455- if pp_rank == first_pp_rank :
456- _update_padded_sequence (padded_sequence , new_token , prompt_lengths )
457-
458- # increment prompt lengths for next token
459- for i in range (len (prompt_lengths )):
460- prompt_lengths [i ] += 1
511+ input_pos += 1
512+ model .setup_input_pos (input_pos )
461513
462514 # Display the decoding results
463515
0 commit comments