88import os
99from pathlib import Path
1010from types import SimpleNamespace
11- from typing import Any , Dict , Optional
11+ from typing import Any , Dict , List , Optional , Tuple
1212
1313# Run command:
1414# torchrun --nproc-per-node 4 dist_run.py
1515import torch
1616import torch .distributed as dist
17- from torch .distributed .pipelining import PipelineStage , ScheduleGPipe
18-
1917
2018from distributed .logging_utils import SingletonLogger
2119
2523 get_hf_weight_map_and_path ,
2624 load_safetensor_weights ,
2725)
28-
2926from distributed .utils import (
27+ bytes_to_readable ,
3028 Color as color ,
31- GPUMemoryMonitor ,
29+ CUDATrackTime ,
3230 get_module_size ,
3331 get_num_params ,
34- bytes_to_readable ,
35- TrackTime ,
36- CUDATrackTime ,
32+ GPUMemoryMonitor ,
3733)
38-
3934from distributed .verification_utils import find_cpu_tensors
40- from torchchat .cli .builder import TokenizerArgs , _initialize_tokenizer
35+ from torch .distributed .pipelining import PipelineStage , ScheduleGPipe
36+ from torchchat .cli .builder import _initialize_tokenizer , TokenizerArgs
4137from torchchat .model import ModelArgs , Transformer
4238from torchchat .utils .build_utils import set_precision
4339
@@ -136,6 +132,99 @@ def _load_model_weights(stage_module, distribution, device, model_config):
136132 raise ValueError (f"Missing { num_missing_weights } weights" )
137133
138134
135+ def _encode_strings (
136+ strings : List [str ],
137+ tokenizer ,
138+ bos : bool = True ,
139+ device : torch .device = "cuda:0" ,
140+ dtype = torch .int64 ,
141+ ) -> List [torch .Tensor ]:
142+ """Encode a list of prompt strings into a list of tensor token ids."""
143+ encoded_list = []
144+ for string in strings :
145+ tokens = tokenizer .encode (string )
146+ if bos :
147+ tokens = [tokenizer .bos_id ()] + tokens
148+ encoded_list .append (torch .tensor (tokens , dtype = dtype , device = device ))
149+ return encoded_list
150+
151+
152+ def _create_padded_prompts (
153+ input_ids_list : List [torch .Tensor ],
154+ tokenizer ,
155+ seqlen : int ,
156+ start_pos : int ,
157+ device : torch .device ,
158+ pad_token_id : Optional [int ] = None ,
159+ ) -> Tuple [torch .Tensor , List [int ]]:
160+ """
161+ Create a padded tensor for multiple encoded input prompts.
162+
163+ Returns:
164+ Tuple[torch.Tensor, List[int]]: A tuple containing the padded tensor and a list of prompt lengths.
165+ """
166+ pad_token_id = pad_token_id if pad_token_id is not None else tokenizer .eos_id ()
167+
168+ # Find the maximum prompt length
169+ max_prompt_len = max (ids .size (0 ) for ids in input_ids_list )
170+
171+ # Calculate the buffer size
172+ max_new_tokens = max (0 , min (seqlen - start_pos , seqlen - max_prompt_len ))
173+ token_buffer_size = max_prompt_len + max_new_tokens
174+
175+ # Create the padded batch tensor
176+ batch_size = len (input_ids_list )
177+ batch_seq = torch .full (
178+ (batch_size , token_buffer_size ), pad_token_id , dtype = torch .int64 , device = device
179+ )
180+
181+ prompt_lengths = []
182+ for i , input_ids in enumerate (input_ids_list ):
183+ prompt_len = input_ids .size (0 )
184+ batch_seq [i , :prompt_len ] = input_ids
185+ prompt_lengths .append (prompt_len )
186+
187+ return batch_seq , prompt_lengths
188+
189+
190+ def _batch_decode_next_tokens (
191+ output : torch .Tensor ,
192+ prompt_lengths : List [int ],
193+ tokenizer ,
194+ ) -> List [Tuple [int , str ]]:
195+ """
196+ Decode the next token for each prompt in the batch.
197+
198+ Returns:
199+ List[Tuple[int, str]]: List of tuples containing the next token id and its
200+ decoded string for each prompt in the batch.
201+ """
202+ batch_size = output .shape [0 ]
203+ results = []
204+
205+ for i in range (batch_size ):
206+ next_token_logits = output [i , prompt_lengths [i ] - 1 , :]
207+
208+ # Argmax (deterministic) TODO: add temperature
209+ next_token = torch .argmax (next_token_logits , dim = - 1 )
210+
211+ next_token_decoded = tokenizer .decode ([next_token .item ()])
212+ results .append ((next_token .item (), next_token_decoded ))
213+
214+ return results
215+
216+
217+ def _update_padded_sequence (
218+ padded_sequence : torch .Tensor ,
219+ x_recv : torch .Tensor ,
220+ res ,
221+ prompt_lengths : List [int ],
222+ ) -> None :
223+ for i in range (len (prompt_lengths )):
224+ prompt_lengths [i ] += 1
225+ padded_sequence [i , prompt_lengths [i ] - 1 ] = x_recv
226+
227+
139228def _cleanup ():
140229 dist .barrier ()
141230 dist .destroy_process_group ()
@@ -180,6 +269,17 @@ def main(args):
180269 pp_mesh = device_mesh ["pp" ]
181270 tp_rank = tp_mesh .get_local_rank ()
182271 pp_rank = pp_mesh .get_local_rank ()
272+ tp_group = tp_mesh .get_group ()
273+ pp_group = pp_mesh .get_group ()
274+
275+ logger .info (f"review: { pp_group = } , { tp_group = } " )
276+
277+ logger .info (f"Created device mesh: { device_mesh } \n { tp_mesh = } , { pp_mesh = } \n " )
278+ # TODO - this assumes 1D mesh, need to update for 2D+ mesh
279+ pp_group_size = pp_mesh .size ()
280+ tp_group_size = tp_mesh .size ()
281+
282+ logger .info (f"pp_group_size: { pp_group_size } , tp_group_size: { tp_group_size } " )
183283
184284 # Assuming same number of GPUs per node
185285 device = torch .device (f"cuda:{ rank % torch .cuda .device_count ()} " )
@@ -198,9 +298,10 @@ def main(args):
198298 if rank == 0 :
199299 logger .info (f"Model: { model } " )
200300
201- mbs = 2 # number of micro-batches
301+ mbs = 1 # number of micro-batches
202302 mb_size = 1 # micro-batch size
203303 batch_size = mbs * mb_size # total batch size
304+
204305 seqlen = 4096 # sequence length
205306 dim = 4096 # embedding dimension
206307 assert seqlen % sp_degree == 0
@@ -213,8 +314,10 @@ def main(args):
213314
214315 # Load weights
215316 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
216- with TrackTime ("cuda" ) as timer :
217- _load_model_weights (model , distribution , device = device , model_config = config )
317+
318+ with CUDATrackTime () as timer :
319+ _load_model_weights (model , hf_model_name , device = device , model_config = config )
320+
218321 logger .info (
219322 f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for stage { rank } { color .reset } "
220323 )
@@ -226,9 +329,8 @@ def main(args):
226329 logger .info (
227330 f"Stage { rank } has { color .blue } { stage_num_params } params{ color .reset } , Size: { color .blue } { stage_size_formatted } { color .reset } \n "
228331 )
229-
230- # Setup input position
231- # input_pos for prefill: a list of increasing integers from 0 to seqlen
332+
333+ # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
232334 input_pos = torch .arange (seqlen , device = device )
233335 model .setup_input_pos (input_pos )
234336 model .eval ()
@@ -249,41 +351,129 @@ def main(args):
249351 if len (cpu_tensors ) > 0 :
250352 raise ValueError ("Found cpu tensors in stage" )
251353
252- # TODO: this can likely be removed after we prove out a few more models
253- # verify dtypes for model - expect all to be model_dtype except for bool causal_mask atm.
254- # dtype_count, dtype_locations, fp32_locations = record_module_dtypes(stage.submod)
255- # logger.info(
256- # f"Stage Dtypes - Found {len(dtype_count)} dtypes: {dtype_count.items()}"
257- # )
258- # assert (
259- # len(dtype_count) == 2
260- # ), f"Expected 2 dtypes in model after checkpoint loading: {model_dtype} and {torch.bool}"
354+ prompt = [
355+ "What is snow?" ,
356+ ]
261357
262- input_ids = torch .randint (0 , config .vocab_size , (batch_size , seqlen ), device = device )
263- logger .info (f"Input: { input_ids .dtype = } , { input_ids .shape = } , { input_ids .device = } " )
358+ """
359+ "What is the capital of France?",
360+ "What is your name?",
361+ "What is the capital of Japan?",
362+ "When is Christmas?",
363+ "Where does Santa Claus live?",
364+ "What is the capital of the United States?",
365+ "What is the capital of China?",
366+ "What is the capital of Russia?",
367+ "What is PyTorch?",
368+ "What is the capital of India?",
369+ "What is an LLM?",
370+ "What is the capital of Brazil?",
371+ "What is the capital of Mexico?",
372+ "What is the capital of Argentina?",
373+ "What is the capital of Canada?",
374+ ]
375+ """
264376
265- schedule = ScheduleGPipe (stage , mbs )
266- logger .info (f"Created schedule: { schedule } " )
267377
268- with torch .no_grad (): # .inference_mode():
269- if pp_rank == 0 :
270- output = schedule .step (input_ids )
271- else :
272- output = schedule .step ()
378+ start_pos = 0
273379
274- if pp_rank == pp_degree - 1 and tp_rank == 0 :
275- logger .info (f"Output: { output } " )
380+ # encode the prompt
381+ input_ids = _encode_strings (
382+ prompt , tokenizer , bos = True , device = device , dtype = torch .int64
383+ )
384+ logger .info (f"{ input_ids [0 :8 ]= } " )
276385
277- # show peak memory stats for this stage
278- res_mem_gib , res_mem_pct = gpu_memory_monitor .get_peak_stats ()
279- logger .info (
280- f"{ color .blue } Memory used: { color .green } { res_mem_pct :.3f} %, { color .magenta } { res_mem_gib :.3f} GB{ color .reset } "
386+ # create a padded tensor for the input prompt
387+ padded_sequence , prompt_lengths = _create_padded_prompts (
388+ input_ids , tokenizer , seqlen , start_pos , device
281389 )
390+ logger .info (f"{ prompt_lengths = } " )
391+ logger .info (f"first prompt { padded_sequence [0 , :prompt_lengths [0 ]+ 1 ]= } " )
392+ if len (prompt_lengths ) > 1 :
393+ logger .info (f"second prompt { padded_sequence [1 , :prompt_lengths [1 ]+ 1 ]= } " )
394+
395+ schedule = ScheduleGPipe (stage , mbs )
396+ logger .info (f"Created schedule: { schedule } " )
397+
398+ # with CUDATrackTime() as timer:
399+ first_pp_group = 0
400+ last_pp_group = pp_group_size - 1
401+
402+ x_recv = torch .zeros (1 , device = device , dtype = torch .int64 )
403+ logger .info (f"{ x_recv .shape = } " )
404+
405+ last_global_rank = world_size - 1
406+ res = []
407+ dst = None
408+ src = None
409+
410+ if pp_rank == last_pp_group :
411+ dst = dist .get_global_rank (pp_group , 0 )
412+ elif pp_rank == 0 :
413+ src = dist .get_global_rank (pp_group , last_pp_group )
414+
415+ # Decoding
416+ num_tokens = 40
417+
418+ with torch .no_grad ():
419+ for step in range (num_tokens ):
420+ # first
421+ if pp_rank == 0 :
422+ schedule .step (padded_sequence )
423+ # only receive if not last step
424+ if step < num_tokens - 1 :
425+ dist .recv (
426+ x_recv ,
427+ src ,
428+ group = pp_group ,
429+ )
430+ _update_padded_sequence (
431+ padded_sequence , x_recv , res , prompt_lengths
432+ )
433+
434+ # last
435+ elif pp_rank == last_pp_group :
436+ output = schedule .step ()
437+ # need to decode the output
438+ decode_results = _batch_decode_next_tokens (
439+ output = output , prompt_lengths = prompt_lengths , tokenizer = tokenizer
440+ )
441+ if tp_rank == 0 :
442+ logger .info (
443+ f"\n \n { color .green } { 'Prefill' if step == 0 else '* Decode *' } responses ====>>>> { color .blue } { decode_results = } \n { color .reset } "
444+ )
445+
446+ next_token = torch .tensor ([decode_results [0 ][0 ]], device = device )
447+ res .append (decode_results [0 ][1 ])
448+
449+ # increment prompt lengths for next token
450+ for i in range (len (prompt_lengths )):
451+ prompt_lengths [i ] += 1
452+ # logger.info(
453+ # f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}"
454+ # )
455+
456+ # only send if not last step
457+ if step < (num_tokens - 1 ):
458+ dist .send (
459+ next_token ,
460+ dst ,
461+ pp_group ,
462+ )
463+
464+ # middle pp ranks
465+ else :
466+ schedule .step ()
467+
468+ # output formatted response via last pp group and tp rank 0
469+ if pp_rank == last_pp_group and tp_rank == 0 :
470+ logger .info (f"\n Prompt:{ color .green } { prompt [0 ]} { color .reset } " )
471+ formatted_response = "" .join (res )
472+ logger .info (f"$$$$$$ { color .blue } { formatted_response } \n { color .reset } $$$$$" )
282473
283474 logger .info (
284475 f"{ color .green } Success{ color .white } - { color .blue } Rank { rank } has completed.{ color .reset } "
285476 )
286-
287477 _cleanup ()
288478
289479
0 commit comments