@@ -187,8 +187,8 @@ def _create_padded_prompts(
187187
188188def _batch_decode_next_tokens (
189189 output : torch .Tensor ,
190- prompt_lengths : List [int ],
191190 tokenizer ,
191+ prompt_lengths : Optional [List [int ]] = None ,
192192) -> List [Tuple [int , str ]]:
193193 """
194194 Decode the next token for each prompt in the batch.
@@ -201,7 +201,8 @@ def _batch_decode_next_tokens(
201201 results = []
202202
203203 for i in range (batch_size ):
204- next_token_logits = output [i , prompt_lengths [i ] - 1 , :]
204+ pos = prompt_lengths [i ] - 1 if prompt_lengths is not None else 0
205+ next_token_logits = output [i , pos , :]
205206
206207 # Argmax (deterministic) TODO: add temperature
207208 next_token = torch .argmax (next_token_logits , dim = - 1 )
@@ -276,6 +277,10 @@ def main(args):
276277 tp_group_size = tp_group .size ()
277278 logger .info (f"{ pp_group_size = } , { tp_group_size = } " )
278279
280+ # Convenience variables
281+ first_pp_rank = 0
282+ last_pp_rank = pp_group_size - 1
283+
279284 # Assuming same number of GPUs per node
280285 device = torch .device (f"cuda:{ rank % torch .cuda .device_count ()} " )
281286
@@ -293,29 +298,23 @@ def main(args):
293298 logger .info (f"Model: { model } " )
294299
295300 mbs = 1 # number of micro-batches
296- mb_size = 5 # micro-batch size
301+ mb_size = 4 # micro-batch size
297302 batch_size = mbs * mb_size # total batch size
298303
299- seqlen = 4096 # sequence length
304+ seqlen_prefill = 1024 # sequence length
300305 dim = 4096 # embedding dimension
301306
302307 # Setup KV caches (after model distribution)
303308 # TODO: the setting below only works for 1 micro-batch case. To support
304309 # multiple micro-batches, we need the KV cache in the model to be aware of
305310 # the number of micro-batches and the current micro-batch index.
306- model .setup_caches (mb_size , seqlen )
307-
308- mb_ids = torch .randint (0 , config .vocab_size , (mb_size , seqlen ), device = device )
309- activation = torch .rand (
310- mb_size , seqlen , dim , device = device , dtype = model_dtype
311- )
312- example_args = mb_ids if pp_rank == 0 else activation
311+ model .setup_caches (mb_size , seqlen_prefill )
313312
314313 # Load weights
315314 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
316-
317315 with CUDATrackTime () as timer :
318316 _load_model_weights (model , distribution , device = device , model_config = config )
317+ model .to (device )
319318
320319 logger .info (
321320 f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for stage { rank } { color .reset } "
@@ -330,53 +329,47 @@ def main(args):
330329 )
331330
332331 # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
333- input_pos = torch .arange (seqlen , device = device )
332+ input_pos = torch .arange (seqlen_prefill , device = device )
334333 model .setup_input_pos (input_pos )
335334 model .eval ()
336335
337- logger .info (f"Creating pipeline stage { pp_rank = } , { pp_degree = } " )
338- stage = PipelineStage (
336+ # Helper function to get example inputs and outputs for the stages.
337+ def get_example_ins_outs (seqlen : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
338+ mb_ids = torch .randint (0 , config .vocab_size , (mb_size , seqlen ), device = device )
339+ activation = torch .rand (
340+ mb_size , seqlen , dim , device = device , dtype = model_dtype
341+ )
342+ logits = torch .rand (
343+ mb_size , seqlen , config .vocab_size , device = device , dtype = model_dtype
344+ )
345+ example_inputs = (mb_ids if pp_rank == first_pp_rank else activation ,)
346+ example_outputs = (logits if pp_rank == last_pp_rank else activation ,)
347+ return example_inputs , example_outputs
348+
349+ # Create prefill stage
350+ logger .info (f"Creating pipeline stage for prefill { pp_rank = } , { pp_degree = } " )
351+ example_inputs , example_outputs = get_example_ins_outs (seqlen_prefill )
352+ prefill_stage = PipelineStage (
339353 model ,
340354 pp_rank ,
341355 pp_degree ,
342356 device ,
343- input_args = (example_args ,),
357+ input_args = example_inputs ,
358+ output_args = example_outputs ,
344359 group = pp_group ,
345360 )
361+ # create schedule
362+ prefill_schedule = ScheduleGPipe (prefill_stage , mbs )
346363
347364 prompt = [
348- "What is snow?" ,
349- "Where does Santa Claus live?" ,
350- "What is PyTorch?" ,
351- "Write a poem about the beauty of the night sky." ,
352- "What is the capital of France, Germany and Switzerland?" ,
353- ]
354-
355- """
356- "What is the capital of France?",
357- "What is your name?",
358- "What is the capital of Japan?",
359- "When is Christmas?",
360- "Where does Santa Claus live?",
361- "What is the capital of the United States?",
362- "What is the capital of China?",
363- "What is the capital of Russia?",
364- "What is PyTorch?",
365- "What is the capital of India?",
366- "What is an LLM?",
367- "What is the capital of Brazil?",
368- "What is the capital of Mexico?",
369- "What is the capital of Argentina?",
370- "What is the capital of Canada?",
365+ "What is a computer?" ,
366+ "Where does Santa live?" ,
367+ "Who is Abraham Lincoln?" ,
368+ "How are models trained?" ,
371369 ]
372- """
373370
374371 start_pos = 0
375372
376- # pipeline comms setup
377- first_pp_rank = 0
378- last_pp_rank = pp_group_size - 1
379-
380373 # Need these global ids due to the API definition of dist.send and recv
381374 first_pp_rank_global_id = dist .get_global_rank (pp_group , first_pp_rank )
382375 last_pp_rank_global_id = dist .get_global_rank (pp_group , last_pp_rank )
@@ -388,15 +381,14 @@ def main(args):
388381
389382 # create a padded tensor for the input prompt
390383 padded_sequence , prompt_lengths = _create_padded_prompts (
391- input_ids , tokenizer , seqlen , start_pos , device
384+ input_ids , tokenizer , seqlen_prefill , start_pos , device
392385 )
393-
394- # create schedule
395- schedule = ScheduleGPipe (stage , mbs )
386+ # TODO: figure out how to set input_pos for each prompt in the batch then we
387+ # can remove this limitation.
388+ s = set (prompt_lengths )
389+ assert len (s ) == 1 , f"prompt_lengths should be the same, got { s } "
396390
397391 # with CUDATrackTime() as timer:
398- first_pp_rank = 0
399- last_pp_rank = pp_group_size - 1
400392 # Need these global ids due to the API definition of dist.send and recv
401393 first_pp_rank_global_id = dist .get_global_rank (pp_group , first_pp_rank )
402394 last_pp_rank_global_id = dist .get_global_rank (pp_group , last_pp_rank )
@@ -408,25 +400,87 @@ def main(args):
408400 res = [[] for _ in range (total_prompts )]
409401 num_tokens = 40
410402
403+ # Prefill phase
404+ # Run context input through pipeline, in 1 step
405+ with torch .no_grad ():
406+ if pp_rank == first_pp_rank :
407+ output = prefill_schedule .step (padded_sequence )
408+ elif pp_rank == last_pp_rank :
409+ output = prefill_schedule .step ()
410+ else : # middle pp ranks
411+ prefill_schedule .step ()
412+
413+ # Decode the output -- first generated token
414+ if pp_rank == last_pp_rank :
415+ decode_results = _batch_decode_next_tokens (
416+ output = output ,
417+ tokenizer = tokenizer ,
418+ prompt_lengths = prompt_lengths ,
419+ )
420+ for i in range (len (decode_results )):
421+ new_token [i , 0 ] = torch .tensor (
422+ [decode_results [i ][0 ]], device = device
423+ ) # token_id in int form
424+ if tp_rank == 0 :
425+ logger .info (
426+ f"{ color .green } { '* Prefill *' } "
427+ f"responses ====>>>> { color .blue } { decode_results = } { color .reset } "
428+ )
429+
430+ # seqlen = 1 now
431+ seqlen_decode = 1
432+ input_pos = torch .tensor ([prompt_lengths [0 ]], device = device )
433+ model .setup_input_pos (input_pos )
434+
435+ # Create decode stage
436+ logger .info (f"Creating pipeline stage for decode { pp_rank = } , { pp_degree = } " )
437+ example_inputs , example_outputs = get_example_ins_outs (seqlen_decode )
438+ decode_stage = PipelineStage (
439+ model ,
440+ pp_rank ,
441+ pp_degree ,
442+ device ,
443+ input_args = example_inputs ,
444+ output_args = example_outputs ,
445+ group = pp_group ,
446+ )
447+ # create schedule
448+ decode_schedule = ScheduleGPipe (decode_stage , mbs )
449+
411450 # Decoding
412451 with torch .no_grad ():
413- for step in range (num_tokens ):
452+ for step in range (num_tokens - 1 ):
453+ # sendrecv between last and first ranks, only if:
454+ # first_pp_rank != last_pp_rank.
455+ if pp_rank == last_pp_rank and pp_rank != first_pp_rank :
456+ dist .send (
457+ new_token ,
458+ dst = first_pp_rank_global_id ,
459+ group = pp_group ,
460+ )
461+ elif pp_rank == first_pp_rank and pp_rank != last_pp_rank :
462+ dist .recv (
463+ new_token ,
464+ src = last_pp_rank_global_id ,
465+ group = pp_group ,
466+ )
467+
414468 # Run data through pipeline
415469 if pp_rank == first_pp_rank :
416- output = schedule .step (padded_sequence )
470+ output = decode_schedule .step (new_token )
417471 elif pp_rank == last_pp_rank :
418- output = schedule .step ()
472+ output = decode_schedule .step ()
419473 else : # middle pp ranks
420- schedule .step ()
474+ decode_schedule .step ()
421475
422476 # Decode the output
423477 if pp_rank == last_pp_rank :
424478 decode_results = _batch_decode_next_tokens (
425- output = output , prompt_lengths = prompt_lengths , tokenizer = tokenizer
479+ output = output , tokenizer = tokenizer
426480 )
427481 if tp_rank == 0 :
428482 logger .info (
429- f"{ color .green } { 'Prefill' if step == 0 else ' * Decode *'} "
483+ f"{ color .green } { '* Decode *' } "
430484 f"responses ====>>>> { color .blue } { decode_results = } { color .reset } "
431485 )
432486 # decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
@@ -436,28 +490,8 @@ def main(args):
436490 [decode_results [i ][0 ]], device = device
437491 ) # decode_results[i][0]
438492
439- # sendrecv between last and first ranks, only if:
440- # first_pp_rank != last_pp_rank.
441- if pp_rank == last_pp_rank and pp_rank != first_pp_rank :
442- dist .send (
443- new_token ,
444- dst = first_pp_rank_global_id ,
445- group = pp_group ,
446- )
447- elif pp_rank == first_pp_rank and pp_rank != last_pp_rank :
448- dist .recv (
449- new_token ,
450- src = last_pp_rank_global_id ,
451- group = pp_group ,
452- )
453-
454- # Update input sequence with new token
455- if pp_rank == first_pp_rank :
456- _update_padded_sequence (padded_sequence , new_token , prompt_lengths )
457-
458- # increment prompt lengths for next token
459- for i in range (len (prompt_lengths )):
460- prompt_lengths [i ] += 1
493+ input_pos += 1
494+ model .setup_input_pos (input_pos )
461495
462496 # Display the decoding results
463497
0 commit comments