Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs
from torchchat.model import ModelArgs, Transformer
from torchchat.model import ModelArgs, Transformer, TransformerArgs
from torchchat.utils.build_utils import set_precision

try:
Expand Down Expand Up @@ -239,8 +239,11 @@ def main(args):
distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}")

config = ModelArgs.from_name(distribution).transformer_args["text"]
logger.info(f"Chat Model Config: {config}")
# Model-level config
model_config = ModelArgs.from_name(distribution)
# Transformer-level config
config = TransformerArgs.from_params(model_config.transformer_args["text"])
logger.info(f"Transformer Config: {config}")

tokenizer = _build_chat_tokenizer(model_name)

Expand Down Expand Up @@ -282,6 +285,7 @@ def main(args):
config.n_stages = pp_degree

with device:
# TODO: we should create model instead of Transformer
model = Transformer(config)

# Distribute model on TP mesh
Expand Down
Loading