Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 96b51be

Browse files
committed
Merge branch 'main' into rm-modle.model
2 parents bcb414f + 72d2d20 commit 96b51be

File tree

2 files changed

+37
-32
lines changed

2 files changed

+37
-32
lines changed

torchchat/cli/builder.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import argparse
78
import os
89
import sys
9-
import time
1010
from dataclasses import dataclass
1111
from pathlib import Path
1212
from typing import Any, Dict, Optional, Tuple, Union
@@ -21,12 +21,7 @@
2121
except ImportError:
2222
pass
2323

24-
from distributed import (
25-
init_distributed,
26-
launch_distributed,
27-
ParallelDims,
28-
parallelize_llama,
29-
)
24+
from distributed import launch_distributed, ParallelDims, parallelize_llama
3025

3126
from torch.distributed.device_mesh import DeviceMesh
3227

@@ -101,7 +96,7 @@ def __post_init__(self):
10196
self.prefill_possible = True
10297

10398
@classmethod
104-
def from_args(cls, args): # -> BuilderArgs:
99+
def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
105100
# Handle disabled checkpoint_dir option
106101
checkpoint_dir = None
107102
if hasattr(args, "checkpoint_dir"):
@@ -183,7 +178,7 @@ def from_args(cls, args): # -> BuilderArgs:
183178
)
184179

185180
@classmethod
186-
def from_speculative_args(cls, args): # -> BuilderArgs:
181+
def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
187182
speculative_builder_args = BuilderArgs.from_args(args)
188183
# let's limit multi-checkpoint to checker
189184
speculative_builder_args.checkpoint_dir = None
@@ -229,7 +224,7 @@ def __post_init__(self):
229224

230225
def validate_model(
231226
self,
232-
model: Model,
227+
model: Optional[Model],
233228
model_description: str = "model",
234229
) -> None:
235230
if model is None:
@@ -250,10 +245,21 @@ def validate_model(
250245
return
251246

252247
@classmethod
253-
def from_args(cls, args): # -> TokenizerArgs:
254-
is_sentencepiece = False
255-
is_tiktoken = False
256-
248+
def from_args(cls, args: argparse.Namespace) -> "TokenizerArgs":
249+
"""
250+
Create a TokenizerArgs object from command line arguments.
251+
Specifically, `tokenizer_path` is resolved with precedence:
252+
* From Explicitly provided tokenizer_path
253+
* Resolve via model_config identified by args.model
254+
* Look in the directory of args.checkpoint_path for tokenizer.model
255+
* Look in the directory of args.checkpoint_dir for tokenizer.model
256+
257+
Args:
258+
args (argparse.Namespace): The command line arguments.
259+
260+
Returns:
261+
TokenizerArgs: A TokenizerArgs object.
262+
"""
257263
if args.tokenizer_path:
258264
tokenizer_path = args.tokenizer_path
259265
elif args.model: # Using a named, well-known model
@@ -263,7 +269,6 @@ def from_args(cls, args): # -> TokenizerArgs:
263269
/ model_config.name
264270
/ model_config.tokenizer_file
265271
)
266-
267272
elif args.checkpoint_path:
268273
tokenizer_path = args.checkpoint_path.parent / "tokenizer.model"
269274
elif hasattr(args, "checkpoint_dir") and args.checkpoint_dir:
@@ -276,12 +281,7 @@ def from_args(cls, args): # -> TokenizerArgs:
276281
f"did not find tokenizer at {os.path.abspath(tokenizer_path)}"
277282
)
278283

279-
return cls(
280-
tokenizer_path=tokenizer_path,
281-
is_sentencepiece=is_sentencepiece,
282-
is_tiktoken=is_tiktoken,
283-
t=None,
284-
)
284+
return cls(tokenizer_path=tokenizer_path)
285285

286286

287287
def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
@@ -299,7 +299,7 @@ def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
299299

300300

