Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def get_extensions():
"to_sparse_semi_structured_cutlass_sm9x_f8.cu",
),
os.path.join(extensions_cuda_dir, "activation24", "sparsify24.cu"),
os.path.join(extensions_cuda_dir, "activation24", "sparse_gemm.cu"),
]
for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]:
cutlass_90a_sources.append(
Expand Down
66 changes: 66 additions & 0 deletions test/sparsity/test_activation24.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
PerRow,
quantize_,
)
from torchao.quantization.quant_api import _float8_cutlass_quant

torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = True

Expand Down Expand Up @@ -141,3 +142,68 @@ def srelu_linear(x):
custom_output = reference_linear_copy(input_tensor)

torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01)


@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90")
def test_sparse24_fp8_sm90_cutlass_gemm_eye(
M=512, K=256, dtype=torch.float8_e4m3fn
) -> None:
torch.manual_seed(0)

A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda()
A_aqt = _float8_cutlass_quant(A_dense, dtype)
A = A_aqt.tensor_impl.float8_data

# NOTE: CUTLASS compression kernel expects the input to be *exactly*
# 2:4 sparse already (eg it does not select the largest values)
A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A)
assert torch.allclose(
A_packed.float().sum(), A.float().sum()
) # Check all values are there

# Check MM without scale
eye = torch.eye(A.shape[1], device=A.device, dtype=A.dtype).T
A_reconstructed = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm(
A_packed, A_mdata, eye
)
assert torch.allclose(A.float(), A_reconstructed.float())

# Check MM with scale
b_scale = torch.randn([1, A.shape[1]], device=eye.device, dtype=torch.float32)
a_scale = torch.randn([A.shape[0], 1], device=eye.device, dtype=torch.float32)
A_reconstructed = torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm(
A_packed, A_mdata, eye, a_scale=a_scale, b_scale=b_scale
)
assert torch.allclose(
A.float() * b_scale * a_scale, A_reconstructed.float(), rtol=0.01
)


@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90")
def test_sparse24_fp8_sm90_cutlass_gemm_random_tensor(
M=512, N=1024, K=256, dtype=torch.float8_e4m3fn
) -> None:
def _to_fp8_rowwise(x: torch.Tensor, dtype):
max_v = torch.finfo(dtype).max
x_scale = (x.abs().max(1, keepdim=True)[0] / max_v).float()
x = (x / x_scale).to(dtype)
return x, x_scale

torch.manual_seed(0)
A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda()
A, a_scale = _to_fp8_rowwise(A_dense, dtype)

B_dense = torch.randn([N, K], device="cuda", dtype=torch.bfloat16)
B, b_scale = _to_fp8_rowwise(B_dense, dtype)

B = B.T
b_scale = b_scale.T

A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A)
out_sparse = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm(
A_packed, A_mdata, B, a_scale=a_scale, b_scale=b_scale
)
out_ref = torch._scaled_mm(
A, B, scale_a=a_scale, scale_b=b_scale, out_dtype=out_sparse.dtype
)
assert torch.allclose(out_sparse, out_ref, rtol=0.01, atol=0.01)
Loading
Loading