diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 504be3c6343..88d2bc0cab9 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -34,7 +34,7 @@ from executorch.extension.export_util.utils import export_to_edge, save_pte_program -from executorch.extension.llm.export.export_passes import RemoveRedundantPermutes +from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes from executorch.extension.llm.tokenizer.utils import get_tokenizer from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer import Quantizer @@ -113,7 +113,7 @@ def __init__( self.calibration_seq_length = calibration_seq_length self.calibration_data = calibration_data self.tokenizer_path = tokenizer_path - self.canonical_passes = [RemoveRedundantPermutes()] + self.canonical_passes = [RemoveRedundantTransposes()] def set_output_dir(self, output_dir: str) -> "LLMEdgeManager": """ diff --git a/extension/llm/export/export_passes.py b/extension/llm/export/export_passes.py index 942de095805..154c9eb862f 100644 --- a/extension/llm/export/export_passes.py +++ b/extension/llm/export/export_passes.py @@ -19,7 +19,7 @@ def _normalize_dims(tensor: FakeTensor, dim_0: int, dim_1: int): return dim_0, dim_1 -class RemoveRedundantPermutes(ExportPass): +class RemoveRedundantTransposes(ExportPass): """ This pass removes redundant transpose nodes in the graph. It checks if the next node is also a transpose node and if the two transpose nodes undo each other. diff --git a/pytest.ini b/pytest.ini index 1502f1749f2..da96469d1e5 100644 --- a/pytest.ini +++ b/pytest.ini @@ -49,6 +49,7 @@ addopts = backends/xnnpack/test/serialization # extension/ extension/llm/modules/test + extension/llm/export extension/pybindings/test # Runtime runtime