301301
# TODO: remove these once ET supports _weight_int4pack_mm
302-
def _set_gguf_kwargs(builder_args, is_et, context: str):
302+
def _set_gguf_kwargs(builder_args: BuilderArgs, is_et: bool, context: str) -> None:
303303
assert context in ["export", "generate"]
304304
assert builder_args.gguf_kwargs is None
305305

@@ -312,11 +312,11 @@ def _set_gguf_kwargs(builder_args, is_et, context: str):
312312
builder_args.gguf_kwargs["load_as_quantized"] = False
313313

314314

315-
def _unset_gguf_kwargs(builder_args):
315+
def _unset_gguf_kwargs(builder_args: BuilderArgs) -> None:
316316
builder_args.gguf_kwargs = None
317317

318318

319-
def _init_model_on_meta_device(builder_args):
319+
def _init_model_on_meta_device(builder_args: BuilderArgs) -> Model:
320320
with torch.device("meta"):
321321
if builder_args.params_path:
322322
return Model.from_params(builder_args.params_path)
@@ -326,7 +326,7 @@ def _init_model_on_meta_device(builder_args):
326326
return Model.from_name(builder_args.checkpoint_path.parent.name)
327327

328328

329-
def _load_model_gguf(builder_args, only_config=False):
329+
def _load_model_gguf(builder_args: BuilderArgs) -> Model:
330330
assert builder_args.gguf_path
331331
if builder_args.gguf_kwargs is None:
332332
kwargs = {}
@@ -336,10 +336,10 @@ def _load_model_gguf(builder_args, only_config=False):
336336
return model
337337

338338

339-
def _load_model_default(builder_args, only_config=False):
339+
def _load_model_default(builder_args: BuilderArgs) -> Model:
340340
assert not builder_args.gguf_path
341341

342-
model = _init_model_on_meta_device(builder_args)
342+
model: Model = _init_model_on_meta_device(builder_args)
343343

344344
if builder_args.params_table and builder_args.params_table.endswith("Tune"):
345345
print("Loading Tune checkpoint")
@@ -457,7 +457,7 @@ def _maybe_parellelize_model(
457457
return load_checkpoints_to_model(model, builder_args, world_mesh)
458458

459459

460-
def _load_model(builder_args, only_config=False):
460+
def _load_model(builder_args: BuilderArgs) -> Model:
461461
world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
462462
if builder_args.gguf_path:
463463
model = _load_model_gguf(builder_args)
@@ -472,12 +472,12 @@ def _load_model(builder_args, only_config=False):
472472

473473

474474
def _initialize_model(
475-
builder_args,
475+
builder_args: BuilderArgs,
476476
quantize,
477477
tokenizer=None,
478478
max_seq_length=None,
479479
support_tensor_subclass: bool = True,
480-
):
480+
) -> Model:
481481
print("Loading model...")
482482

483483
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
@@ -503,7 +503,7 @@ def _initialize_model(
503503
# ), "quantize not valid for exported DSO model. Specify quantization during export."
504504

505505
with measure_time("Time to load model: {time:.02f} seconds"):
506-
model = _load_model(builder_args, only_config=True)
506+
model = _load_model(builder_args)
507507
device_sync(device=builder_args.device)
508508

509509
try:
@@ -530,7 +530,7 @@ def _initialize_model(
530530
# ), "quantize not valid for exported PTE model. Specify quantization during export."
531531

532532
with measure_time("Time to load model: {time:.02f} seconds"):
533-
model = _load_model(builder_args, only_config=True)
533+
model = _load_model(builder_args)
534534
device_sync(device=builder_args.device)
535535

536536
try:

torchchat/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,11 @@ def forward(self, x, input_pos):
982982
# the first element to get the tensor
983983
assert len(logits) == 1
984984
logits = logits[0]
985+
986+
# Add a batch dimension, if it's missing (e.g. some pte's
987+
# exported from the ExecuTorch repo)
988+
if logits.dim() == 2:
989+
logits = logits.unsqueeze(0)
985990
return logits
986991

987992
def setup_caches(self, max_batch_size, max_seq_length):

0 commit comments

Comments
 (0)