diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index fafb69d878b..9bd16fa7c07 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -84,7 +84,7 @@ runtime.python_library( "source_transformation/apply_spin_quant_r1_r2.py", "source_transformation/lora.py", "source_transformation/pre_quantization.py", - "source_transformation/prune_output.py", + "source_transformation/prune_vocab.py", "source_transformation/quantize.py", "source_transformation/quantized_kv_cache.py", "source_transformation/rms_norm.py", diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 04bd5bddaaf..a0b44fb9652 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -437,6 +437,12 @@ def build_args_parser() -> argparse.ArgumentParser: default=None, help="path to the output pruning token mapping file (token_map.json)", ) + + parser.add_argument( + "--input_prune_map", + default=None, + help="path to the input pruning token mapping file (token_map.json)", + ) return parser @@ -525,6 +531,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: tokenizer_path=args.tokenizer_path, verbose=args.verbose, max_seq_len=args.max_seq_length, + input_prune_map_path=args.input_prune_map, output_prune_map_path=args.output_prune_map, metadata_str=args.metadata, dtype_override=dtype_override, @@ -766,6 +773,7 @@ def _load_llama_model( tokenizer_path: Optional[str] = None, verbose: bool = False, max_seq_len: int = 128, + input_prune_map_path: Optional[str] = None, output_prune_map_path: Optional[str] = None, metadata_str: Optional[str] = None, dtype_override: Optional[DType] = None, @@ -795,6 +803,7 @@ def _load_llama_model( fairseq2=weight_type == WeightType.FAIRSEQ2, max_seq_len=max_seq_len, enable_dynamic_shape=enable_dynamic_shape, + input_prune_map_path=input_prune_map_path, output_prune_map_path=output_prune_map_path, args=args, ) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 5a96e49ef1b..3f93498fbab 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -103,6 +103,8 @@ class ModelArgs: generate_full_logits: bool = False enable_dynamic_shape: bool = False # export model with dynamic shape support # A dictionary mapping from pruned token-id to original token-id + input_prune_map: Optional[Dict[int, int]] = None + # A dictionary mapping from pruned token-id to original token-id output_prune_map: Optional[Dict[int, int]] = None use_hf_rope: bool = False # Use HuggingFace's RoPE implementation rope_theta: Optional[float] = ( @@ -461,6 +463,7 @@ def __init__(self, params: ModelArgs): self.use_kv_cache = params.use_kv_cache self.generate_full_logits = params.generate_full_logits self.max_seq_len = params.max_seq_len + self.input_prune_map = params.input_prune_map self.output_prune_map = params.output_prune_map if params.use_hf_rope: self.precompute_freqs_cis = hf_precompute_freqs_cis diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index e6f39e0cad5..0f83e404a3c 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -49,6 +49,7 @@ def __init__(self, **kwargs): self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False) self.generate_full_logits = kwargs.get("generate_full_logits", False) self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) + self.input_prune_map_path = kwargs.get("input_prune_map_path", None) self.output_prune_map_path = kwargs.get("output_prune_map_path", None) self.max_seq_len = kwargs.get("max_seq_len", 128) self.args = kwargs.get("args", None) @@ -126,6 +127,12 @@ def __init__(self, **kwargs): output_prune_map = json.load(f) # Change keys from string to int (json only supports string keys). output_prune_map = {int(k): v for (k, v) in output_prune_map.items()} + input_prune_map = None + if self.input_prune_map_path is not None: + with open(self.input_prune_map_path, "r") as f: + input_prune_map = json.load(f) + # Change keys from string to int (json only supports string keys). + input_prune_map = {int(k): v for (k, v) in input_prune_map.items()} model_args: ModelArgs = ModelArgs( max_seq_len=self.max_seq_len, @@ -133,6 +140,7 @@ def __init__(self, **kwargs): use_kv_cache=self.use_kv_cache, use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op, generate_full_logits=self.generate_full_logits, + input_prune_map=input_prune_map, output_prune_map=output_prune_map, enable_dynamic_shape=self.enable_dynamic_shape, **params, @@ -209,9 +217,15 @@ def __init__(self, **kwargs): print(unexpected) print("============= /unexpected ================") + # Prune the input layer if input_prune_map is provided + if input_prune_map is not None: + from .source_transformation.prune_vocab import prune_input_vocab + + self.model_ = prune_input_vocab(self.model_, input_prune_map) + # Prune the output layer if output_prune_map is provided if output_prune_map is not None: - from .source_transformation.prune_output import prune_output_vocab + from .source_transformation.prune_vocab import prune_output_vocab self.model_ = prune_output_vocab(self.model_, output_prune_map) diff --git a/examples/models/llama/source_transformation/prune_output.py b/examples/models/llama/source_transformation/prune_vocab.py similarity index 59% rename from examples/models/llama/source_transformation/prune_output.py rename to examples/models/llama/source_transformation/prune_vocab.py index 6d02d52fa5c..1751059f5dc 100644 --- a/examples/models/llama/source_transformation/prune_output.py +++ b/examples/models/llama/source_transformation/prune_vocab.py @@ -69,3 +69,51 @@ def prune_output_vocab( setattr(model, output_layer_name, pruned_layer) return model + + +def prune_input_vocab( + model: torch.nn.Module, + token_map: Dict[int, int], + imput_layer_name: str = "tok_embeddings", +) -> torch.nn.Module: + """Prune the model input embedding layer while keeping the tokens in the token map. + + Note: Pruning is performed in-place. + + Args: + model: The model to prune. + token_map: A dictionary mapping from new token ids to the old token ids to preserve. + e.g. {0: 221, 1: 1325, 2: 1542, 3: 1728, 4: 18243} + imput_layer_name: name of the input embedding layer to prune + + Returns: + The pruned model. + """ + assert hasattr( + model, imput_layer_name + ), f"Model does not have {imput_layer_name} layer" + input_layer = getattr(model, imput_layer_name) + assert isinstance( + input_layer, torch.nn.Embedding + ), "Input layer is not an Embedding layer" + original_shape = input_layer.weight.shape + num_pruned_tokens = len(token_map) + weight_dtype = input_layer.weight.dtype + pruned_layer = torch.nn.Embedding(num_pruned_tokens, original_shape[1]) + pruned_layer.to(dtype=weight_dtype) + pruned_layer_weights = np.zeros(pruned_layer.weight.shape, dtype=np.float32) + for i, token_id in token_map.items(): + # Copy the weights from the original layer to the pruned layer + pruned_wt = input_layer.weight[token_id].detach() + if weight_dtype == torch.bfloat16: + pruned_wt = pruned_wt.float() + pruned_layer_weights[i] = pruned_wt.numpy() + with torch.no_grad(): + pruned_layer.weight.copy_( + torch.tensor(pruned_layer_weights, dtype=weight_dtype) + ) + + # Replace the original layer with the pruned layer + setattr(model, imput_layer_name, pruned_layer) + + return model