From 62a2e61ebbda6607142a6c58492d30cc96054a51 Mon Sep 17 00:00:00 2001 From: Dark Knight Date: Wed, 8 Oct 2025 10:44:28 -0700 Subject: [PATCH] Revert D82355346 Summary: This diff reverts D82355346 This diff break the cache hits for PT2 cache, as result, we saw a jump in PT2 compilation time Depends on D82355346 Reviewed By: HugeEngine Differential Revision: D84169566 --- test/quantization/test_da8w4_cpu.py | 15 +-------------- .../uintx/dyn_int8_act_int4_wei_cpu_layout.py | 4 ++-- .../fx_passes/da8w4_concat_linear_fusion_cpu.py | 12 +----------- 3 files changed, 4 insertions(+), 27 deletions(-) diff --git a/test/quantization/test_da8w4_cpu.py b/test/quantization/test_da8w4_cpu.py index 80094beb2d..d4f68c4333 100644 --- a/test/quantization/test_da8w4_cpu.py +++ b/test/quantization/test_da8w4_cpu.py @@ -8,7 +8,6 @@ import unittest import torch -from torch._dynamo.utils import counters from torch.testing._internal import common_utils from torch.testing._internal.common_utils import ( TestCase, @@ -121,6 +120,7 @@ def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a): @common_utils.parametrize("x_dim", [2, 3]) @common_utils.parametrize("bias", [True, False]) def test_8da4w_concat_linear_cpu(self, x_dim, bias): + self.skipTest("Disabled for now") N, K = 64, 128 class Mod(torch.nn.Module): @@ -163,15 +163,6 @@ def forward(self, x): # ensure the expected op occurs only once in the code after fusion # The trailing "(" is to avoid matching the op in the comment assert code[0].count("torch.ops.torchao.da8w4_linear_cpu.default(") == 1 - - # Ensure that when concat linear is enabled, fxgraph cache works - # without being bypassed (fxgraph_cache_bypass = 0), indicating that - # DA8W4ConcatLinearCPUPass properly implements the CustomGraphPass - # interface and uuid() function, allowing fxgraph to be saved and hit - # on subsequent runs (fxgraph_cache_hit > 0). - fx_cache_bypass_count = counters["inductor"]["fxgraph_cache_bypass"] - assert fx_cache_bypass_count == 0 - with torch._inductor.config.patch( {"freezing": True, "cpp.enable_concat_linear": False} ): @@ -181,10 +172,6 @@ def forward(self, x): ) assert torch.allclose(y, y_ref) - # Ensure that the fxgraph cache is also not bypassed when concat linear is disabled - fx_cache_bypass_count = counters["inductor"]["fxgraph_cache_bypass"] - assert fx_cache_bypass_count == 0 - common_utils.instantiate_parametrized_tests(TestDa8w4Cpu) diff --git a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py index c0f2fcdfe5..8d0cfaddeb 100644 --- a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py +++ b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py @@ -314,6 +314,6 @@ def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): # Register the concat linear fusion pass -from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass +# from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass -register_da8w4_concat_linear_cpu_pass() +# register_da8w4_concat_linear_cpu_pass() diff --git a/torchao/prototype/inductor/fx_passes/da8w4_concat_linear_fusion_cpu.py b/torchao/prototype/inductor/fx_passes/da8w4_concat_linear_fusion_cpu.py index 8e39826f4c..12b1a4696b 100644 --- a/torchao/prototype/inductor/fx_passes/da8w4_concat_linear_fusion_cpu.py +++ b/torchao/prototype/inductor/fx_passes/da8w4_concat_linear_fusion_cpu.py @@ -7,15 +7,6 @@ import operator import torch -from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files - - -class DA8W4ConcatLinearCPUPass(CustomGraphPass): - def __call__(self, graph: torch.fx.Graph): - _concat_linear_dq8w4_cpu(graph) - - def uuid(self): - return get_hash_for_files((__file__,)) # Inductor FX passes for concat linear for DA8W4 @@ -222,5 +213,4 @@ def ... def register_da8w4_concat_linear_cpu_pass(): from torch._inductor import config as inductor_config - da8w4_concat_linear_cpu_pass = DA8W4ConcatLinearCPUPass() - inductor_config.post_grad_custom_post_pass = da8w4_concat_linear_cpu_pass + inductor_config.post_grad_custom_post_pass = _concat_linear_dq8w4_cpu