diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index e2d256ea396..940887e3f1f 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -794,10 +794,6 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: ) ) - # Now cast to the dtype override after quantization, so non-quantized - # components use the desired computation dtype. - edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype()) - return edge_manager @@ -1857,6 +1853,12 @@ def _get_source_transforms( # noqa ) ) + # Cast to dtype_override after quantization transforms, so non-quantized + # components use the desired computation dtype. This must happen before + # _convert_model_for_aarch64 which converts IntxUnpackedToInt8Tensor to + # IntxOpaqueTensor (which doesn't support .to()). + transforms.append(lambda m: m.to(dtype=dtype_override.to_torch_dtype())) + if any([use_torchao_kernels_linear, use_torchao_kernels_tied_embedding]): from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64