From 8a6aa2d9d5a58b20035845ade8d8a19aa34f95c1 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 23 Sep 2025 13:34:44 -0700 Subject: [PATCH] Add conversion options Summary: Add convert_linear/convert_tied_embedding options to _convert_model_for_aarch64 so that certain module conversions can be disabled. Differential Revision: D83087813 --- torchao/prototype/tensor_conversion/api.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/tensor_conversion/api.py b/torchao/prototype/tensor_conversion/api.py index 6533e5de2d..cbba955bc1 100644 --- a/torchao/prototype/tensor_conversion/api.py +++ b/torchao/prototype/tensor_conversion/api.py @@ -124,9 +124,16 @@ def _find_tied_params(model): def _convert_model_for_aarch64( - model, *, tensor_type="auto", intx_packing_format="opaque_torchao_auto" + model, + *, + tensor_type="auto", + intx_packing_format="opaque_torchao_auto", + convert_tied_embedding=True, + convert_linear=True, ): - module_name_to_tied_param = _find_tied_params(model) + module_name_to_tied_param = ( + _find_tied_params(model) if convert_tied_embedding else {} + ) # Iterate through modules in model and convert IntxUnpackedToInt8Tensor tensors to Int8LutTensor for name, module in model.named_modules(): @@ -138,7 +145,7 @@ def _convert_model_for_aarch64( print("Skipping converting nn.Embedding {name} because it is not tied") continue - if not isinstance(module, nn.Linear): + if not (convert_linear and isinstance(module, nn.Linear)): continue weight = module.weight