diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 9f77ad020d..4871b48849 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -80,6 +80,7 @@ def forward(self, x): class TestFloat8Tensor(TorchAOIntegrationTestCase): def setUp(self): self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] + torch.set_grad_enabled(False) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(