From bc67c1a2b27d2760838e7c37c54c5996b977943c Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 00:49:49 +0000 Subject: [PATCH 01/16] Move block_sparse_layout to prototype --- docs/source/api_ref_dtypes.rst | 1 - docs/source/api_ref_sparsity.rst | 10 + torchao/dtypes/__init__.py | 5 +- torchao/dtypes/affine_quantized_tensor_ops.py | 2 +- torchao/dtypes/uintx/__init__.py | 4 - torchao/dtypes/uintx/block_sparse_layout.py | 243 ++---------------- torchao/prototype/sparsity/__init__.py | 6 + .../prototype/sparsity/block_sparse_layout.py | 233 +++++++++++++++++ 8 files changed, 274 insertions(+), 230 deletions(-) create mode 100644 torchao/prototype/sparsity/block_sparse_layout.py diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 6cbec7465e..37e1407435 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -22,7 +22,6 @@ Layouts and Tensor Subclasses FloatxTensor FloatxTensorCoreLayout MarlinSparseLayout - BlockSparseLayout UintxLayout MarlinQQQTensor MarlinQQQLayout diff --git a/docs/source/api_ref_sparsity.rst b/docs/source/api_ref_sparsity.rst index 9fc6644683..acb4cd3fb6 100644 --- a/docs/source/api_ref_sparsity.rst +++ b/docs/source/api_ref_sparsity.rst @@ -15,3 +15,13 @@ torchao.sparsity apply_fake_sparsity WandaSparsifier PerChannelNormObserver + +Prototype +--------- +.. currentmodule:: torchao.prototype.sparsity + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + BlockSparseLayout diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 07f03c7ed9..8033a4de66 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -13,8 +13,11 @@ Float8Layout, ) from .nf4tensor import NF4Tensor, to_nf4 + +# Import BlockSparseLayout from prototype for backward compatibility +from torchao.prototype.sparsity import BlockSparseLayout + from .uintx import ( - BlockSparseLayout, CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index ffadece729..47f285a121 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,7 +25,7 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) -from torchao.dtypes.uintx.block_sparse_layout import ( +from torchao.prototype.sparsity.block_sparse_layout import ( _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, ) diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 6d1bc95653..1d269fc4c4 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,6 +1,3 @@ -from .block_sparse_layout import ( - BlockSparseLayout, -) from .cutlass_int4_packed_layout import ( CutlassInt4PackedLayout, ) @@ -39,7 +36,6 @@ __all__ = [ "UintxLayout", - "BlockSparseLayout", "MarlinSparseLayout", "SemiSparseLayout", "TensorCoreTiledLayout", diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index 0c6046c313..7d6b76ac95 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -3,231 +3,28 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import logging -from dataclasses import dataclass -from typing import Optional, Tuple -import torch -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, -) +# Backward compatibility stub - imports from the new location +import warnings -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, -) -from torchao.dtypes.uintx.plain_layout import ( - PlainAQTTensorImpl, - _aqt_is_int8_reduced_range, +warnings.warn( + "Importing from torchao.dtypes.uintx.block_sparse_layout is deprecated. " + "Please use 'from torchao.prototype.sparsity import BlockSparseLayout' instead. " + "This import path will be removed in torchao v0.16.0.", + DeprecationWarning, + stacklevel=2, ) -from torchao.dtypes.utils import ( - Layout, - PlainLayout, -) - -logger = logging.getLogger(__name__) - -aten = torch.ops.aten - - -@dataclass(frozen=True) -class BlockSparseLayout(Layout): - """BlockSparseLayout is a data class that represents the layout of a block sparse matrix. - - Attributes: - blocksize (int): The size of the blocks in the sparse matrix. Default is 64. - """ - - blocksize: int = 64 - - -@register_layout(BlockSparseLayout) -class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): - bsr_crow_indices: Optional[torch.Tensor] - bsr_col_indices: Optional[torch.Tensor] - bsr_values: Optional[torch.Tensor] - scale: Optional[torch.Tensor] - zero_point: Optional[torch.Tensor] - - __slots__ = [ - "bsr_crow_indices", - "bsr_col_indices", - "bsr_values", - "scale", - "zero_point", - ] - - @staticmethod - def __new__( # noqa: PYI034 - cls, - shape: torch.Size, - bsr_crow_indices: Optional[torch.Tensor], - bsr_col_indices: Optional[torch.Tensor], - bsr_values: Optional[torch.Tensor], - scale: Optional[torch.Tensor], - zero_point: Optional[torch.Tensor], - _layout: Layout, - requires_grad: bool = False, - ): - if bsr_values is None: - raise ValueError("bsr values must be provided!") - else: - previous_tensor = bsr_values - - kwargs = { - "device": previous_tensor.device, - "dtype": previous_tensor.dtype, - "layout": previous_tensor.layout, - "requires_grad": requires_grad, - } - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( # noqa: PYI034 - self, - shape: torch.Size, - bsr_crow_indices: Optional[torch.Tensor], - bsr_col_indices: Optional[torch.Tensor], - bsr_values: Optional[torch.Tensor], - scale: Optional[torch.Tensor], - zero_point: Optional[torch.Tensor], - _layout: Layout, - requires_grad: bool = False, - ): - self.bsr_crow_indices = bsr_crow_indices - self.bsr_col_indices = bsr_col_indices - self.bsr_values = bsr_values - self.scale = scale - self.zero_point = zero_point - self._layout = _layout - - def __tensor_flatten__(self): - inner_tensors = list( - filter(lambda x: getattr(self, x) is not None, self.__slots__) - ) - tensor_meta = (self.shape, self._layout, self.requires_grad) - return inner_tensors, tensor_meta - @classmethod - def __tensor_unflatten__( - cls, - inner_tensors, - tensor_meta: Tuple[torch.Size, bool], - outer_size, - outer_stride, - ) -> torch.Tensor: - shape, _layout, requires_grad = tensor_meta - return cls( - shape=shape, - bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), - bsr_col_indices=inner_tensors.get("bsr_col_indices", None), - bsr_values=inner_tensors.get("bsr_values", None), - scale=inner_tensors.get("scale", None), - zero_point=inner_tensors.get("zero_point", None), - _layout=_layout, - requires_grad=requires_grad, - ) - - @classmethod - def from_plain(cls, int_data, scale, zero_point, _layout): - bsr_tensor = int_data.to_sparse_bsr(_layout.blocksize) - return cls( - shape=int_data.shape, - bsr_crow_indices=bsr_tensor.crow_indices(), - bsr_col_indices=bsr_tensor.col_indices(), - bsr_values=bsr_tensor.values(), - scale=scale, - zero_point=zero_point, - _layout=_layout, - requires_grad=False, - ) - - def get_plain(self): - int_data_expanded = torch.ops.blocksparse.bsr_to_dense( - self.crow_indices(), - self.col_indices(), - self.values(), - self.shape[0], - self.shape[1], - ) - return int_data_expanded, self.scale, self.zero_point - - def _apply_fn_to_data(self, func): - return self.__class__( - shape=self.shape, - bsr_crow_indices=func(self.bsr_crow_indices), - bsr_col_indices=func(self.bsr_col_indices), - bsr_values=func(self.bsr_values), - scale=self.scale, - zero_point=self.zero_point, - _layout=self._layout, - requires_grad=self.requires_grad, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - # Need the following for bsr specific functions - if func is aten.crow_indices.default: - return args[0].bsr_crow_indices.detach() - - if func is aten.col_indices.default: - return args[0].bsr_col_indices.detach() - - if func is aten.values.default: - return args[0].bsr_values.detach() - - if func is aten._nnz.default: - return args[0].bsr_values.shape[0] - - raise NotImplementedError( - f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - -def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) - and _aqt_is_int8_reduced_range(input_tensor) - and isinstance(weight_tensor, AffineQuantizedTensor) - and weight_tensor.is_cuda - and input_tensor.dtype == weight_tensor.dtype - and isinstance(input_tensor._layout, PlainLayout) - and isinstance(weight_tensor._layout, BlockSparseLayout) - ) - - -def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals = weight_tensor.tensor_impl - w_scales = weight_tensor.tensor_impl.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - tmp_t = tmp.t() - - y = torch.ops.blocksparse.int_addmm( - w_vals.crow_indices(), - w_vals.col_indices(), - w_vals.values(), - tmp_t, - w_scales, - x_scales.reshape(-1), - ) - y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) - y = y.reshape(*y_shape) +from torchao.prototype.sparsity.block_sparse_layout import ( + BlockSparseLayout, + BlockSparseAQTTensorImpl, + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, +) - # can downcast only at the very end - output_dtype = input_tensor.dtype - y = y.to(output_dtype) - if bias is not None: - y += bias - return y +__all__ = [ + "BlockSparseLayout", + "BlockSparseAQTTensorImpl", + "_linear_int8_act_int8_weight_block_sparse_check", + "_linear_int8_act_int8_weight_block_sparse_impl", +] diff --git a/torchao/prototype/sparsity/__init__.py b/torchao/prototype/sparsity/__init__.py index 821e5049e0..403458ceae 100644 --- a/torchao/prototype/sparsity/__init__.py +++ b/torchao/prototype/sparsity/__init__.py @@ -19,6 +19,11 @@ WeightNormSparsifier, ) +# Block Sparse Layout +from torchao.prototype.sparsity.block_sparse_layout import ( + BlockSparseLayout, +) + __all__ = [ "BaseScheduler", "CubicSL", @@ -30,4 +35,5 @@ "get_arg_info_from_tensor_fqn", "module_to_fqn", "WeightNormSparsifier", + "BlockSparseLayout", ] diff --git a/torchao/prototype/sparsity/block_sparse_layout.py b/torchao/prototype/sparsity/block_sparse_layout.py new file mode 100644 index 0000000000..0c6046c313 --- /dev/null +++ b/torchao/prototype/sparsity/block_sparse_layout.py @@ -0,0 +1,233 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import logging +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.uintx.plain_layout import ( + PlainAQTTensorImpl, + _aqt_is_int8_reduced_range, +) +from torchao.dtypes.utils import ( + Layout, + PlainLayout, +) + +logger = logging.getLogger(__name__) + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class BlockSparseLayout(Layout): + """BlockSparseLayout is a data class that represents the layout of a block sparse matrix. + + Attributes: + blocksize (int): The size of the blocks in the sparse matrix. Default is 64. + """ + + blocksize: int = 64 + + +@register_layout(BlockSparseLayout) +class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): + bsr_crow_indices: Optional[torch.Tensor] + bsr_col_indices: Optional[torch.Tensor] + bsr_values: Optional[torch.Tensor] + scale: Optional[torch.Tensor] + zero_point: Optional[torch.Tensor] + + __slots__ = [ + "bsr_crow_indices", + "bsr_col_indices", + "bsr_values", + "scale", + "zero_point", + ] + + @staticmethod + def __new__( # noqa: PYI034 + cls, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + _layout: Layout, + requires_grad: bool = False, + ): + if bsr_values is None: + raise ValueError("bsr values must be provided!") + else: + previous_tensor = bsr_values + + kwargs = { + "device": previous_tensor.device, + "dtype": previous_tensor.dtype, + "layout": previous_tensor.layout, + "requires_grad": requires_grad, + } + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( # noqa: PYI034 + self, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + _layout: Layout, + requires_grad: bool = False, + ): + self.bsr_crow_indices = bsr_crow_indices + self.bsr_col_indices = bsr_col_indices + self.bsr_values = bsr_values + self.scale = scale + self.zero_point = zero_point + self._layout = _layout + + def __tensor_flatten__(self): + inner_tensors = list( + filter(lambda x: getattr(self, x) is not None, self.__slots__) + ) + tensor_meta = (self.shape, self._layout, self.requires_grad) + return inner_tensors, tensor_meta + + @classmethod + def __tensor_unflatten__( + cls, + inner_tensors, + tensor_meta: Tuple[torch.Size, bool], + outer_size, + outer_stride, + ) -> torch.Tensor: + shape, _layout, requires_grad = tensor_meta + return cls( + shape=shape, + bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), + bsr_col_indices=inner_tensors.get("bsr_col_indices", None), + bsr_values=inner_tensors.get("bsr_values", None), + scale=inner_tensors.get("scale", None), + zero_point=inner_tensors.get("zero_point", None), + _layout=_layout, + requires_grad=requires_grad, + ) + + @classmethod + def from_plain(cls, int_data, scale, zero_point, _layout): + bsr_tensor = int_data.to_sparse_bsr(_layout.blocksize) + return cls( + shape=int_data.shape, + bsr_crow_indices=bsr_tensor.crow_indices(), + bsr_col_indices=bsr_tensor.col_indices(), + bsr_values=bsr_tensor.values(), + scale=scale, + zero_point=zero_point, + _layout=_layout, + requires_grad=False, + ) + + def get_plain(self): + int_data_expanded = torch.ops.blocksparse.bsr_to_dense( + self.crow_indices(), + self.col_indices(), + self.values(), + self.shape[0], + self.shape[1], + ) + return int_data_expanded, self.scale, self.zero_point + + def _apply_fn_to_data(self, func): + return self.__class__( + shape=self.shape, + bsr_crow_indices=func(self.bsr_crow_indices), + bsr_col_indices=func(self.bsr_col_indices), + bsr_values=func(self.bsr_values), + scale=self.scale, + zero_point=self.zero_point, + _layout=self._layout, + requires_grad=self.requires_grad, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + # Need the following for bsr specific functions + if func is aten.crow_indices.default: + return args[0].bsr_crow_indices.detach() + + if func is aten.col_indices.default: + return args[0].bsr_col_indices.detach() + + if func is aten.values.default: + return args[0].bsr_values.detach() + + if func is aten._nnz.default: + return args[0].bsr_values.shape[0] + + raise NotImplementedError( + f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + +def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.is_cuda + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, BlockSparseLayout) + ) + + +def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals = weight_tensor.tensor_impl + w_scales = weight_tensor.tensor_impl.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + tmp_t = tmp.t() + + y = torch.ops.blocksparse.int_addmm( + w_vals.crow_indices(), + w_vals.col_indices(), + w_vals.values(), + tmp_t, + w_scales, + x_scales.reshape(-1), + ) + y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) + y = y.reshape(*y_shape) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y From dd2e7d613c8b3af0b9eeb05a0b15fd1867bcc795 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 04:03:24 +0000 Subject: [PATCH 02/16] test fixes --- test/sparsity/test_sparse_api.py | 2 +- torchao/dtypes/__init__.py | 4 +--- torchao/dtypes/uintx/__init__.py | 4 ++++ torchao/dtypes/uintx/block_sparse_layout.py | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 003a50c4d1..4bfcb5bb0a 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -253,7 +253,7 @@ def test_sparse(self, compile): quantize_(model_copy, Int8DynamicActivationInt8WeightConfig()) reference = model_copy(input) - from torchao.dtypes import BlockSparseLayout + from torchao.prototype.sparsity import BlockSparseLayout quantize_( model, diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 8033a4de66..d2042d1fb9 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -14,10 +14,8 @@ ) from .nf4tensor import NF4Tensor, to_nf4 -# Import BlockSparseLayout from prototype for backward compatibility -from torchao.prototype.sparsity import BlockSparseLayout - from .uintx import ( + BlockSparseLayout, CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 1d269fc4c4..6d1bc95653 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,3 +1,6 @@ +from .block_sparse_layout import ( + BlockSparseLayout, +) from .cutlass_int4_packed_layout import ( CutlassInt4PackedLayout, ) @@ -36,6 +39,7 @@ __all__ = [ "UintxLayout", + "BlockSparseLayout", "MarlinSparseLayout", "SemiSparseLayout", "TensorCoreTiledLayout", diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index 7d6b76ac95..edff0b0493 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -8,7 +8,7 @@ import warnings warnings.warn( - "Importing from torchao.dtypes.uintx.block_sparse_layout is deprecated. " + "Importing BlockSparseLayout from torchao.dtypes is deprecated. " "Please use 'from torchao.prototype.sparsity import BlockSparseLayout' instead. " "This import path will be removed in torchao v0.16.0.", DeprecationWarning, From efa74cdf1b5aa84384f0ddc983441fca32fcdd0c Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Sun, 2 Nov 2025 20:07:02 -0800 Subject: [PATCH 03/16] Remove unused import in __init__.py --- torchao/dtypes/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index d2042d1fb9..07f03c7ed9 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -13,7 +13,6 @@ Float8Layout, ) from .nf4tensor import NF4Tensor, to_nf4 - from .uintx import ( BlockSparseLayout, CutlassInt4PackedLayout, From 37225378308bdf3aedb4aca5309520bf2de34fa4 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Sun, 2 Nov 2025 20:08:10 -0800 Subject: [PATCH 04/16] Clean up exports in block_sparse_layout.py Removed unused exports from block_sparse_layout.py --- torchao/dtypes/uintx/block_sparse_layout.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index edff0b0493..137f13d56f 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -21,10 +21,3 @@ _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, ) - -__all__ = [ - "BlockSparseLayout", - "BlockSparseAQTTensorImpl", - "_linear_int8_act_int8_weight_block_sparse_check", - "_linear_int8_act_int8_weight_block_sparse_impl", -] From 6c1b8ef35018107d692231858577e9f40a9e7ea1 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 05:06:41 +0000 Subject: [PATCH 05/16] test fixes --- torchao/dtypes/__init__.py | 5 ++++- torchao/dtypes/uintx/__init__.py | 4 ---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 07f03c7ed9..8033a4de66 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -13,8 +13,11 @@ Float8Layout, ) from .nf4tensor import NF4Tensor, to_nf4 + +# Import BlockSparseLayout from prototype for backward compatibility +from torchao.prototype.sparsity import BlockSparseLayout + from .uintx import ( - BlockSparseLayout, CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 6d1bc95653..1d269fc4c4 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,6 +1,3 @@ -from .block_sparse_layout import ( - BlockSparseLayout, -) from .cutlass_int4_packed_layout import ( CutlassInt4PackedLayout, ) @@ -39,7 +36,6 @@ __all__ = [ "UintxLayout", - "BlockSparseLayout", "MarlinSparseLayout", "SemiSparseLayout", "TensorCoreTiledLayout", From 1702bdf932fb76b9d38f2ec1457bbee5236bf8e4 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 05:13:05 +0000 Subject: [PATCH 06/16] Fix ruff import sorting and add noqa for re-exported imports --- torchao/dtypes/__init__.py | 7 +++---- torchao/dtypes/affine_quantized_tensor_ops.py | 8 ++++---- torchao/dtypes/uintx/block_sparse_layout.py | 8 ++++---- torchao/prototype/sparsity/__init__.py | 9 ++++----- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 8033a4de66..33a22b4925 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,3 +1,6 @@ +# Import BlockSparseLayout from prototype for backward compatibility +from torchao.prototype.sparsity import BlockSparseLayout + from . import affine_quantized_tensor_ops from .affine_quantized_tensor import ( AffineQuantizedTensor, @@ -13,10 +16,6 @@ Float8Layout, ) from .nf4tensor import NF4Tensor, to_nf4 - -# Import BlockSparseLayout from prototype for backward compatibility -from torchao.prototype.sparsity import BlockSparseLayout - from .uintx import ( CutlassInt4PackedLayout, Int4CPULayout, diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 47f285a121..8f6305d0c0 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,10 +25,6 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) -from torchao.prototype.sparsity.block_sparse_layout import ( - _linear_int8_act_int8_weight_block_sparse_check, - _linear_int8_act_int8_weight_block_sparse_impl, -) from torchao.dtypes.uintx.cutlass_int4_packed_layout import ( _linear_int4_act_int4_weight_cutlass_check, _linear_int4_act_int4_weight_cutlass_impl, @@ -94,6 +90,10 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) +from torchao.prototype.sparsity.block_sparse_layout import ( + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, +) from torchao.quantization.quant_primitives import ( ZeroPointDomain, _dequantize_affine_no_zero_point, diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index 137f13d56f..ef162d3945 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -16,8 +16,8 @@ ) from torchao.prototype.sparsity.block_sparse_layout import ( - BlockSparseLayout, - BlockSparseAQTTensorImpl, - _linear_int8_act_int8_weight_block_sparse_check, - _linear_int8_act_int8_weight_block_sparse_impl, + BlockSparseAQTTensorImpl, # noqa: F401 + BlockSparseLayout, # noqa: F401 + _linear_int8_act_int8_weight_block_sparse_check, # noqa: F401 + _linear_int8_act_int8_weight_block_sparse_impl, # noqa: F401 ) diff --git a/torchao/prototype/sparsity/__init__.py b/torchao/prototype/sparsity/__init__.py index 403458ceae..20ea714577 100644 --- a/torchao/prototype/sparsity/__init__.py +++ b/torchao/prototype/sparsity/__init__.py @@ -1,5 +1,9 @@ # Sparsifier # Scheduler +# Block Sparse Layout +from torchao.prototype.sparsity.block_sparse_layout import ( + BlockSparseLayout, +) from torchao.prototype.sparsity.scheduler.base_scheduler import BaseScheduler from torchao.prototype.sparsity.scheduler.cubic_scheduler import CubicSL from torchao.prototype.sparsity.scheduler.lambda_scheduler import LambdaSL @@ -19,11 +23,6 @@ WeightNormSparsifier, ) -# Block Sparse Layout -from torchao.prototype.sparsity.block_sparse_layout import ( - BlockSparseLayout, -) - __all__ = [ "BaseScheduler", "CubicSL", From 18d682a61dd3f0d16ec7e4c1cba7d9de4cb1419a Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 05:29:19 +0000 Subject: [PATCH 07/16] Move block_sparse_layout to prototype/dtypes/uintx and update all imports - Move torchao/prototype/sparsity/block_sparse_layout.py to torchao/prototype/dtypes/uintx/block_sparse_layout.py - Update all imports to use 'from torchao.prototype.dtypes import BlockSparseLayout' - Update deprecation warning message - Create torchao/prototype/dtypes/__init__.py and torchao/prototype/dtypes/uintx/__init__.py - Remove BlockSparseLayout from torchao/prototype/sparsity/__init__.py - Update documentation --- docs/source/api_ref_sparsity.rst | 10 ---------- test/sparsity/test_sparse_api.py | 2 +- torchao/dtypes/__init__.py | 2 +- torchao/dtypes/affine_quantized_tensor_ops.py | 2 +- torchao/dtypes/uintx/block_sparse_layout.py | 4 ++-- torchao/prototype/dtypes/__init__.py | 12 ++++++++++++ torchao/prototype/dtypes/uintx/__init__.py | 12 ++++++++++++ .../uintx}/block_sparse_layout.py | 0 torchao/prototype/sparsity/__init__.py | 5 ----- 9 files changed, 29 insertions(+), 20 deletions(-) create mode 100644 torchao/prototype/dtypes/__init__.py create mode 100644 torchao/prototype/dtypes/uintx/__init__.py rename torchao/prototype/{sparsity => dtypes/uintx}/block_sparse_layout.py (100%) diff --git a/docs/source/api_ref_sparsity.rst b/docs/source/api_ref_sparsity.rst index acb4cd3fb6..9fc6644683 100644 --- a/docs/source/api_ref_sparsity.rst +++ b/docs/source/api_ref_sparsity.rst @@ -15,13 +15,3 @@ torchao.sparsity apply_fake_sparsity WandaSparsifier PerChannelNormObserver - -Prototype ---------- -.. currentmodule:: torchao.prototype.sparsity - -.. autosummary:: - :toctree: generated/ - :nosignatures: - - BlockSparseLayout diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 4bfcb5bb0a..66cd032a9a 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -253,7 +253,7 @@ def test_sparse(self, compile): quantize_(model_copy, Int8DynamicActivationInt8WeightConfig()) reference = model_copy(input) - from torchao.prototype.sparsity import BlockSparseLayout + from torchao.prototype.dtypes import BlockSparseLayout quantize_( model, diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 33a22b4925..ebf6c98553 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,5 +1,5 @@ # Import BlockSparseLayout from prototype for backward compatibility -from torchao.prototype.sparsity import BlockSparseLayout +from torchao.prototype.dtypes import BlockSparseLayout from . import affine_quantized_tensor_ops from .affine_quantized_tensor import ( diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 8f6305d0c0..2b6f47e692 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -90,7 +90,7 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) -from torchao.prototype.sparsity.block_sparse_layout import ( +from torchao.prototype.dtypes.uintx.block_sparse_layout import ( _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, ) diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index ef162d3945..2840a9b7fa 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -9,13 +9,13 @@ warnings.warn( "Importing BlockSparseLayout from torchao.dtypes is deprecated. " - "Please use 'from torchao.prototype.sparsity import BlockSparseLayout' instead. " + "Please use 'from torchao.prototype.dtypes import BlockSparseLayout' instead. " "This import path will be removed in torchao v0.16.0.", DeprecationWarning, stacklevel=2, ) -from torchao.prototype.sparsity.block_sparse_layout import ( +from torchao.prototype.dtypes.uintx.block_sparse_layout import ( BlockSparseAQTTensorImpl, # noqa: F401 BlockSparseLayout, # noqa: F401 _linear_int8_act_int8_weight_block_sparse_check, # noqa: F401 diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py new file mode 100644 index 0000000000..c8dfc67f70 --- /dev/null +++ b/torchao/prototype/dtypes/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from .uintx import BlockSparseLayout + +__all__ = [ + "BlockSparseLayout", +] + diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py new file mode 100644 index 0000000000..2b8e2e66ba --- /dev/null +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from .block_sparse_layout import BlockSparseLayout + +__all__ = [ + "BlockSparseLayout", +] + diff --git a/torchao/prototype/sparsity/block_sparse_layout.py b/torchao/prototype/dtypes/uintx/block_sparse_layout.py similarity index 100% rename from torchao/prototype/sparsity/block_sparse_layout.py rename to torchao/prototype/dtypes/uintx/block_sparse_layout.py diff --git a/torchao/prototype/sparsity/__init__.py b/torchao/prototype/sparsity/__init__.py index 20ea714577..821e5049e0 100644 --- a/torchao/prototype/sparsity/__init__.py +++ b/torchao/prototype/sparsity/__init__.py @@ -1,9 +1,5 @@ # Sparsifier # Scheduler -# Block Sparse Layout -from torchao.prototype.sparsity.block_sparse_layout import ( - BlockSparseLayout, -) from torchao.prototype.sparsity.scheduler.base_scheduler import BaseScheduler from torchao.prototype.sparsity.scheduler.cubic_scheduler import CubicSL from torchao.prototype.sparsity.scheduler.lambda_scheduler import LambdaSL @@ -34,5 +30,4 @@ "get_arg_info_from_tensor_fqn", "module_to_fqn", "WeightNormSparsifier", - "BlockSparseLayout", ] From de054ab08cc45e244877a7f09bf91fe1c8d1e2b1 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 05:37:19 +0000 Subject: [PATCH 08/16] Apply ruff formatting (remove trailing newlines) --- torchao/prototype/dtypes/__init__.py | 1 - torchao/prototype/dtypes/uintx/__init__.py | 1 - 2 files changed, 2 deletions(-) diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index c8dfc67f70..54d395e673 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -9,4 +9,3 @@ __all__ = [ "BlockSparseLayout", ] - diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py index 2b8e2e66ba..107e6a344b 100644 --- a/torchao/prototype/dtypes/uintx/__init__.py +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -9,4 +9,3 @@ __all__ = [ "BlockSparseLayout", ] - From 4ef291fa13fdf89d636d84aecf3368a6584a3736 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 05:46:15 +0000 Subject: [PATCH 09/16] Add Prototype section to dtypes documentation with BlockSparseLayout --- docs/source/api_ref_dtypes.rst | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 37e1407435..abf938c322 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -42,6 +42,17 @@ Quantization techniques to_affine_quantized_floatx_static to_marlinqqq_quantized_intx to_nf4 + +Prototype +--------- +.. currentmodule:: torchao.prototype.dtypes + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + BlockSparseLayout + .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring of torchao.dtypes.nf4tensor.NF4Tensor.dequantize_scalers:6:Unexpected indentation. From d699ee0a418c913017be653a9b2c584eb0bf312a Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Mon, 3 Nov 2025 09:49:54 -0800 Subject: [PATCH 10/16] Clean up imports in affine_quantized_tensor_ops.py Removed redundant imports of block sparse layout functions. --- torchao/dtypes/affine_quantized_tensor_ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 2b6f47e692..45e350068e 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,6 +25,10 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) +from torchao.prototype.dtypes.uintx.block_sparse_layout import ( + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, +) from torchao.dtypes.uintx.cutlass_int4_packed_layout import ( _linear_int4_act_int4_weight_cutlass_check, _linear_int4_act_int4_weight_cutlass_impl, @@ -90,10 +94,6 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) -from torchao.prototype.dtypes.uintx.block_sparse_layout import ( - _linear_int8_act_int8_weight_block_sparse_check, - _linear_int8_act_int8_weight_block_sparse_impl, -) from torchao.quantization.quant_primitives import ( ZeroPointDomain, _dequantize_affine_no_zero_point, From ab8479957f0bbc74530bc0882d9c26881fe28185 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 11:11:58 -0800 Subject: [PATCH 11/16] Update internal links --- torchao/dtypes/__init__.py | 4 +--- torchao/dtypes/uintx/__init__.py | 4 ++++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index ebf6c98553..07f03c7ed9 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,6 +1,3 @@ -# Import BlockSparseLayout from prototype for backward compatibility -from torchao.prototype.dtypes import BlockSparseLayout - from . import affine_quantized_tensor_ops from .affine_quantized_tensor import ( AffineQuantizedTensor, @@ -17,6 +14,7 @@ ) from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( + BlockSparseLayout, CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 1d269fc4c4..6d1bc95653 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,3 +1,6 @@ +from .block_sparse_layout import ( + BlockSparseLayout, +) from .cutlass_int4_packed_layout import ( CutlassInt4PackedLayout, ) @@ -36,6 +39,7 @@ __all__ = [ "UintxLayout", + "BlockSparseLayout", "MarlinSparseLayout", "SemiSparseLayout", "TensorCoreTiledLayout", From 40ab18861fb676d462558276212d386bd9211177 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 21:56:08 +0000 Subject: [PATCH 12/16] test fixes --- torchao/dtypes/__init__.py | 2 +- torchao/dtypes/uintx/__init__.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 07f03c7ed9..6cfcc741fa 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -14,7 +14,6 @@ ) from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( - BlockSparseLayout, CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, @@ -33,6 +32,7 @@ Layout, PlainLayout, ) +from .uintx.block_sparse_layout import BlockSparseLayout __all__ = [ "NF4Tensor", diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 6d1bc95653..1d269fc4c4 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,6 +1,3 @@ -from .block_sparse_layout import ( - BlockSparseLayout, -) from .cutlass_int4_packed_layout import ( CutlassInt4PackedLayout, ) @@ -39,7 +36,6 @@ __all__ = [ "UintxLayout", - "BlockSparseLayout", "MarlinSparseLayout", "SemiSparseLayout", "TensorCoreTiledLayout", From 894857e3a6582842fbfb06090b057f1f32ce8b1a Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 22:08:46 +0000 Subject: [PATCH 13/16] ruff fixes --- torchao/dtypes/__init__.py | 2 +- torchao/dtypes/affine_quantized_tensor_ops.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 6cfcc741fa..b1e7fc9875 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -28,11 +28,11 @@ UintxLayout, to_marlinqqq_quantized_intx, ) +from .uintx.block_sparse_layout import BlockSparseLayout from .utils import ( Layout, PlainLayout, ) -from .uintx.block_sparse_layout import BlockSparseLayout __all__ = [ "NF4Tensor", diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 45e350068e..2b6f47e692 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,10 +25,6 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) -from torchao.prototype.dtypes.uintx.block_sparse_layout import ( - _linear_int8_act_int8_weight_block_sparse_check, - _linear_int8_act_int8_weight_block_sparse_impl, -) from torchao.dtypes.uintx.cutlass_int4_packed_layout import ( _linear_int4_act_int4_weight_cutlass_check, _linear_int4_act_int4_weight_cutlass_impl, @@ -94,6 +90,10 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) +from torchao.prototype.dtypes.uintx.block_sparse_layout import ( + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, +) from torchao.quantization.quant_primitives import ( ZeroPointDomain, _dequantize_affine_no_zero_point, From 9bff1628cee5f8efa2c04aca5be61df15d66442e Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 4 Nov 2025 11:22:07 -0800 Subject: [PATCH 14/16] Add test cases --- test/sparsity/test_sparse_api.py | 27 +++++++++++++++++++++ torchao/dtypes/uintx/block_sparse_layout.py | 3 ++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 66cd032a9a..c9d41a98a9 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -267,6 +267,33 @@ def test_sparse(self, compile): torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1) + # TODO: Remove this test once the deprecated API has been removed + def test_sparse_deprecated(self): + import sys + import warnings + + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.block_sparse_layout", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + from torchao.dtypes import BlockSparseLayout # noqa: F401 + + warnings.simplefilter("always") # Ensure all warnings are captured + self.assertTrue( + any( + issubclass(warning.category, DeprecationWarning) + and "BlockSparseLayout" in str(warning.message) + for warning in w + ), + f"Expected deprecation warning for BlockSparseLayout, got: {[str(w.message) for w in w]}", + ) + common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse) common_utils.instantiate_parametrized_tests(TestQuantSemiSparse) diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index 2840a9b7fa..c1d6436268 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -10,7 +10,8 @@ warnings.warn( "Importing BlockSparseLayout from torchao.dtypes is deprecated. " "Please use 'from torchao.prototype.dtypes import BlockSparseLayout' instead. " - "This import path will be removed in torchao v0.16.0.", + "This import path will be removed in torchao v0.16.0. " + "Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ", DeprecationWarning, stacklevel=2, ) From d23530543adbc1147abd529b5051a7da49f8c64d Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Wed, 5 Nov 2025 11:19:39 -0800 Subject: [PATCH 15/16] Update deprecation warning message for BlockSparseLayout --- torchao/dtypes/uintx/block_sparse_layout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index c1d6436268..b663c12696 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -10,7 +10,7 @@ warnings.warn( "Importing BlockSparseLayout from torchao.dtypes is deprecated. " "Please use 'from torchao.prototype.dtypes import BlockSparseLayout' instead. " - "This import path will be removed in torchao v0.16.0. " + "This import path will be removed in future torchao releases. " "Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ", DeprecationWarning, stacklevel=2, From 3c6d09d1cb1478b9b59a487fd63c976099304f43 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Wed, 5 Nov 2025 11:20:18 -0800 Subject: [PATCH 16/16] Fix deprecation warning message for BlockSparseLayout --- torchao/dtypes/uintx/block_sparse_layout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index b663c12696..6ca4e8745a 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -10,7 +10,7 @@ warnings.warn( "Importing BlockSparseLayout from torchao.dtypes is deprecated. " "Please use 'from torchao.prototype.dtypes import BlockSparseLayout' instead. " - "This import path will be removed in future torchao releases. " + "This import path will be removed in a future torchao release. " "Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ", DeprecationWarning, stacklevel=2,