44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import argparse
78import os
89import sys
9- import time
1010from dataclasses import dataclass
1111from pathlib import Path
1212from typing import Any , Dict , Optional , Tuple , Union
2121except ImportError :
2222 pass
2323
24- from distributed import (
25- init_distributed ,
26- launch_distributed ,
27- ParallelDims ,
28- parallelize_llama ,
29- )
24+ from distributed import launch_distributed , ParallelDims , parallelize_llama
3025
3126from torch .distributed .device_mesh import DeviceMesh
3227
@@ -101,7 +96,7 @@ def __post_init__(self):
10196 self .prefill_possible = True
10297
10398 @classmethod
104- def from_args (cls , args ): # -> BuilderArgs:
99+ def from_args (cls , args : argparse . Namespace ) -> " BuilderArgs" :
105100 # Handle disabled checkpoint_dir option
106101 checkpoint_dir = None
107102 if hasattr (args , "checkpoint_dir" ):
@@ -183,7 +178,7 @@ def from_args(cls, args): # -> BuilderArgs:
183178 )
184179
185180 @classmethod
186- def from_speculative_args (cls , args ): # -> BuilderArgs:
181+ def from_speculative_args (cls , args : argparse . Namespace ) -> " BuilderArgs" :
187182 speculative_builder_args = BuilderArgs .from_args (args )
188183 # let's limit multi-checkpoint to checker
189184 speculative_builder_args .checkpoint_dir = None
@@ -229,7 +224,7 @@ def __post_init__(self):
229224
230225 def validate_model (
231226 self ,
232- model : Model ,
227+ model : Optional [ Model ] ,
233228 model_description : str = "model" ,
234229 ) -> None :
235230 if model is None :
@@ -250,10 +245,21 @@ def validate_model(
250245 return
251246
252247 @classmethod
253- def from_args (cls , args ): # -> TokenizerArgs:
254- is_sentencepiece = False
255- is_tiktoken = False
256-
248+ def from_args (cls , args : argparse .Namespace ) -> "TokenizerArgs" :
249+ """
250+ Create a TokenizerArgs object from command line arguments.
251+ Specifically, `tokenizer_path` is resolved with precedence:
252+ * From Explicitly provided tokenizer_path
253+ * Resolve via model_config identified by args.model
254+ * Look in the directory of args.checkpoint_path for tokenizer.model
255+ * Look in the directory of args.checkpoint_dir for tokenizer.model
256+
257+ Args:
258+ args (argparse.Namespace): The command line arguments.
259+
260+ Returns:
261+ TokenizerArgs: A TokenizerArgs object.
262+ """
257263 if args .tokenizer_path :
258264 tokenizer_path = args .tokenizer_path
259265 elif args .model : # Using a named, well-known model
@@ -263,7 +269,6 @@ def from_args(cls, args): # -> TokenizerArgs:
263269 / model_config .name
264270 / model_config .tokenizer_file
265271 )
266-
267272 elif args .checkpoint_path :
268273 tokenizer_path = args .checkpoint_path .parent / "tokenizer.model"
269274 elif hasattr (args , "checkpoint_dir" ) and args .checkpoint_dir :
@@ -276,12 +281,7 @@ def from_args(cls, args): # -> TokenizerArgs:
276281 f"did not find tokenizer at { os .path .abspath (tokenizer_path )} "
277282 )
278283
279- return cls (
280- tokenizer_path = tokenizer_path ,
281- is_sentencepiece = is_sentencepiece ,
282- is_tiktoken = is_tiktoken ,
283- t = None ,
284- )
284+ return cls (tokenizer_path = tokenizer_path )
285285
286286
287287def _initialize_tokenizer (tokenizer_args : TokenizerArgs ):
@@ -299,7 +299,7 @@ def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
299299
300300
301301# TODO: remove these once ET supports _weight_int4pack_mm
302- def _set_gguf_kwargs (builder_args , is_et , context : str ):
302+ def _set_gguf_kwargs (builder_args : BuilderArgs , is_et : bool , context : str ) -> None :
303303 assert context in ["export" , "generate" ]
304304 assert builder_args .gguf_kwargs is None
305305
@@ -312,11 +312,11 @@ def _set_gguf_kwargs(builder_args, is_et, context: str):
312312 builder_args .gguf_kwargs ["load_as_quantized" ] = False
313313
314314
315- def _unset_gguf_kwargs (builder_args ) :
315+ def _unset_gguf_kwargs (builder_args : BuilderArgs ) -> None :
316316 builder_args .gguf_kwargs = None
317317
318318
319- def _init_model_on_meta_device (builder_args ) :
319+ def _init_model_on_meta_device (builder_args : BuilderArgs ) -> Model :
320320 with torch .device ("meta" ):
321321 if builder_args .params_path :
322322 return Model .from_params (builder_args .params_path )
@@ -326,7 +326,7 @@ def _init_model_on_meta_device(builder_args):
326326 return Model .from_name (builder_args .checkpoint_path .parent .name )
327327
328328
329- def _load_model_gguf (builder_args , only_config = False ) :
329+ def _load_model_gguf (builder_args : BuilderArgs ) -> Model :
330330 assert builder_args .gguf_path
331331 if builder_args .gguf_kwargs is None :
332332 kwargs = {}
@@ -336,10 +336,10 @@ def _load_model_gguf(builder_args, only_config=False):
336336 return model
337337
338338
339- def _load_model_default (builder_args , only_config = False ) :
339+ def _load_model_default (builder_args : BuilderArgs ) -> Model :
340340 assert not builder_args .gguf_path
341341
342- model = _init_model_on_meta_device (builder_args )
342+ model : Model = _init_model_on_meta_device (builder_args )
343343
344344 if builder_args .params_table and builder_args .params_table .endswith ("Tune" ):
345345 print ("Loading Tune checkpoint" )
@@ -457,7 +457,7 @@ def _maybe_parellelize_model(
457457 return load_checkpoints_to_model (model , builder_args , world_mesh )
458458
459459
460- def _load_model (builder_args , only_config = False ) :
460+ def _load_model (builder_args : BuilderArgs ) -> Model :
461461 world_mesh , parallel_dims = _maybe_init_distributed (builder_args )
462462 if builder_args .gguf_path :
463463 model = _load_model_gguf (builder_args )
@@ -472,12 +472,12 @@ def _load_model(builder_args, only_config=False):
472472
473473
474474def _initialize_model (
475- builder_args ,
475+ builder_args : BuilderArgs ,
476476 quantize ,
477477 tokenizer = None ,
478478 max_seq_length = None ,
479479 support_tensor_subclass : bool = True ,
480- ):
480+ ) -> Model :
481481 print ("Loading model..." )
482482
483483 if builder_args .gguf_path and (builder_args .dso_path or builder_args .pte_path ):
@@ -503,7 +503,7 @@ def _initialize_model(
503503 # ), "quantize not valid for exported DSO model. Specify quantization during export."
504504
505505 with measure_time ("Time to load model: {time:.02f} seconds" ):
506- model = _load_model (builder_args , only_config = True )
506+ model = _load_model (builder_args )
507507 device_sync (device = builder_args .device )
508508
509509 try :
@@ -530,7 +530,7 @@ def _initialize_model(
530530 # ), "quantize not valid for exported PTE model. Specify quantization during export."
531531
532532 with measure_time ("Time to load model: {time:.02f} seconds" ):
533- model = _load_model (builder_args , only_config = True )
533+ model = _load_model (builder_args )
534534 device_sync (device = builder_args .device )
535535
536536 try :
0 commit comments