diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 3c5414ceac..5f34b761cd 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1104,7 +1104,6 @@ def test_weight_only_quant(self): @parameterized.expand(COMMON_DEVICE_DTYPE) @torch.no_grad() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skip("This test is flaky, we'll enable later") def test_weight_only_quant_force_mixed_mm(self, device, dtype): if device != "cuda": self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}") @@ -1127,7 +1126,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): sqnr = compute_error(y_ref, y_wo) self.assertGreaterEqual(sqnr, 42.75) if device == "cuda": - self.assertTrue("mixed_mm" in code) + self.assertTrue("mixed_mm" in code, f"got code: {code}") @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")