From f41156cc823ad0af06f986c0d1828e23d37618f2 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 9 Aug 2024 17:27:01 -0700 Subject: [PATCH] Re-enable dequant save load test Summary: Fixes: https://github.com/pytorch/ao/issues/594 Test Plan: python test/integration/test_integration.py -k test_save_load_dqtensors Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 4c911f18da..9d904a4474 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1022,7 +1022,6 @@ def forward(self, x): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(is_fbcode(), "'PlainAQTLayout' object has no attribute 'int_data'") - @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Can't save local lambda function for tensor subclass") @torch.no_grad() def test_save_load_dqtensors(self, device, dtype): if device == "cpu": @@ -1226,7 +1225,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): self.skipTest(f"bfloat16 requires sm80+") if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") - # This test fails on v0.4.0 and torch 2.4, so skipping for now. + # This test fails on v0.4.0 and torch 2.4, so skipping for now. if m1 == 1 or m2 == 1 and not TORCH_VERSION_AFTER_2_5: self.skipTest(f"Shape {(m1, m2, k, n)} requires torch version > 2.4") model = torch.nn.Sequential( @@ -1299,7 +1298,7 @@ def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): self.skipTest(f"bfloat16 requires sm80+") if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") - # This test fails on v0.4.0 and torch 2.4, so skipping for now. + # This test fails on v0.4.0 and torch 2.4, so skipping for now. if m1 == 1 or m2 == 1 and not TORCH_VERSION_AFTER_2_5: self.skipTest(f"Shape {(m1, m2, k, n)} requires torch version > 2.4")