Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 9242dfb

Browse files
authored
Merge branch 'main' into lessw2020/prefill
2 parents 6bb2725 + 7a4f0d1 commit 9242dfb

File tree

1 file changed

+44
-23
lines changed

1 file changed

+44
-23
lines changed

dist_run.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import argparse
78
import os
89
from pathlib import Path
910
from types import SimpleNamespace
@@ -48,11 +49,12 @@
4849

4950
logger = SingletonLogger.get_logger()
5051

51-
MODEL_NAME = "Meta-Llama-3-8B"
52-
53-
NAME_TO_HF_MODEL_ID_AND_DTYPE = {
54-
"Transformer-2-7b-chat-hf": ("meta-llama/Llama-2-7b-chat-hf", torch.float16),
55-
"Meta-Llama-3-8B": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
52+
# Using model name to identify the model to load, for example "llama2-7b-chat".
53+
# You can change it to other values listed below.
54+
# For details on the name-to-distribution mapping, see README.md or models.json.
55+
NAME_TO_DISTRIBUTION_AND_DTYPE = {
56+
"llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16),
57+
"llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
5658
}
5759
CACHE_PRECISION = torch.bfloat16
5860

@@ -75,8 +77,19 @@ def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace:
7577

7678

7779
def _build_chat_tokenizer(
78-
model_base_name: str = "llama3",
80+
model_name: str,
81+
model_base_name: Optional[str] = None,
7982
) -> SentencePieceProcessor | TiktokenTokenizer:
83+
"""Builds a tokenizer for the given model name."""
84+
# Try to infer the model base name from the model name:
85+
# e.g. "llama2-7b-chat" -> "llama2"
86+
if model_base_name is None:
87+
model_base_name = model_name.split("-")[0]
88+
logger.info(
89+
f"Using model base name '{model_base_name}' to build tokenizer. "
90+
"If not found, please specify it using the `model_base_name` argument."
91+
)
92+
8093
# Create base args for tokenizer
8194
default_model_dir = Path(
8295
os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache")
@@ -97,12 +110,12 @@ def _build_chat_tokenizer(
97110
return tokenizer
98111

99112

100-
def _load_model_weights(stage_module, hf_model_name, device, model_config):
113+
def _load_model_weights(stage_module, distribution, device, model_config):
101114
"""Load the weights from the safetensor file(s) into the model stage.
102115
Model config is needed b/c we permute wq and wk weights based on attn heads.
103116
"""
104117

105-
weight_map, weight_path, key_map = get_hf_weight_map_and_path(hf_model_name)
118+
weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution)
106119

107120
num_loaded_weights, num_missing_weights = load_safetensor_weights(
108121
stage_module,
@@ -217,32 +230,31 @@ def _cleanup():
217230
dist.destroy_process_group()
218231

219232

220-
def main():
233+
def main(args):
234+
model_name = args.model_name
235+
pp_degree = args.pp
236+
221237
rank, world_size = _init_distributed()
222238

223239
gpu_memory_monitor = GPUMemoryMonitor("cuda")
224240
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
225241

226-
config = ModelArgs.from_name(MODEL_NAME).transformer_args["text"]
227-
logger.info(f"Chat Model Name: {MODEL_NAME}\nModel Config: {config}")
242+
distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
243+
logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}")
228244

229-
tokenizer = _build_chat_tokenizer()
230-
logger.info(f"built tokenizer {tokenizer=}")
245+
config = ModelArgs.from_name(distribution).transformer_args['text']
246+
logger.info(f"Chat Model Config: {config}")
231247

232-
hf_model_name, model_dtype = NAME_TO_HF_MODEL_ID_AND_DTYPE[MODEL_NAME]
233-
logger.info(f"Using HF model weights from {hf_model_name} and dtype {model_dtype}")
248+
tokenizer = _build_chat_tokenizer(model_name)
234249

235250
set_precision(CACHE_PRECISION)
236251
logger.info(f"Using cache precision {CACHE_PRECISION}")
237252

238-
hf_config = get_hf_config_file(hf_model_name)
253+
hf_config = get_hf_config_file(distribution)
239254
if hf_config is None:
240-
raise ValueError(f"Config file not found for model id {hf_model_name}")
241-
logger.info(f"Using HF model weights from {hf_model_name}")
255+
raise ValueError(f"Config file not found for model id {distribution}")
242256

243-
# Assuming 2 pipeline stages, feel free to change this as long as the
244-
# asserts are satisfied
245-
pp_degree = 2
257+
# Validate pipeline degree
246258
assert world_size % pp_degree == 0
247259
assert config.n_layers % pp_degree == 0
248260

@@ -283,7 +295,8 @@ def main():
283295

284296
# Distribute model on TP mesh
285297
model.distribute(tp_mesh)
286-
logger.info(f"Model: {model}")
298+
if rank == 0:
299+
logger.info(f"Model: {model}")
287300

288301
mbs = 1 # number of micro-batches
289302
mb_size = 1 # micro-batch size
@@ -301,8 +314,10 @@ def main():
301314

302315
# Load weights
303316
logger.info(f"Loading weights for {pp_rank=} on {device=}")
317+
304318
with CUDATrackTime() as timer:
305319
_load_model_weights(model, hf_model_name, device=device, model_config=config)
320+
306321
logger.info(
307322
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
308323
)
@@ -359,6 +374,7 @@ def main():
359374
]
360375
"""
361376

377+
362378
start_pos = 0
363379

364380
# encode the prompt
@@ -462,4 +478,9 @@ def main():
462478

463479

464480
if __name__ == "__main__":
465-
main()
481+
parser = argparse.ArgumentParser()
482+
parser.add_argument("model_name", type=str, help="Name of the model to load", choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys())
483+
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
484+
args = parser.parse_args()
485+
486+
main(args)

0 commit comments

Comments
 (0)