diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 3da0929ea7..575aca6df0 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -67,7 +67,7 @@ jobs: dev-requirements-overrides: "" - name: CUDA 2.7 runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: 'torch==2.7.0' + torch-spec: 'torch==2.7.1' gpu-arch-type: "cuda" gpu-arch-version: "12.6" dev-requirements-overrides: "" @@ -77,6 +77,12 @@ jobs: gpu-arch-type: "cuda" gpu-arch-version: "12.6" dev-requirements-overrides: "" + - name: CUDA 2.9 + runs-on: linux.g5.12xlarge.nvidia.gpu + torch-spec: 'torch==2.9.1' + gpu-arch-type: "cuda" + gpu-arch-version: "12.6" + dev-requirements-overrides: "" - name: CPU 2.6 runs-on: linux.4xlarge @@ -86,7 +92,7 @@ jobs: dev-requirements-overrides: "" - name: CPU 2.7 runs-on: linux.4xlarge - torch-spec: 'torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu' + torch-spec: 'torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu' gpu-arch-type: "cpu" gpu-arch-version: "" dev-requirements-overrides: "" @@ -96,6 +102,12 @@ jobs: gpu-arch-type: "cpu" gpu-arch-version: "" dev-requirements-overrides: "" + - name: CPU 2.9 + runs-on: linux.4xlarge + torch-spec: 'torch==2.9.1 --index-url https://download.pytorch.org/whl/cpu' + gpu-arch-type: "cpu" + gpu-arch-version: "" + dev-requirements-overrides: "" uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: diff --git a/torchao/quantization/pt2e/utils.py b/torchao/quantization/pt2e/utils.py index 7ff1dbc619..f3cbffa430 100644 --- a/torchao/quantization/pt2e/utils.py +++ b/torchao/quantization/pt2e/utils.py @@ -859,6 +859,16 @@ def _get_aten_graph_module_for_pattern( ): aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr] + if torch.__version__.startswith("2.9"): + # PyTorch 2.9 adds _guards_fn nodes to exported graphs. + # These have errors only on torch 2.9.0 and 2.9.1 + for node in list(aten_pattern.graph.nodes): # type: ignore[union-attr] + if node.op == "call_module" and node.name == "_guards_fn": + aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr] + # Also remove the _guards_fn module from the graph module if it exists + if hasattr(aten_pattern, "_guards_fn"): + delattr(aten_pattern, "_guards_fn") + aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] aten_pattern.recompile() # type: ignore[operator]