Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@
from torch.testing._internal import common_utils

from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout
from torchao.quantization import (
Float8DynamicActivationFloat8SemiSparseWeightConfig,
Float8DynamicActivationFloat8WeightConfig,
)
from torchao.quantization.quant_api import (
int4_weight_only,
int8_dynamic_activation_int8_weight,
quantize_,
)
from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_
from torchao.utils import is_sm_at_least_90

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
Expand Down Expand Up @@ -121,6 +126,69 @@ def test_sparse_marlin(self, compile):

torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1)

@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("compile", [True, False])
def test_fp8_cutlass_sparse(self, compile):
input = torch.rand((256, 256)).half().cuda()
model = (
nn.Sequential(
nn.Linear(256, 1024),
nn.Linear(1024, 256),
)
.half()
.cuda()
.eval()
)

apply_fake_sparsity(model)
model_copy = copy.deepcopy(model)

# Quantized
quantize_(model_copy.bfloat16(), Float8DynamicActivationFloat8WeightConfig())
dense_result = model_copy(input.bfloat16()).half()

# Sparse + quantized
quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig())
if compile:
model = torch.compile(model)
sparse_result = model(input)

torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1)

@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_fp8_cutlass_sparse_lowering_op_clone(self):
with torch.inference_mode():
model = nn.Linear(256, 1024).half().cuda().eval()
apply_fake_sparsity(model)
quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig())

original = model.weight.original_weight_tensor.tensor_impl.get_plain()
cloned = model.weight.original_weight_tensor.tensor_impl.clone().get_plain()

for o, c in zip(original, cloned):
torch.testing.assert_close(o, c, atol=0.0, rtol=0.0)

@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_fp8_cutlass_sparse_lowering_op_to(self):
# Need to run with inference mode to avoid dispatching to `aten.to_copy`
with torch.inference_mode():
model = nn.Linear(256, 1024).half().cuda().eval()
apply_fake_sparsity(model)
model_copy = copy.deepcopy(model)
expected = model_copy.weight.to(dtype=torch.float)

quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig())

original = torch.ops.aten.to.dtype_layout(
model.weight.original_weight_tensor.tensor_impl,
dtype=torch.float,
layout=torch.strided,
)
torch.testing.assert_close(expected, original, atol=1e-1, rtol=1e-1)


class TestBlockSparseWeight(common_utils.TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
12 changes: 12 additions & 0 deletions torchao/dtypes/floatx/cutlass_semi_sparse_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)
elif func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)
elif func is aten.to.dtype_layout:
dense, scale, _ = args[0].get_plain()
dense = dense.to(
*args[1:],
dtype=kwargs.get("dtype", dense.dtype),
device=kwargs.get("device", dense.device),
)
return scale * dense

raise NotImplementedError(
f"CutlassSemiSparseTensorImpl dispatch: attempting to run {func}, this is not supported"
Expand Down
Loading