Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8ade3c4

Browse files
authored
Merge branch 'main' into add_quant_saving
2 parents 763a9ce + 70260eb commit 8ade3c4

File tree

10 files changed

+1013
-40
lines changed

10 files changed

+1013
-40
lines changed

dist_run.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
2121
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs
2222

23-
from torchchat.distributed.logging_utils import SingletonLogger
24-
2523
# TODO - these are not distributed specific, consider moving to new package
2624
from torchchat.distributed.checkpoint_utils import (
2725
get_hf_config_file,
2826
load_weights_from_hf_format,
2927
load_weights_from_torchchat_format,
3028
)
29+
30+
from torchchat.distributed.logging_utils import SingletonLogger
3131
from torchchat.distributed.utils import (
3232
bytes_to_readable,
3333
Color as color,
@@ -153,7 +153,9 @@ def _load_model_weights(
153153
# This format stands for:
154154
# single binary file, OR
155155
# multiple binary files without index files.
156-
load_weights_from_torchchat_format(stage_module, distribution, device, model_config)
156+
load_weights_from_torchchat_format(
157+
stage_module, distribution, device, model_config
158+
)
157159
else:
158160
raise ValueError(f"Unknown checkpoint format: {chpt_from}")
159161

@@ -593,9 +595,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
593595
parser.add_argument(
594596
"model_name",
595597
type=str,
598+
default="llama3",
596599
help="Name of the model to load",
597600
choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
598601
)
602+
599603
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
600604
parser.add_argument(
601605
"--ntokens",

docs/quantization.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my n
120120

121121
## Experimental TorchAO lowbit kernels
122122

123+
WARNING: These kernels only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
124+
123125
### Use
124126

125127
#### linear:a8wxdq

install/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@ streamlit
3030

3131
# Server mode
3232
flask
33+
34+
# eval
35+
lm_eval==0.4.2

torchchat/cli/builder.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import torch.nn as nn
1818

1919
from torch.distributed.device_mesh import DeviceMesh
20+
from torch.distributed.elastic.multiprocessing.errors import record
21+
from torch.distributed.elastic.utils.distributed import get_free_port
2022

2123
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
2224

@@ -55,7 +57,10 @@ class BuilderArgs:
5557
device: Optional[str] = None
5658
precision: torch.dtype = torch.float32
5759
setup_caches: bool = False
58-
use_distributed: bool = False
60+
distributed: bool = False
61+
pp: int = 1
62+
tp: int = 1
63+
chpt_from: str = "hf"
5964
is_chat_model: bool = False
6065
prefill_possible: bool = False
6166
dynamic_shapes: bool = False
@@ -157,7 +162,11 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
157162
dtype = torch.float16
158163
else:
159164
dtype = name_to_dtype(args.dtype, args.device)
160-
165+
# distributed args
166+
distributed = getattr(args, "distributed", False)
167+
pp = getattr(args, "pp", 1)
168+
tp = getattr(args, "tp", 1)
169+
chpt_from = getattr(args, "chpt_from", "hf")
161170
return cls(
162171
checkpoint_dir=checkpoint_dir,
163172
checkpoint_path=checkpoint_path,
@@ -171,7 +180,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
171180
device=args.device,
172181
precision=dtype,
173182
setup_caches=(output_dso_path or output_pte_path),
174-
use_distributed=args.distributed,
183+
distributed=distributed,
184+
pp=pp,
185+
tp=tp,
186+
chpt_from=chpt_from,
175187
is_chat_model=is_chat_model,
176188
dynamic_shapes=getattr(args, "dynamic_shapes", False),
177189
max_seq_length=getattr(args, "max_seq_length", None),
@@ -481,14 +493,14 @@ def _maybe_parallelize_model(
481493

482494

483495
def _load_model(builder_args: BuilderArgs) -> Model:
484-
world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
496+
# world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
485497
if builder_args.gguf_path:
486498
model = _load_model_gguf(builder_args)
487-
elif builder_args.use_distributed:
488-
model = _init_model_on_meta_device(builder_args)
499+
# elif builder_args.use_distributed:
500+
# model = _init_model_on_meta_device(builder_args)
489501
else:
490502
model = _load_model_default(builder_args)
491-
model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
503+
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
492504

493505
model = model.to(device=builder_args.device, dtype=builder_args.precision)
494506
return model.eval()
@@ -502,7 +514,6 @@ def _initialize_model(
502514
support_tensor_subclass: bool = True,
503515
) -> Model:
504516
print("Loading model...")
505-
506517
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
507518
print("Setting gguf_kwargs for generate.")
508519
is_dso = builder_args.dso_path is not None

torchchat/cli/cli.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,7 @@ def _add_distributed_args(parser) -> None:
405405
parser.add_argument(
406406
"--distributed",
407407
action="store_true",
408-
help=argparse.SUPPRESS,
409-
# "Whether to enable distributed inference",
408+
help="Whether to enable distributed inference",
410409
)
411410
parser.add_argument(
412411
"--dcp-dir",
@@ -415,6 +414,27 @@ def _add_distributed_args(parser) -> None:
415414
help=argparse.SUPPRESS,
416415
# "Use the specified model checkpoint directory",
417416
)
417+
parser.add_argument(
418+
"--pp",
419+
"--pipeline-parallel",
420+
type=int,
421+
default=1,
422+
help="Pipeline parallel degree",
423+
)
424+
parser.add_argument(
425+
"--tp",
426+
"--tensor-parallel",
427+
type=int,
428+
default=2,
429+
help="Tensor parallel degree",
430+
)
431+
parser.add_argument(
432+
"--chpt-from",
433+
type=str,
434+
default="hf", # TODO: change to torchchat once we support it well
435+
help="Checkpoint format to load from",
436+
choices=["hf", "torchchat"],
437+
)
418438

419439

420440
# Add CLI Args related to custom model inputs

0 commit comments

Comments
 (0)