44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import argparse
78import os
89from pathlib import Path
910from types import SimpleNamespace
4849
4950logger = SingletonLogger .get_logger ()
5051
51- MODEL_NAME = "Meta-Llama-3-8B"
52-
53- NAME_TO_HF_MODEL_ID_AND_DTYPE = {
54- "Transformer-2-7b-chat-hf" : ("meta-llama/Llama-2-7b-chat-hf" , torch .float16 ),
55- "Meta-Llama-3-8B" : ("meta-llama/Meta-Llama-3-8B-Instruct" , torch .bfloat16 ),
52+ # Using model name to identify the model to load, for example "llama2-7b-chat".
53+ # You can change it to other values listed below.
54+ # For details on the name-to-distribution mapping, see README.md or models.json.
55+ NAME_TO_DISTRIBUTION_AND_DTYPE = {
56+ "llama2-7b-chat" : ("meta-llama/Llama-2-7b-chat-hf" , torch .float16 ),
57+ "llama3" : ("meta-llama/Meta-Llama-3-8B-Instruct" , torch .bfloat16 ),
5658}
5759CACHE_PRECISION = torch .bfloat16
5860
@@ -75,8 +77,19 @@ def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace:
7577
7678
7779def _build_chat_tokenizer (
78- model_base_name : str = "llama3" ,
80+ model_name : str ,
81+ model_base_name : Optional [str ] = None ,
7982) -> SentencePieceProcessor | TiktokenTokenizer :
83+ """Builds a tokenizer for the given model name."""
84+ # Try to infer the model base name from the model name:
85+ # e.g. "llama2-7b-chat" -> "llama2"
86+ if model_base_name is None :
87+ model_base_name = model_name .split ("-" )[0 ]
88+ logger .info (
89+ f"Using model base name '{ model_base_name } ' to build tokenizer. "
90+ "If not found, please specify it using the `model_base_name` argument."
91+ )
92+
8093 # Create base args for tokenizer
8194 default_model_dir = Path (
8295 os .getenv ("TORCHCHAT_MODELDIR" , "~/.torchchat/model-cache" )
@@ -97,12 +110,12 @@ def _build_chat_tokenizer(
97110 return tokenizer
98111
99112
100- def _load_model_weights (stage_module , hf_model_name , device , model_config ):
113+ def _load_model_weights (stage_module , distribution , device , model_config ):
101114 """Load the weights from the safetensor file(s) into the model stage.
102115 Model config is needed b/c we permute wq and wk weights based on attn heads.
103116 """
104117
105- weight_map , weight_path , key_map = get_hf_weight_map_and_path (hf_model_name )
118+ weight_map , weight_path , key_map = get_hf_weight_map_and_path (distribution )
106119
107120 num_loaded_weights , num_missing_weights = load_safetensor_weights (
108121 stage_module ,
@@ -217,32 +230,31 @@ def _cleanup():
217230 dist .destroy_process_group ()
218231
219232
220- def main ():
233+ def main (args ):
234+ model_name = args .model_name
235+ pp_degree = args .pp
236+
221237 rank , world_size = _init_distributed ()
222238
223239 gpu_memory_monitor = GPUMemoryMonitor ("cuda" )
224240 logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } " )
225241
226- config = ModelArgs . from_name ( MODEL_NAME ). transformer_args [ "text" ]
227- logger .info (f"Chat Model Name: { MODEL_NAME } \n Model Config: { config } " )
242+ distribution , model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE [ model_name ]
243+ logger .info (f"Using HF model weights from { distribution } and dtype { model_dtype } " )
228244
229- tokenizer = _build_chat_tokenizer ()
230- logger .info (f"built tokenizer { tokenizer = } " )
245+ config = ModelArgs . from_name ( distribution ). transformer_args [ 'text' ]
246+ logger .info (f"Chat Model Config: { config } " )
231247
232- hf_model_name , model_dtype = NAME_TO_HF_MODEL_ID_AND_DTYPE [MODEL_NAME ]
233- logger .info (f"Using HF model weights from { hf_model_name } and dtype { model_dtype } " )
248+ tokenizer = _build_chat_tokenizer (model_name )
234249
235250 set_precision (CACHE_PRECISION )
236251 logger .info (f"Using cache precision { CACHE_PRECISION } " )
237252
238- hf_config = get_hf_config_file (hf_model_name )
253+ hf_config = get_hf_config_file (distribution )
239254 if hf_config is None :
240- raise ValueError (f"Config file not found for model id { hf_model_name } " )
241- logger .info (f"Using HF model weights from { hf_model_name } " )
255+ raise ValueError (f"Config file not found for model id { distribution } " )
242256
243- # Assuming 2 pipeline stages, feel free to change this as long as the
244- # asserts are satisfied
245- pp_degree = 2
257+ # Validate pipeline degree
246258 assert world_size % pp_degree == 0
247259 assert config .n_layers % pp_degree == 0
248260
@@ -283,7 +295,8 @@ def main():
283295
284296 # Distribute model on TP mesh
285297 model .distribute (tp_mesh )
286- logger .info (f"Model: { model } " )
298+ if rank == 0 :
299+ logger .info (f"Model: { model } " )
287300
288301 mbs = 1 # number of micro-batches
289302 mb_size = 1 # micro-batch size
@@ -301,8 +314,10 @@ def main():
301314
302315 # Load weights
303316 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
317+
304318 with CUDATrackTime () as timer :
305319 _load_model_weights (model , hf_model_name , device = device , model_config = config )
320+
306321 logger .info (
307322 f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for stage { rank } { color .reset } "
308323 )
@@ -359,6 +374,7 @@ def main():
359374 ]
360375 """
361376
377+
362378 start_pos = 0
363379
364380 # encode the prompt
@@ -462,4 +478,9 @@ def main():
462478
463479
464480if __name__ == "__main__" :
465- main ()
481+ parser = argparse .ArgumentParser ()
482+ parser .add_argument ("model_name" , type = str , help = "Name of the model to load" , choices = NAME_TO_DISTRIBUTION_AND_DTYPE .keys ())
483+ parser .add_argument ("--pp" , type = int , default = 1 , help = "Pipeline parallel degree" )
484+ args = parser .parse_args ()
485+
486+ main (args )
0 commit comments