From 416c4f49bdef966fc37b6f3d66500cd38f04dc58 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 23 Jul 2024 17:04:09 -0700 Subject: [PATCH 01/12] add llama 3.1 8b support --- .../known_model_params/Meta-Llama-3.1-8B.json | 1 + build/model.py | 32 ++++++++++++++++++- config/data/models.json | 6 ++++ 3 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 build/known_model_params/Meta-Llama-3.1-8B.json diff --git a/build/known_model_params/Meta-Llama-3.1-8B.json b/build/known_model_params/Meta-Llama-3.1-8B.json new file mode 100644 index 000000000..0d3808205 --- /dev/null +++ b/build/known_model_params/Meta-Llama-3.1-8B.json @@ -0,0 +1 @@ +{"dim": 4096, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 32, "n_local_heads": 8, "n_layers": 32, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "use_scaled_rope": true} diff --git a/build/model.py b/build/model.py index 0405e3683..0e9673df3 100644 --- a/build/model.py +++ b/build/model.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. import json import os +import math + from dataclasses import dataclass from pathlib import Path from typing import Dict, Optional @@ -38,6 +40,7 @@ class ModelArgs: ffn_dim_multiplier: Optional[int] = None use_tiktoken: bool = False max_seq_length: int = 8192 + use_scaled_rope: bool = False def __post_init__(self): if self.n_local_heads == -1: @@ -178,6 +181,7 @@ def setup_caches(self, max_batch_size, max_seq_length): self.config.dim // self.config.n_heads, self.config.block_size * 2, self.config.rope_base, + use_scaled = self.config.use_scaled_rope, ) self.register_buffer("freqs_cis", freqs_cis, persistent=True) causal_mask = torch.tril( @@ -361,8 +365,32 @@ def forward(self, x: Tensor) -> Tensor: return output * self.weight +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + def precompute_freqs_cis( - n_elem: int, seq_len: int, base: int = 10000, dtype=None + n_elem: int, seq_len: int, base: int = 10000, dtype=None, use_scaled: bool = False ) -> Tensor: if not dtype: dtype = get_precision() @@ -370,6 +398,8 @@ def precompute_freqs_cis( base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) ) t = torch.arange(seq_len, device=freqs.device) + if use_scaled: + freqs = apply_scaling(freqs) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) diff --git a/config/data/models.json b/config/data/models.json index 483f6c35f..35e7c2061 100644 --- a/config/data/models.json +++ b/config/data/models.json @@ -40,6 +40,12 @@ "distribution_path": "meta-llama/Meta-Llama-3-70B-Instruct", "transformer_params_key": "Meta-Llama-3-70B" }, + "meta-llama/Meta-Llama-3.1-8B-Instruct": { + "aliases": ["llama3.1", "llama3.1-chat", "llama3.1-instruct"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "transformer_params_key": "Meta-Llama-3.1-8B" + }, "meta-llama/CodeLlama-7b-Python-hf": { "aliases": ["codellama", "codellama-7b"], "distribution_channel": "HuggingFaceSnapshot", From db4853b81fd75824998844fd90dfc017f4832037 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Mon, 29 Jul 2024 18:38:01 -0700 Subject: [PATCH 02/12] make Model and ModelArgs as model definition entrance --- build/builder.py | 46 ++++++++------ build/gguf_loader.py | 25 +++++--- build/model.py | 134 ++++++++++++++++++++++++++++++++------- docs/ADVANCED-USERS.md | 12 ++-- eval.py | 10 +-- export_util/export_et.py | 16 ++++- generate.py | 16 ++--- 7 files changed, 186 insertions(+), 73 deletions(-) diff --git a/build/builder.py b/build/builder.py index b69fcaf20..ed0fcba06 100644 --- a/build/builder.py +++ b/build/builder.py @@ -12,19 +12,23 @@ from typing import Any, Dict, Optional, Tuple, Union import torch -import torch.nn as nn -from torch.distributed.device_mesh import DeviceMesh import torch._dynamo.config import torch._inductor.config +import torch.nn as nn from config.model_config import resolve_model_config -from distributed import init_distributed, ParallelDims, parallelize_llama +from distributed import ( + init_distributed, + launch_distributed, + ParallelDims, + parallelize_llama, +) from quantization.quantize import quantize_model +from torch.distributed.device_mesh import DeviceMesh from utils.measure_time import measure_time -from build.model import Transformer +from build.model import Model from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype -from distributed import launch_distributed @dataclass @@ -200,7 +204,7 @@ def __post_init__(self): def validate_model( self, - model: Transformer, + model: Model, model_description: str = "model", ) -> None: if model is None: @@ -288,11 +292,11 @@ def _unset_gguf_kwargs(builder_args): def _init_model_on_meta_device(builder_args): with torch.device("meta"): if builder_args.params_path: - return Transformer.from_params(builder_args.params_path) + return Model.from_params(builder_args.params_path) elif builder_args.params_table: - return Transformer.from_table(builder_args.params_table) + return Model.from_table(builder_args.params_table) else: - return Transformer.from_name(builder_args.checkpoint_path.parent.name) + return Model.from_name(builder_args.checkpoint_path.parent.name) def _load_model_gguf(builder_args, only_config=False): @@ -301,7 +305,7 @@ def _load_model_gguf(builder_args, only_config=False): kwargs = {} else: kwargs = builder_args.gguf_kwargs - model = Transformer.from_gguf(builder_args.gguf_path, **kwargs) + model = Model.from_gguf(builder_args.gguf_path, **kwargs) return model @@ -355,27 +359,29 @@ def _maybe_init_distributed( builder_args: BuilderArgs, ) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]: """ - Initialize distributed related setups if the user specified + Initialize distributed related setups if the user specified using distributed inference. If not, this is a no-op. Args: builder_args (:class:`BuilderArgs`): Command args for model building. Returns: - Tuple[Optional[DeviceMesh], Optional[ParallelDims]]: - - The first element is an optional DeviceMesh object, + Tuple[Optional[DeviceMesh], Optional[ParallelDims]]: + - The first element is an optional DeviceMesh object, which which describes the mesh topology of devices for the DTensor. - - The second element is an optional ParallelDims object, + - The second element is an optional ParallelDims object, which represents the parallel dimensions configuration. """ if not builder_args.use_distributed: return None, None - dist_config = 'llama3_8B.toml' # TODO - integrate with chat cmd line - - world_mesh, parallel_dims = launch_distributed(dist_config) - - assert world_mesh is not None and parallel_dims is not None, f"failed to launch distributed using {dist_config}" - + dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line + + world_mesh, parallel_dims = launch_distributed(dist_config) + + assert ( + world_mesh is not None and parallel_dims is not None + ), f"failed to launch distributed using {dist_config}" + return world_mesh, parallel_dims diff --git a/build/gguf_loader.py b/build/gguf_loader.py index 4047d29ef..68c5e2814 100644 --- a/build/gguf_loader.py +++ b/build/gguf_loader.py @@ -16,8 +16,9 @@ from gguf import GGUFValueType from quantization.qops import LinearInt4 as WeightOnlyInt4Linear from quantization.quantize import pack_scales_and_zeros + from build.gguf_util import Q4_0, to_float -from build.model import TransformerArgs, Transformer +from build.model import Model, ModelArgs, TransformerArgs logger: logging.Logger = logging.getLogger(__name__) @@ -107,14 +108,18 @@ def load_model(gguf_file: str) -> torch.nn.Module: arch = metadata["general.architecture"] assert arch == "llama", "Only LLaMa models are supported by this converter." - model_args = TransformerArgs( - dim=metadata[f"{arch}.embedding_length"], - n_layers=metadata[f"{arch}.block_count"], - n_heads=metadata[f"{arch}.attention.head_count"], - n_local_heads=metadata[f"{arch}.attention.head_count_kv"], - vocab_size=len(metadata["tokenizer.ggml.tokens"]), - norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"], - hidden_dim=metadata[f"{arch}.feed_forward_length"], + model_args = ModelArgs( + { + "default": TransformerArgs( + dim=metadata[f"{arch}.embedding_length"], + n_layers=metadata[f"{arch}.block_count"], + n_heads=metadata[f"{arch}.attention.head_count"], + n_local_heads=metadata[f"{arch}.attention.head_count_kv"], + vocab_size=len(metadata["tokenizer.ggml.tokens"]), + norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"], + hidden_dim=metadata[f"{arch}.feed_forward_length"], + ) + } ) # TODO: what to do with rope args like @@ -122,7 +127,7 @@ def load_model(gguf_file: str) -> torch.nn.Module: # metadata.get(f"{arch}.rope.dimension_count", None) with torch.device("meta"): - model = Transformer(model_args) + model = Model(model_args) return model diff --git a/build/model.py b/build/model.py index 2401e5724..9bb846ded 100644 --- a/build/model.py +++ b/build/model.py @@ -21,6 +21,76 @@ config_path = Path(f"{str(Path(__file__).parent)}/known_model_params") +@dataclass +class ModelArgs: + transformer_args_maps: Dict[str, TransformerArgs] + + @classmethod + def from_params(cls, params_path): + with open(params_path, "r") as f: + loaded_params = json.loads(f.read()) + + transformer_args_maps: Dict[str, TransformerArgs] = {} + + try: + # try to interpret as a single transformer config + transformer_args_maps["default"] = TransformerArgs.from_params( + loaded_params + ) + except TypeError: + # try to interpret as a dict of transformer configs + for name, params in loaded_params.items(): + transformer_args_maps[name] = TransformerArgs.from_params(params) + + return cls(transformer_args_maps) + + @classmethod + def from_table(cls, name: str): + json_path = config_path / f"{name}.json" + if json_path.is_file(): + return ModelArgs.from_params(json_path) + else: + known_model_params = [ + config.replace(".json", "") for config in os.listdir(config_path) + ] + raise RuntimeError( + f"unknown table index {name} for transformer config, must be from {known_model_params}" + ) + + @classmethod + def from_name(cls, name: str): + json_path = config_path / f"{name}.json" + if Path(json_path).is_file(): + return ModelArgs.from_params(json_path) + + known_model_params = [ + config.replace(".json", "") for config in os.listdir(config_path) + ] + + print(f"known configs: {known_model_params}") + # Fuzzy search by name (e.g. "7B" and "Mistral-7B") + config = [ + config + for config in known_model_params + if config in str(name).upper() or config in str(name) + ] + + # We may have two or more configs matched (e.g., "7B" and + # "Mistral-7B"). Find the best config match: take longer + # name (as it have more symbols matched) + if len(config) > 1: + config.sort(key=len, reverse=True) + assert len(config[0]) != len( + config[1] + ), name # make sure only one 'best' match + elif len(config) == 0: + raise ValueError( + f"Unknown model directory name {name}. Must be one of {known_model_params}." + ) + + return ModelArgs.from_params(config_path / f"{config[0]}.json") + + @dataclass class TransformerArgs: block_size: int = 2048 @@ -144,6 +214,46 @@ def update(self, input_pos, k_val, v_val): return k_out, v_out +class Model(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + self.transformer_map: Dict[str, Transformer] = {} + for name, config in self.config.transformer_args_maps.items(): + self.transformer_map[name] = Transformer(config) + + assert ( + len(self.transformer_map) == 1 + ), "Only support one transformer model for now" + assert ( + "default" in self.transformer_map + ), '"default" not in self.transformer_map' + + def forward(self, *args, **kwargs) -> Tensor: + return self.transformer_map["default"](*args, **kwargs) + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + @classmethod + def from_table(cls, name: str): + return cls(ModelArgs.from_table(name)) + + @classmethod + def from_params(cls, params_path: str): + return cls(ModelArgs.from_params(params_path)) + + @classmethod + def from_gguf(cls, gguf_path: str, **kwargs): + from build.gguf_loader import load_model_and_state_dict + + model, state_dict = load_model_and_state_dict(gguf_path, **kwargs) + if state_dict != {}: + model.load_state_dict(state_dict, assign=True) + return model + + class Transformer(nn.Module): def __init__(self, config: TransformerArgs) -> None: super().__init__() @@ -180,7 +290,7 @@ def setup_caches(self, max_batch_size, max_seq_length): self.config.dim // self.config.n_heads, self.config.block_size * 2, self.config.rope_base, - use_scaled = self.config.use_scaled_rope, + use_scaled=self.config.use_scaled_rope, ) self.register_buffer("freqs_cis", freqs_cis, persistent=True) causal_mask = torch.tril( @@ -201,27 +311,6 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: # print(f"logits shape: {logits.shape}") return logits - @classmethod - def from_name(cls, name: str): - return cls(TransformerArgs.from_name(name)) - - @classmethod - def from_table(cls, name: str): - return cls(TransformerArgs.from_table(name)) - - @classmethod - def from_params(cls, params_path: str): - return cls(TransformerArgs.from_params(params_path)) - - @classmethod - def from_gguf(cls, gguf_path: str, **kwargs): - from build.gguf_loader import load_model_and_state_dict - - model, state_dict = load_model_and_state_dict(gguf_path, **kwargs) - if state_dict != {}: - model.load_state_dict(state_dict, assign=True) - return model - class TransformerBlock(nn.Module): def __init__(self, config: TransformerArgs) -> None: @@ -388,6 +477,7 @@ def apply_scaling(freqs: torch.Tensor): new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + def precompute_freqs_cis( n_elem: int, seq_len: int, base: int = 10000, dtype=None, use_scaled: bool = False ) -> Tensor: diff --git a/docs/ADVANCED-USERS.md b/docs/ADVANCED-USERS.md index 12ad5f229..4c4f8bcb7 100644 --- a/docs/ADVANCED-USERS.md +++ b/docs/ADVANCED-USERS.md @@ -112,24 +112,24 @@ architecture, provided you have the model weights in llama format, the model parameters and the tokenizer model used by your language model. Some common models are recognized by torchchat based on their filename -through `Transformer.from_name()` to perform a fuzzy match against a +through `Model.from_name()` to perform a fuzzy match against a table of known model architectures. Alternatively, you can specify the index into that table with the option `--params-table ${INDEX}` where the index is the lookup key key in the [the list of known pconfigurations](https://github.com/pytorch/torchchat/tree/main/build/known_model_params) For example, for the stories15M model, this would be expressed as `--params-table stories15M`. (We use the model constructor -`Transformer.from_table()`) +`Model.from_table()`) For models using a configuration not in the list of known configurations, you can construct the model by initializing the -`TransformerArgs` dataclass that controls model construction from a +`ModelArgs` dataclass that controls model construction from a parameter json using the `params-path ${PARAMS_PATH}` containing the -appropriate model parameters to initialize the `TransformerArgs` for the -model. (We use the model constructor `Transformer.from_params()`). +appropriate model parameters to initialize the `ModelArgs` for the +model. (We use the model constructor `Model.from_params()`). The parameter file should be in JSON format specifying these -parameters. You can find the `TransformerArgs` data class in +parameters. You can find the `ModelArgs` data class in [`model.py`](https://github.com/pytorch/torchchat/blob/main/model.py#L22). The final way to initialize a torchchat model is from GGUF. You load a diff --git a/eval.py b/eval.py index 76aa25d31..f1db656ed 100644 --- a/eval.py +++ b/eval.py @@ -16,7 +16,7 @@ TokenizerArgs, ) -from build.model import Transformer +from build.model import Model from build.utils import set_precision from cli import add_arguments_for_verb, arg_init from utils.measure_time import measure_time @@ -35,7 +35,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( - model: Transformer, + model: Model, prompt: torch.Tensor, max_new_tokens: int, max_seq_length: Optional[int] = None, @@ -81,7 +81,7 @@ class GPTFastEvalWrapper(eval_wrapper): def __init__( self, - model: Transformer, + model: Model, tokenizer, model_forward: Optional[Callable] = None, max_seq_length: Optional[int] = None, @@ -153,7 +153,7 @@ def _model_generate(self, context, max_length, eos_token_id): @torch.no_grad() def eval( - model: Transformer, + model: Model, model_forward: Callable, tokenizer, tasks: Optional[list] = None, @@ -165,7 +165,7 @@ def eval( Evaluates a language model on a specified task using the lm-evaluation-harness library. Args: - model (Transformer): The pre-trained language model to evaluate. + model (Model): The pre-trained language model to evaluate. tokenizer: The tokenizer to use for encoding/decoding text. task (str): The name of the evaluation task to perform. limit (Optional[int]): The maximum number of samples to evaluate (None for all available). diff --git a/export_util/export_et.py b/export_util/export_et.py index 606d08688..c27c431a7 100644 --- a/export_util/export_et.py +++ b/export_util/export_et.py @@ -6,7 +6,7 @@ import torch -from build.model import Transformer +from build.model import Model, Transformer from build.utils import get_precision from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( @@ -26,7 +26,7 @@ default_device = "cpu" -def materialze_broadcast_of_rope_freq_cis( +def materialze_broadcast_of_rope_freq_cis_transformer( module: torch.nn.Module, ): assert isinstance(module, Transformer) @@ -52,6 +52,18 @@ def materialze_broadcast_of_rope_freq_cis( return module +def materialze_broadcast_of_rope_freq_cis( + module: torch.nn.Module, +): + assert instance(module, Model) + + for k in module.transformer_map.keys(): + module.transformer_map[k] = materialze_broadcast_of_rope_freq_cis_transformer( + module.transformer_map[k] + ) + return module + + def export_model(model, device, output_path, args=None) -> str: # noqa: C901 input = ( diff --git a/generate.py b/generate.py index 21d54373c..58843797f 100644 --- a/generate.py +++ b/generate.py @@ -23,7 +23,7 @@ BuilderArgs, TokenizerArgs, ) -from build.model import Transformer +from build.model import Model from build.utils import device_sync, set_precision from cli import add_arguments_for_verb, arg_init, check_args from utils.device_info import get_device_info @@ -259,7 +259,7 @@ def sample( def prefill( self, - model: Transformer, + model: Model, x: torch.Tensor, input_pos: torch.Tensor, *, @@ -285,7 +285,7 @@ def prefill( def decode_one_token( self, - model: Transformer, + model: Model, x: torch.Tensor, input_pos: torch.Tensor, need_probs: bool, @@ -305,7 +305,7 @@ def decode_one_token( def decode_n_tokens( self, - model: Transformer, + model: Model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, @@ -374,8 +374,8 @@ def model_forward(self, model, x, input_pos): def speculative_decode( self, - model: Transformer, - draft_model: Transformer, + model: Model, + draft_model: Model, cur_token: torch.Tensor, input_pos: int, speculate_k: int, @@ -439,13 +439,13 @@ def speculative_decode( @torch.no_grad() def generate( self, - model: Transformer, + model: Model, prompt: torch.Tensor, max_new_tokens: int, *, chat_mode: bool, start_pos: int = 0, - draft_model: Transformer, + draft_model: Model, speculate_k: Optional[int] = 8, sequential_prefill=True, callback=lambda x: x, From f0f760e2691331522811bf1af74342a4c59e4b39 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Thu, 1 Aug 2024 11:15:48 -0700 Subject: [PATCH 03/12] make model definition support multiple transformer --- api/api.py | 4 +- build/builder.py | 14 ++-- build/model.py | 124 ++++++++----------------------- distributed/parallelize_llama.py | 2 +- eval.py | 4 +- generate.py | 6 +- 6 files changed, 48 insertions(+), 106 deletions(-) diff --git a/api/api.py b/api/api.py index e52870d60..8494edc18 100644 --- a/api/api.py +++ b/api/api.py @@ -214,11 +214,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.start_pos = 0 self.max_seq_length = ( - self.model.config.max_seq_length + self.model.text_transformer.config.max_seq_length + self.speculative_builder_args.speculate_k + 1 if self.draft_model is not None - else self.model.config.max_seq_length + else self.model.text_transformer.config.max_seq_length ) def completion(self, completion_request: CompletionRequest): diff --git a/build/builder.py b/build/builder.py index ed0fcba06..5a3696e3a 100644 --- a/build/builder.py +++ b/build/builder.py @@ -215,7 +215,7 @@ def validate_model( is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece - use_tiktoken = model.config.use_tiktoken + use_tiktoken = model.text_transformer.config.use_tiktoken if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken): raise RuntimeError( @@ -328,7 +328,6 @@ def _load_model_default(builder_args, only_config=False): mmap=True, ) ) - checkpoint = {} for key in cps[0].keys(): if not torch.allclose(cps[0][key], cps[1][key]): @@ -349,9 +348,10 @@ def _load_model_default(builder_args, only_config=False): if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): checkpoint = checkpoint["model"] + + checkpoint = {'text_transformer.' + k: v for k, v in checkpoint.items()} - model.load_state_dict(checkpoint, assign=True, strict=False) - + model.load_state_dict(checkpoint, assign=True, strict=True) return model @@ -494,7 +494,7 @@ def _initialize_model( try: from build.model_et import PTEModel - model = PTEModel(model.config, builder_args.pte_path) + model = PTEModel(model.text_transformer.config, builder_args.pte_path) except Exception: raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}") else: @@ -510,8 +510,8 @@ def _initialize_model( if builder_args.setup_caches: with torch.device(builder_args.device): - model.setup_caches( - max_batch_size=1, max_seq_length=model.config.max_seq_length + model.text_transformer.setup_caches( + max_batch_size=1, max_seq_length=model.text_transformer.config.max_seq_length ) model.to(dtype=builder_args.precision) diff --git a/build/model.py b/build/model.py index 9bb846ded..75772b807 100644 --- a/build/model.py +++ b/build/model.py @@ -21,76 +21,6 @@ config_path = Path(f"{str(Path(__file__).parent)}/known_model_params") -@dataclass -class ModelArgs: - transformer_args_maps: Dict[str, TransformerArgs] - - @classmethod - def from_params(cls, params_path): - with open(params_path, "r") as f: - loaded_params = json.loads(f.read()) - - transformer_args_maps: Dict[str, TransformerArgs] = {} - - try: - # try to interpret as a single transformer config - transformer_args_maps["default"] = TransformerArgs.from_params( - loaded_params - ) - except TypeError: - # try to interpret as a dict of transformer configs - for name, params in loaded_params.items(): - transformer_args_maps[name] = TransformerArgs.from_params(params) - - return cls(transformer_args_maps) - - @classmethod - def from_table(cls, name: str): - json_path = config_path / f"{name}.json" - if json_path.is_file(): - return ModelArgs.from_params(json_path) - else: - known_model_params = [ - config.replace(".json", "") for config in os.listdir(config_path) - ] - raise RuntimeError( - f"unknown table index {name} for transformer config, must be from {known_model_params}" - ) - - @classmethod - def from_name(cls, name: str): - json_path = config_path / f"{name}.json" - if Path(json_path).is_file(): - return ModelArgs.from_params(json_path) - - known_model_params = [ - config.replace(".json", "") for config in os.listdir(config_path) - ] - - print(f"known configs: {known_model_params}") - # Fuzzy search by name (e.g. "7B" and "Mistral-7B") - config = [ - config - for config in known_model_params - if config in str(name).upper() or config in str(name) - ] - - # We may have two or more configs matched (e.g., "7B" and - # "Mistral-7B"). Find the best config match: take longer - # name (as it have more symbols matched) - if len(config) > 1: - config.sort(key=len, reverse=True) - assert len(config[0]) != len( - config[1] - ), name # make sure only one 'best' match - elif len(config) == 0: - raise ValueError( - f"Unknown model directory name {name}. Must be one of {known_model_params}." - ) - - return ModelArgs.from_params(config_path / f"{config[0]}.json") - - @dataclass class TransformerArgs: block_size: int = 2048 @@ -129,21 +59,42 @@ def __post_init__(self): self.use_tiktoken = self.use_tiktoken == "True" @classmethod - def from_params(cls, params_path): + def from_params(cls, params): replace = [("rope_theta", "rope_base"), ("n_kv_heads", "n_local_heads")] - with open(params_path, "r") as f: - params = json.loads(f.read()) - # Patch for llama3 - for _from, _to in replace: - if _from in params: - params[_to] = params.pop(_from) + for _from, _to in replace: + if _from in params: + params[_to] = params.pop(_from) return cls(**params) +@dataclass +class ModelArgs: + text_transformer_args: TransformerArgs + + @classmethod + def from_params(cls, params_path): + with open(params_path, "r") as f: + loaded_params = json.loads(f.read()) + + try: + # try to interpret as a single transformer config + text_transformer_args = TransformerArgs.from_params( + loaded_params + ) + except TypeError: + # try to interpret as a dict of transformer configs + for name, params in loaded_params.items(): + if name == "text": + text_transformer_args = TransformerArgs.from_params(params) + else: + raise ValueError(f"Unknown transformer name {name}") + + return cls(text_transformer_args) + @classmethod def from_table(cls, name: str): json_path = config_path / f"{name}.json" if json_path.is_file(): - return TransformerArgs.from_params(json_path) + return ModelArgs.from_params(json_path) else: known_model_params = [ config.replace(".json", "") for config in os.listdir(config_path) @@ -156,7 +107,7 @@ def from_table(cls, name: str): def from_name(cls, name: str): json_path = config_path / f"{name}.json" if Path(json_path).is_file(): - return TransformerArgs.from_params(json_path) + return ModelArgs.from_params(json_path) known_model_params = [ config.replace(".json", "") for config in os.listdir(config_path) @@ -183,7 +134,7 @@ def from_name(cls, name: str): f"Unknown model directory name {name}. Must be one of {known_model_params}." ) - return TransformerArgs.from_params(config_path / f"{config[0]}.json") + return ModelArgs.from_params(config_path / f"{config[0]}.json") class KVCache(nn.Module): @@ -218,19 +169,10 @@ class Model(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config - self.transformer_map: Dict[str, Transformer] = {} - for name, config in self.config.transformer_args_maps.items(): - self.transformer_map[name] = Transformer(config) - - assert ( - len(self.transformer_map) == 1 - ), "Only support one transformer model for now" - assert ( - "default" in self.transformer_map - ), '"default" not in self.transformer_map' + self.text_transformer = Transformer(config.text_transformer_args) def forward(self, *args, **kwargs) -> Tensor: - return self.transformer_map["default"](*args, **kwargs) + return self.text_transformer(*args, **kwargs) @classmethod def from_name(cls, name: str): diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py index c4eb17658..b0e7477b0 100644 --- a/distributed/parallelize_llama.py +++ b/distributed/parallelize_llama.py @@ -59,7 +59,7 @@ def apply_tp( # after we apply TP to the model. Because we don't want to change model code # when applying TP. We need to have change to ensure KVCache has the correct # size as k and v. - model.config.n_local_heads = model.config.n_local_heads // tp_mesh.size() + model.text_transformer.config.n_local_heads = model.text_transformer.config.n_local_heads // tp_mesh.size() # Apply tensor parallelism to every transformer block for transformer_block in model.layers: diff --git a/eval.py b/eval.py index f1db656ed..f90ce4857 100644 --- a/eval.py +++ b/eval.py @@ -58,7 +58,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( T = prompt.size(0) T_new = T + max_new_tokens if max_seq_length is None: - max_seq_length = min(T_new, model.config.block_size) + max_seq_length = min(T_new, model.text_transformer.config.block_size) device, dtype = prompt.device, prompt.dtype # create an empty tensor of the expected final shape and @@ -69,7 +69,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( input_pos = torch.arange(0, T, device=device) with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + model.text_transformer.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) return seq, input_pos, max_seq_length diff --git a/generate.py b/generate.py index 58843797f..053ee175f 100644 --- a/generate.py +++ b/generate.py @@ -467,7 +467,7 @@ def generate( if start_pos == 0: model = model.to(device=device) with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + model.text_transformer.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) if is_speculative and draft_model is not model: draft_model.setup_caches( max_batch_size=1, max_seq_length=max_seq_length @@ -628,7 +628,7 @@ def chat( self.system_prompt = None # Set up our max_seq_length if generator_args.chat_mode: - max_seq_length = self.model.config.max_seq_length + max_seq_length = self.model.text_transformer.config.max_seq_length print( f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye" ) @@ -642,7 +642,7 @@ def chat( else: max_seq_length = min( encoded.size(0) + generator_args.max_new_tokens, - self.model.config.block_size, + self.model.text_transformer.config.block_size, ) max_seq_length = ( From 2cb794ed963af4e36b8097289940004cf34abed4 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Thu, 1 Aug 2024 13:13:29 -0700 Subject: [PATCH 04/12] make model definition support multiple transformer --- build/builder.py | 2 +- build/model.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/build/builder.py b/build/builder.py index 5a3696e3a..1ee5ae805 100644 --- a/build/builder.py +++ b/build/builder.py @@ -215,7 +215,7 @@ def validate_model( is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece - use_tiktoken = model.text_transformer.config.use_tiktoken + use_tiktoken = model.config.text_transformer_args.use_tiktoken if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken): raise RuntimeError( diff --git a/build/model.py b/build/model.py index 75772b807..604f23b47 100644 --- a/build/model.py +++ b/build/model.py @@ -70,6 +70,10 @@ def from_params(cls, params): class ModelArgs: text_transformer_args: TransformerArgs + def __post_init__(self): + assert self.text_transformer_args is not None + assert type(self.text_transformer_args) == TransformerArgs + @classmethod def from_params(cls, params_path): with open(params_path, "r") as f: From 7b757e1f4c7d6a36c5aa97c7b37138f49743396d Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Thu, 1 Aug 2024 13:18:35 -0700 Subject: [PATCH 05/12] make model definition support multiple transformer --- build/gguf_loader.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/build/gguf_loader.py b/build/gguf_loader.py index 68c5e2814..bcc4950c5 100644 --- a/build/gguf_loader.py +++ b/build/gguf_loader.py @@ -13,13 +13,13 @@ import torch +from build.gguf_util import Q4_0, to_float +from build.model import Model, ModelArgs, TransformerArgs + from gguf import GGUFValueType from quantization.qops import LinearInt4 as WeightOnlyInt4Linear from quantization.quantize import pack_scales_and_zeros -from build.gguf_util import Q4_0, to_float -from build.model import Model, ModelArgs, TransformerArgs - logger: logging.Logger = logging.getLogger(__name__) @@ -109,17 +109,15 @@ def load_model(gguf_file: str) -> torch.nn.Module: assert arch == "llama", "Only LLaMa models are supported by this converter." model_args = ModelArgs( - { - "default": TransformerArgs( - dim=metadata[f"{arch}.embedding_length"], - n_layers=metadata[f"{arch}.block_count"], - n_heads=metadata[f"{arch}.attention.head_count"], - n_local_heads=metadata[f"{arch}.attention.head_count_kv"], - vocab_size=len(metadata["tokenizer.ggml.tokens"]), - norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"], - hidden_dim=metadata[f"{arch}.feed_forward_length"], - ) - } + TransformerArgs( + dim=metadata[f"{arch}.embedding_length"], + n_layers=metadata[f"{arch}.block_count"], + n_heads=metadata[f"{arch}.attention.head_count"], + n_local_heads=metadata[f"{arch}.attention.head_count_kv"], + vocab_size=len(metadata["tokenizer.ggml.tokens"]), + norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"], + hidden_dim=metadata[f"{arch}.feed_forward_length"], + ) ) # TODO: what to do with rope args like From 0992784b8e79df23497ee2f8cc021d3481901dac Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Thu, 1 Aug 2024 14:30:42 -0700 Subject: [PATCH 06/12] make input arg static in Model to support export --- build/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build/model.py b/build/model.py index 604f23b47..88080e83b 100644 --- a/build/model.py +++ b/build/model.py @@ -175,8 +175,8 @@ def __init__(self, config: ModelArgs) -> None: self.config = config self.text_transformer = Transformer(config.text_transformer_args) - def forward(self, *args, **kwargs) -> Tensor: - return self.text_transformer(*args, **kwargs) + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + return self.text_transformer(idx, input_pos) @classmethod def from_name(cls, name: str): From 665df2e20a3b06c5bdc70bd822dd1f08c82a67be Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Thu, 1 Aug 2024 14:57:08 -0700 Subject: [PATCH 07/12] fix bugs for gguf and et in new model definition architecture --- build/builder.py | 4 ++-- build/gguf_loader.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/build/builder.py b/build/builder.py index 1ee5ae805..1ed015ca8 100644 --- a/build/builder.py +++ b/build/builder.py @@ -349,7 +349,7 @@ def _load_model_default(builder_args, only_config=False): if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): checkpoint = checkpoint["model"] - checkpoint = {'text_transformer.' + k: v for k, v in checkpoint.items()} + checkpoint = {"text_transformer." + k: v for k, v in checkpoint.items()} model.load_state_dict(checkpoint, assign=True, strict=True) return model @@ -494,7 +494,7 @@ def _initialize_model( try: from build.model_et import PTEModel - model = PTEModel(model.text_transformer.config, builder_args.pte_path) + model = PTEModel(model.config, builder_args.pte_path) except Exception: raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}") else: diff --git a/build/gguf_loader.py b/build/gguf_loader.py index bcc4950c5..2769fad3f 100644 --- a/build/gguf_loader.py +++ b/build/gguf_loader.py @@ -41,7 +41,7 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str: result = copy.deepcopy(gguf_name) for gguf_string, replacement in _name_replacements: - result = result.replace(gguf_string, replacement) + result = "text_transformer." + result.replace(gguf_string, replacement) return result From 479bb44681958237808a48a4dc44339d05cc2173 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Mon, 5 Aug 2024 17:24:49 -0700 Subject: [PATCH 08/12] retrieve text transformer arg from modelargs --- api/api.py | 4 ++-- build/builder.py | 2 +- build/gguf_loader.py | 3 ++- distributed/parallelize_llama.py | 2 +- eval.py | 2 +- export.py | 2 +- generate.py | 4 ++-- 7 files changed, 10 insertions(+), 9 deletions(-) diff --git a/api/api.py b/api/api.py index 8494edc18..c5ef4bbf6 100644 --- a/api/api.py +++ b/api/api.py @@ -214,11 +214,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.start_pos = 0 self.max_seq_length = ( - self.model.text_transformer.config.max_seq_length + self.model.config.text_transformer_args.max_seq_length + self.speculative_builder_args.speculate_k + 1 if self.draft_model is not None - else self.model.text_transformer.config.max_seq_length + else self.model.config.text_transformer_args.max_seq_length ) def completion(self, completion_request: CompletionRequest): diff --git a/build/builder.py b/build/builder.py index 49434dd94..5ec057faf 100644 --- a/build/builder.py +++ b/build/builder.py @@ -517,7 +517,7 @@ def _initialize_model( if builder_args.setup_caches: with torch.device(builder_args.device): model.text_transformer.setup_caches( - max_batch_size=1, max_seq_length=model.text_transformer.config.max_seq_length + max_batch_size=1, max_seq_length=model.config.text_transformer_args.max_seq_length ) model.to(dtype=builder_args.precision) diff --git a/build/gguf_loader.py b/build/gguf_loader.py index 2769fad3f..fac7529f9 100644 --- a/build/gguf_loader.py +++ b/build/gguf_loader.py @@ -41,7 +41,8 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str: result = copy.deepcopy(gguf_name) for gguf_string, replacement in _name_replacements: - result = "text_transformer." + result.replace(gguf_string, replacement) + result = result.replace(gguf_string, replacement) + result = "text_transformer." + result return result diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py index b0e7477b0..cb9d7f860 100644 --- a/distributed/parallelize_llama.py +++ b/distributed/parallelize_llama.py @@ -59,7 +59,7 @@ def apply_tp( # after we apply TP to the model. Because we don't want to change model code # when applying TP. We need to have change to ensure KVCache has the correct # size as k and v. - model.text_transformer.config.n_local_heads = model.text_transformer.config.n_local_heads // tp_mesh.size() + model.config.text_transformer_args.n_local_heads = model.config.text_transformer_args.n_local_heads // tp_mesh.size() # Apply tensor parallelism to every transformer block for transformer_block in model.layers: diff --git a/eval.py b/eval.py index f90ce4857..3d0152c17 100644 --- a/eval.py +++ b/eval.py @@ -58,7 +58,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( T = prompt.size(0) T_new = T + max_new_tokens if max_seq_length is None: - max_seq_length = min(T_new, model.text_transformer.config.block_size) + max_seq_length = min(T_new, model.config.text_transformer_args.block_size) device, dtype = prompt.device, prompt.dtype # create an empty tensor of the expected final shape and diff --git a/export.py b/export.py index 0d1285eab..0ad5dd061 100644 --- a/export.py +++ b/export.py @@ -54,7 +54,7 @@ def export_for_server( torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), ) - seq = Dim("seq", min=1, max=model.config.max_seq_length) + seq = Dim("seq", min=1, max=model.config.text_transformer_args.max_seq_length) # Specify that the first dimension of each input is that batch size dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}} diff --git a/generate.py b/generate.py index 51a3fbce0..acb1d3844 100644 --- a/generate.py +++ b/generate.py @@ -630,7 +630,7 @@ def chat( self.system_prompt = None # Set up our max_seq_length if generator_args.chat_mode: - max_seq_length = self.model.text_transformer.config.max_seq_length + max_seq_length = self.model.config.text_transformer_args.max_seq_length print( f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye" ) @@ -644,7 +644,7 @@ def chat( else: max_seq_length = min( encoded.size(0) + generator_args.max_new_tokens, - self.model.text_transformer.config.block_size, + self.model.config.text_transformer_args.block_size, ) max_seq_length = ( From 0edf73ab3cb2f9418cf3fafffcddea7344d2b4ab Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Mon, 5 Aug 2024 18:01:43 -0700 Subject: [PATCH 09/12] add set_cache funtion to Model to work around PTEModel issue --- build/builder.py | 2 +- build/model.py | 3 +++ eval.py | 2 +- generate.py | 2 +- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/build/builder.py b/build/builder.py index 5ec057faf..f79e4ac39 100644 --- a/build/builder.py +++ b/build/builder.py @@ -516,7 +516,7 @@ def _initialize_model( if builder_args.setup_caches: with torch.device(builder_args.device): - model.text_transformer.setup_caches( + model.setup_caches( max_batch_size=1, max_seq_length=model.config.text_transformer_args.max_seq_length ) diff --git a/build/model.py b/build/model.py index 88080e83b..0d2d342bb 100644 --- a/build/model.py +++ b/build/model.py @@ -177,6 +177,9 @@ def __init__(self, config: ModelArgs) -> None: def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: return self.text_transformer(idx, input_pos) + + def setup_caches(self, max_batch_size, max_seq_length): + self.text_transformer.setup_caches(max_batch_size, max_seq_length) @classmethod def from_name(cls, name: str): diff --git a/eval.py b/eval.py index 3d0152c17..e8f487513 100644 --- a/eval.py +++ b/eval.py @@ -69,7 +69,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( input_pos = torch.arange(0, T, device=device) with torch.device(device): - model.text_transformer.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) return seq, input_pos, max_seq_length diff --git a/generate.py b/generate.py index acb1d3844..335d6f58a 100644 --- a/generate.py +++ b/generate.py @@ -469,7 +469,7 @@ def generate( if start_pos == 0: model = model.to(device=device) with torch.device(device): - model.text_transformer.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) if is_speculative and draft_model is not model: draft_model.setup_caches( max_batch_size=1, max_seq_length=max_seq_length From 960aae5494984e54df7b4491f43ebd16f34785eb Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 20 Aug 2024 11:57:07 -0700 Subject: [PATCH 10/12] make torchchat rely on torchtune --- install_requirements.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/install_requirements.sh b/install_requirements.sh index 925872ec4..69ca36de5 100755 --- a/install_requirements.sh +++ b/install_requirements.sh @@ -77,6 +77,11 @@ REQUIREMENTS_TO_INSTALL=( $PIP_EXECUTABLE install --extra-index-url "${TORCH_NIGHTLY_URL}" \ "${REQUIREMENTS_TO_INSTALL[@]}" ) +# Install torchtune separately with the --pre flag +( + set -x + $PIP_EXECUTABLE install --pre torchtune --extra-index-url "${TORCH_NIGHTLY_URL}" --no-cache-dir +) # For torchao need to install from github since nightly build doesn't have macos build. # TODO: Remove this and install nightly build, once it supports macos From d52aac41577078770980d363cbfdb7c61ab50534 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 27 Aug 2024 09:45:20 -0700 Subject: [PATCH 11/12] remove export_util --- export_util/export_et.py | 135 --------------------------------------- 1 file changed, 135 deletions(-) delete mode 100644 export_util/export_et.py diff --git a/export_util/export_et.py b/export_util/export_et.py deleted file mode 100644 index c27c431a7..000000000 --- a/export_util/export_et.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -import torch -from build.model import Model, Transformer -from build.utils import get_precision - -from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( - XnnpackDynamicallyQuantizedPartitioner, -) - -# TODO: change back to executorch.examples.portable.utils -# when executorch installs correctly - -from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig -from executorch.exir.passes.quant_fusion_pass import QuantFusionPass -from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass -from export_util.executorch_portable_utils import export_to_edge -from export_util.export_et_util import replace_attention_with_custom_sdpa_attention -from torch._export import capture_pre_autograd_graph - -default_device = "cpu" - - -def materialze_broadcast_of_rope_freq_cis_transformer( - module: torch.nn.Module, -): - assert isinstance(module, Transformer) - assert module.freqs_cos.dim() == 2 - dim0 = module.freqs_cos.size(0) - dim1 = module.freqs_cos.size(1) - assert ( - module.layers[0].attention.n_local_kv_heads - == module.layers[0].attention.n_local_heads - ), f"For rope freqs to be materialzed for broadcast q, k, v num heads must match. For q got {module.attention.n_kv_heads} for k got {module.attention.n_local_heads} and v got {module.attention.n_local_kv_heads}" - num_heads = module.layers[0].attention.n_local_heads - module.freqs_cos = module.freqs_cos.view(dim0, 1, dim1) - module.freqs_cos = module.freqs_cos.expand(dim0, num_heads, dim1).contiguous() - assert module.freqs_sin.dim() == 2 - assert dim0 == module.freqs_sin.size( - 0 - ), f"sin and cos freq table sizes must match. Mismatch found at dim 0: {dim0} vs {module.freqs_sin.size(0)}" - assert dim1 == module.freqs_sin.size( - 1 - ), f"sin and cos freq table sizes must match. Mismatch found at dim 1: {dim1} vs {module.freqs_sin.size(1)}" - module.freqs_sin = module.freqs_sin.view(dim0, 1, dim1) - module.freqs_sin = module.freqs_sin.expand(dim0, num_heads, dim1).contiguous() - return module - - -def materialze_broadcast_of_rope_freq_cis( - module: torch.nn.Module, -): - assert instance(module, Model) - - for k in module.transformer_map.keys(): - module.transformer_map[k] = materialze_broadcast_of_rope_freq_cis_transformer( - module.transformer_map[k] - ) - return module - - -def export_model(model, device, output_path, args=None) -> str: # noqa: C901 - - input = ( - torch.tensor([[1]], dtype=torch.long, device=device), - torch.tensor([0], dtype=torch.long, device=device), - ) - - state_dict = model.state_dict() - state_dict_dtype = state_dict[next(iter(state_dict))].dtype - target_precision = get_precision() - dynamic_shapes = None - - # TODO: need to use kv sdpa? - edge_config = EdgeCompileConfig( - _check_ir_validity=False, - _skip_type_promotion=bool(target_precision == torch.float16), - ) - - if target_precision == torch.float16 or target_precision == torch.bfloat16: - if state_dict_dtype != torch.float16: - print("model.to torch.float16") - model = model.to(dtype=torch.float16) - state_dict_dtype = torch.float16 - elif target_precision == torch.float32: - if state_dict_dtype != torch.float32: - print("model.to torch.float32") - model = model.to(dtype=torch.float32) - elif target_precision == torch.bfloat16: - print("model.to torch.bfloat16") - model = model.to(dtype=torch.bfloat16) - else: - raise ValueError(f"Unsupported dtype for ET export: {target_precision}") - - replace_attention_with_custom_sdpa_attention(model) - with torch.nn.attention.sdpa_kernel( - [torch.nn.attention.SDPBackend.MATH] - ), torch.no_grad(): - m = capture_pre_autograd_graph(model, input, dynamic_shapes=dynamic_shapes) - - edge_manager = export_to_edge( - m, - input, - dynamic_shapes=dynamic_shapes, - edge_compile_config=edge_config, - ) - edge_manager = edge_manager.to_backend(XnnpackDynamicallyQuantizedPartitioner()) - # Delegation visualization APIs: https://pytorch.org/executorch/main/debug-backend-delegate.html - # from executorch.exir.backend.utils import get_delegation_info, format_delegated_graph - # from tabulate import tabulate - # graph_module = edge_manager.exported_program().graph_module - # delegation_info = get_delegation_info(graph_module) - # print(delegation_info.get_summary()) - # print(format_delegated_graph(graph_module)) - export_program = edge_manager.to_executorch( - ExecutorchBackendConfig( - extract_constant_segment=True, - extract_delegate_segments=True, - passes=[ - QuantFusionPass(), - ], - sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), - ) - ) - - print("The methods are: ", export_program.methods) - with open(output_path, "wb") as f: - export_program.write_to_file(f) - - return output_path From 11113ff322b63327f0d19f6ab84bfb9ce29c9119 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 27 Aug 2024 09:47:03 -0700 Subject: [PATCH 12/12] extra torchtune dependency --- install_requirements.sh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/install_requirements.sh b/install_requirements.sh index 2048025bb..9baac5ec0 100755 --- a/install_requirements.sh +++ b/install_requirements.sh @@ -77,11 +77,6 @@ REQUIREMENTS_TO_INSTALL=( $PIP_EXECUTABLE install --extra-index-url "${TORCH_NIGHTLY_URL}" \ "${REQUIREMENTS_TO_INSTALL[@]}" ) -# Install torchtune separately with the --pre flag -( - set -x - $PIP_EXECUTABLE install --pre torchtune --extra-index-url "${TORCH_NIGHTLY_URL}" --no-cache-dir -) # For torchao need to install from github since nightly build doesn't have macos build. # TODO: Remove this and install nightly build, once it supports macos