From 5b0719fc026393bff7d5da5e4b06f6758d6e0aec Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 16 Sep 2025 14:19:11 -0700 Subject: [PATCH 1/6] [sparse] Add in missing op support for FP8 Sparse Summary: For ads, we are missing some op support in their lowering stack, namely `.to(dtype=torch.float)` and `.clone()` This PR adds in op support for the `CutlassSemiSparseLayout`. Test Plan: ``` python test/test_sparse_api -k lowering ``` Reviewers: Subscribers: Tasks: Tags: --- test/sparsity/test_sparse_api.py | 62 +++++++++++++++++++ .../floatx/cutlass_semi_sparse_layout.py | 6 ++ 2 files changed, 68 insertions(+) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 0bf0fe4d8c..b6091d0966 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -23,6 +23,12 @@ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) +from torchao.quantization import ( + Float8DynamicActivationFloat8SemiSparseWeightConfig, + Float8DynamicActivationFloat8WeightConfig, + quantize_, +) + class TestSemiStructuredSparse(common_utils.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -121,6 +127,62 @@ def test_sparse_marlin(self, compile): torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @common_utils.parametrize("compile", [True, False]) + def test_fp8_cutlass_sparse(self, compile): + input = torch.rand((256, 256)).half().cuda() + model = ( + nn.Sequential( + nn.Linear(256, 1024), + nn.Linear(1024, 256), + ) + .half() + .cuda() + .eval() + ) + + apply_fake_sparsity(model) + model_copy = copy.deepcopy(model) + + # Quantized + quantize_(model_copy.bfloat16(), Float8DynamicActivationFloat8WeightConfig()) + dense_result = model_copy(input.bfloat16()).half() + + # Sparse + quantized + quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig()) + if compile: + model = torch.compile(model) + sparse_result = model(input) + + torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1) + + def test_fp8_cutlass_sparse_lowering_op_clone(self): + with torch.inference_mode(): + model = nn.Linear(256, 1024).half().cuda().eval() + apply_fake_sparsity(model) + quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig()) + + original = model.weight.original_weight_tensor.tensor_impl.get_plain() + cloned = model.weight.original_weight_tensor.tensor_impl.clone().get_plain() + + for o, c in zip(original, cloned): + torch.testing.assert_close(o, c, atol=0.0, rtol=0.0) + + def test_fp8_cutlass_sparse_lowering_op_to(self): + # Need to run with inference mode to avoid dispatching to `aten.to_copy` + with torch.inference_mode(): + model = nn.Linear(256, 1024).half().cuda().eval() + apply_fake_sparsity(model) + model_copy = copy.deepcopy(model) + expected = model_copy.weight.to(dtype=torch.float) + + quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig()) + + original = torch.ops.aten.to.dtype_layout( + model.weight.original_weight_tensor.tensor_impl, dtype=torch.float + ) + torch.testing.assert_close(expected, original, atol=1e-1, rtol=1e-1) + class TestBlockSparseWeight(common_utils.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") diff --git a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py index 45fe451712..4ac19eb7e0 100644 --- a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py +++ b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py @@ -100,6 +100,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): raise ValueError( f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) + elif func is aten.clone.default: + return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)) + elif func is aten.to.dtype_layout: + dense, scale, _ = args[0].get_plain() + dense = dense.to(*args[1:], **kwargs) + return (scale * dense) raise NotImplementedError( f"CutlassSemiSparseTensorImpl dispatch: attempting to run {func}, this is not supported" From 8855420cb81d24c9831f33547d20ed35f8902d05 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 16 Sep 2025 14:24:19 -0700 Subject: [PATCH 2/6] update --- test/sparsity/test_sparse_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index b6091d0966..9527525a25 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -26,7 +26,6 @@ from torchao.quantization import ( Float8DynamicActivationFloat8SemiSparseWeightConfig, Float8DynamicActivationFloat8WeightConfig, - quantize_, ) From 86bf380697c1e9bc1e7c3d271eb7bf175848ddcc Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 16 Sep 2025 14:25:35 -0700 Subject: [PATCH 3/6] ruff fix --- torchao/dtypes/floatx/cutlass_semi_sparse_layout.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py index 4ac19eb7e0..9074bd8c79 100644 --- a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py +++ b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py @@ -101,11 +101,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) elif func is aten.clone.default: - return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)) + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) elif func is aten.to.dtype_layout: dense, scale, _ = args[0].get_plain() dense = dense.to(*args[1:], **kwargs) - return (scale * dense) + return scale * dense raise NotImplementedError( f"CutlassSemiSparseTensorImpl dispatch: attempting to run {func}, this is not supported" From 8f3bdcd23164ce712bfbf0b8de1ae438b4e5b82b Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 16 Sep 2025 14:52:39 -0700 Subject: [PATCH 4/6] update tests --- test/sparsity/test_sparse_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 9527525a25..be183f3f3d 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -155,6 +155,7 @@ def test_fp8_cutlass_sparse(self, compile): torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_fp8_cutlass_sparse_lowering_op_clone(self): with torch.inference_mode(): model = nn.Linear(256, 1024).half().cuda().eval() @@ -167,6 +168,7 @@ def test_fp8_cutlass_sparse_lowering_op_clone(self): for o, c in zip(original, cloned): torch.testing.assert_close(o, c, atol=0.0, rtol=0.0) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_fp8_cutlass_sparse_lowering_op_to(self): # Need to run with inference mode to avoid dispatching to `aten.to_copy` with torch.inference_mode(): From f3705e5397d39110fbfdee22bf7aacd85ad60e7a Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 16 Sep 2025 15:44:53 -0700 Subject: [PATCH 5/6] fix test to add in layout kwarg --- test/sparsity/test_sparse_api.py | 4 +++- torchao/dtypes/floatx/cutlass_semi_sparse_layout.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index be183f3f3d..d3ab97c2db 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -180,7 +180,9 @@ def test_fp8_cutlass_sparse_lowering_op_to(self): quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig()) original = torch.ops.aten.to.dtype_layout( - model.weight.original_weight_tensor.tensor_impl, dtype=torch.float + model.weight.original_weight_tensor.tensor_impl, + dtype=torch.float, + layout=torch.strided, ) torch.testing.assert_close(expected, original, atol=1e-1, rtol=1e-1) diff --git a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py index 9074bd8c79..35e6a83656 100644 --- a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py +++ b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py @@ -106,7 +106,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif func is aten.to.dtype_layout: dense, scale, _ = args[0].get_plain() - dense = dense.to(*args[1:], **kwargs) + dense = dense.to( + *args[1:], + dtype=kwargs.get("dtype", dense.dtype), + device=kwargs.get("device", dense.device), + ) return scale * dense raise NotImplementedError( From 2cb3eb7fd2b6845c1f38c99f87364097b4f80434 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 17 Sep 2025 09:08:25 -0700 Subject: [PATCH 6/6] skip non h100 --- test/sparsity/test_sparse_api.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index d3ab97c2db..faf55366ec 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -12,22 +12,22 @@ from torch.testing._internal import common_utils from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout +from torchao.quantization import ( + Float8DynamicActivationFloat8SemiSparseWeightConfig, + Float8DynamicActivationFloat8WeightConfig, +) from torchao.quantization.quant_api import ( int4_weight_only, int8_dynamic_activation_int8_weight, quantize_, ) from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ +from torchao.utils import is_sm_at_least_90 logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) -from torchao.quantization import ( - Float8DynamicActivationFloat8SemiSparseWeightConfig, - Float8DynamicActivationFloat8WeightConfig, -) - class TestSemiStructuredSparse(common_utils.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -126,6 +126,7 @@ def test_sparse_marlin(self, compile): torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1) + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [True, False]) def test_fp8_cutlass_sparse(self, compile): @@ -155,6 +156,7 @@ def test_fp8_cutlass_sparse(self, compile): torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1) + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_fp8_cutlass_sparse_lowering_op_clone(self): with torch.inference_mode(): @@ -168,6 +170,7 @@ def test_fp8_cutlass_sparse_lowering_op_clone(self): for o, c in zip(original, cloned): torch.testing.assert_close(o, c, atol=0.0, rtol=0.0) + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_fp8_cutlass_sparse_lowering_op_to(self): # Need to run with inference mode to avoid dispatching to `aten.to_copy`