diff --git a/build/builder.py b/build/builder.py index d8b803149..85773a6c1 100644 --- a/build/builder.py +++ b/build/builder.py @@ -400,7 +400,7 @@ def _maybe_parellelize_model( if the user specifies using distributed inference. If not, this is a no-op. Args: - module (:class:`nn.Module`): + model (:class:`nn.Module`): Module to be parallelized. builder_args (:class:`BuilderArgs`): Command args for model building. diff --git a/distributed/checkpoint.py b/distributed/checkpoint.py index 35f28b419..1830e3a75 100644 --- a/distributed/checkpoint.py +++ b/distributed/checkpoint.py @@ -108,7 +108,7 @@ def load_checkpoints_to_model( We parallelize the module and load the distributed checkpoint to the model. Args: - module (:class:`nn.Module`): + model (:class:`nn.Module`): Module to be parallelized. builder_args (:class:`BuilderArgs`): Command args for model building. diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py index c4eb17658..cbcb29b72 100644 --- a/distributed/parallelize_llama.py +++ b/distributed/parallelize_llama.py @@ -28,7 +28,7 @@ def apply_tp( Args: - module (:class:`nn.Module`): + model (:class:`nn.Module`): Module to be parallelized. world_mesh (:class:`DeviceMesh`): Object which describes the mesh topology @@ -104,7 +104,7 @@ def parallelize_llama( the model must fit on GPU or CPU memory. Args: - module (:class:`nn.Module`): + model (:class:`nn.Module`): Module to be parallelized. world_mesh (:class:`DeviceMesh`): Object which describes the mesh topology diff --git a/distributed/world_maker.py b/distributed/world_maker.py index 85de66128..4fe578741 100644 --- a/distributed/world_maker.py +++ b/distributed/world_maker.py @@ -24,7 +24,7 @@ def launch_distributed( using distributed inference. If not, this is a no-op. Args: - config: str: + toml_config: str: toml file for the inference config. Returns: Tuple[Optional[DeviceMesh], Optional[ParallelDims]]: diff --git a/eval.py b/eval.py index 76aa25d31..0b107e2f3 100644 --- a/eval.py +++ b/eval.py @@ -167,7 +167,7 @@ def eval( Args: model (Transformer): The pre-trained language model to evaluate. tokenizer: The tokenizer to use for encoding/decoding text. - task (str): The name of the evaluation task to perform. + tasks (Optional[list]): The names of the evaluation tasks to perform. limit (Optional[int]): The maximum number of samples to evaluate (None for all available). max_seq_length (Optional[int]): The maximum sequence length allowed for input text. @@ -210,7 +210,7 @@ def main(args) -> None: Args: checkpoint_path (Path): The path to the model checkpoint file to load. compile (bool): Whether or not to compile the model for optimization. - task (Optional[str]): The name of the evaluation task or a list of tasks to perform. + tasks (Optional[list]): The names of the evaluation tasks to perform. limit (Optional[int]): The maximum number of samples to evaluate (None for all available). max_seq_length (Optional[int]): The maximum sequence length allowed for input text. diff --git a/tokenizer/tiktoken.py b/tokenizer/tiktoken.py index c3a5fd607..9e9fe2264 100644 --- a/tokenizer/tiktoken.py +++ b/tokenizer/tiktoken.py @@ -116,8 +116,8 @@ def encode( s (str): The input string to be encoded. bos (bool): Whether to prepend the beginning-of-sequence token. eos (bool): Whether to append the end-of-sequence token. - allowed_tokens ("all"|set[str]): allowed special tokens in string - disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + allowed_special ("all"|set[str]): allowed special tokens in string + disallowed_special ("all"|set[str]): special tokens that raise an error when in string Returns: list[int]: A list of token IDs. @@ -125,7 +125,7 @@ def encode( By default, setting disallowed_special=() encodes a string by ignoring special tokens. Specifically: - Setting `disallowed_special` to () will cause all text corresponding - to special tokens to be encoded as natural text (insteading of raising + to special tokens to be encoded as natural text (instead of raising an error). - Setting `allowed_special` to "all" will treat all text corresponding to special tokens to be encoded as special tokens.