diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index e602210ee5..dd07c31172 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -39,6 +39,8 @@ def setUp(self): .half() .cuda() ) + for param in self.model.parameters(): + param.requires_grad = False @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @skip_if_rocm("ROCm enablement in progress")