diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 0bf0fe4d8c..faf55366ec 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -12,12 +12,17 @@ from torch.testing._internal import common_utils from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout +from torchao.quantization import ( + Float8DynamicActivationFloat8SemiSparseWeightConfig, + Float8DynamicActivationFloat8WeightConfig, +) from torchao.quantization.quant_api import ( int4_weight_only, int8_dynamic_activation_int8_weight, quantize_, ) from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ +from torchao.utils import is_sm_at_least_90 logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO @@ -121,6 +126,69 @@ def test_sparse_marlin(self, compile): torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1) + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @common_utils.parametrize("compile", [True, False]) + def test_fp8_cutlass_sparse(self, compile): + input = torch.rand((256, 256)).half().cuda() + model = ( + nn.Sequential( + nn.Linear(256, 1024), + nn.Linear(1024, 256), + ) + .half() + .cuda() + .eval() + ) + + apply_fake_sparsity(model) + model_copy = copy.deepcopy(model) + + # Quantized + quantize_(model_copy.bfloat16(), Float8DynamicActivationFloat8WeightConfig()) + dense_result = model_copy(input.bfloat16()).half() + + # Sparse + quantized + quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig()) + if compile: + model = torch.compile(model) + sparse_result = model(input) + + torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1) + + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_fp8_cutlass_sparse_lowering_op_clone(self): + with torch.inference_mode(): + model = nn.Linear(256, 1024).half().cuda().eval() + apply_fake_sparsity(model) + quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig()) + + original = model.weight.original_weight_tensor.tensor_impl.get_plain() + cloned = model.weight.original_weight_tensor.tensor_impl.clone().get_plain() + + for o, c in zip(original, cloned): + torch.testing.assert_close(o, c, atol=0.0, rtol=0.0) + + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_fp8_cutlass_sparse_lowering_op_to(self): + # Need to run with inference mode to avoid dispatching to `aten.to_copy` + with torch.inference_mode(): + model = nn.Linear(256, 1024).half().cuda().eval() + apply_fake_sparsity(model) + model_copy = copy.deepcopy(model) + expected = model_copy.weight.to(dtype=torch.float) + + quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig()) + + original = torch.ops.aten.to.dtype_layout( + model.weight.original_weight_tensor.tensor_impl, + dtype=torch.float, + layout=torch.strided, + ) + torch.testing.assert_close(expected, original, atol=1e-1, rtol=1e-1) + class TestBlockSparseWeight(common_utils.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") diff --git a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py index 45fe451712..35e6a83656 100644 --- a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py +++ b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py @@ -100,6 +100,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs): raise ValueError( f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) + elif func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + elif func is aten.to.dtype_layout: + dense, scale, _ = args[0].get_plain() + dense = dense.to( + *args[1:], + dtype=kwargs.get("dtype", dense.dtype), + device=kwargs.get("device", dense.device), + ) + return scale * dense raise NotImplementedError( f"CutlassSemiSparseTensorImpl dispatch: attempting to run {func}, this is not supported"