Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ Layouts and Tensor Subclasses
FloatxTensor
FloatxTensorCoreLayout
MarlinSparseLayout
BlockSparseLayout
UintxLayout
MarlinQQQTensor
MarlinQQQLayout
Expand All @@ -43,6 +42,17 @@ Quantization techniques
to_affine_quantized_floatx_static
to_marlinqqq_quantized_intx
to_nf4

Prototype
---------
.. currentmodule:: torchao.prototype.dtypes

.. autosummary::
:toctree: generated/
:nosignatures:

BlockSparseLayout

..
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring
of torchao.dtypes.nf4tensor.NF4Tensor.dequantize_scalers:6:Unexpected indentation.
29 changes: 28 additions & 1 deletion test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def test_sparse(self, compile):
quantize_(model_copy, Int8DynamicActivationInt8WeightConfig())
reference = model_copy(input)

from torchao.dtypes import BlockSparseLayout
from torchao.prototype.dtypes import BlockSparseLayout

quantize_(
model,
Expand All @@ -267,6 +267,33 @@ def test_sparse(self, compile):

torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1)

# TODO: Remove this test once the deprecated API has been removed
def test_sparse_deprecated(self):
import sys
import warnings

# We need to clear the cache to force re-importing and trigger the warning again.
modules_to_clear = [
"torchao.dtypes.uintx.block_sparse_layout",
"torchao.dtypes",
]
for mod in modules_to_clear:
if mod in sys.modules:
del sys.modules[mod]

with warnings.catch_warnings(record=True) as w:
from torchao.dtypes import BlockSparseLayout # noqa: F401

warnings.simplefilter("always") # Ensure all warnings are captured
self.assertTrue(
any(
issubclass(warning.category, DeprecationWarning)
and "BlockSparseLayout" in str(warning.message)
for warning in w
),
f"Expected deprecation warning for BlockSparseLayout, got: {[str(w.message) for w in w]}",
)


common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse)
common_utils.instantiate_parametrized_tests(TestQuantSemiSparse)
Expand Down
2 changes: 1 addition & 1 deletion torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)
from .nf4tensor import NF4Tensor, to_nf4
from .uintx import (
BlockSparseLayout,
CutlassInt4PackedLayout,
Int4CPULayout,
Int4XPULayout,
Expand All @@ -29,6 +28,7 @@
UintxLayout,
to_marlinqqq_quantized_intx,
)
from .uintx.block_sparse_layout import BlockSparseLayout
from .utils import (
Layout,
PlainLayout,
Expand Down
8 changes: 4 additions & 4 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
_linear_f16_bf16_act_floatx_weight_check,
_linear_f16_bf16_act_floatx_weight_impl,
)
from torchao.dtypes.uintx.block_sparse_layout import (
_linear_int8_act_int8_weight_block_sparse_check,
_linear_int8_act_int8_weight_block_sparse_impl,
)
from torchao.dtypes.uintx.cutlass_int4_packed_layout import (
_linear_int4_act_int4_weight_cutlass_check,
_linear_int4_act_int4_weight_cutlass_impl,
Expand Down Expand Up @@ -94,6 +90,10 @@
_linear_bf16_act_uint4_weight_check,
_linear_bf16_act_uint4_weight_impl,
)
from torchao.prototype.dtypes.uintx.block_sparse_layout import (
_linear_int8_act_int8_weight_block_sparse_check,
_linear_int8_act_int8_weight_block_sparse_impl,
)
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
_dequantize_affine_no_zero_point,
Expand Down
4 changes: 0 additions & 4 deletions torchao/dtypes/uintx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from .block_sparse_layout import (
BlockSparseLayout,
)
from .cutlass_int4_packed_layout import (
CutlassInt4PackedLayout,
)
Expand Down Expand Up @@ -39,7 +36,6 @@

__all__ = [
"UintxLayout",
"BlockSparseLayout",
"MarlinSparseLayout",
"SemiSparseLayout",
"TensorCoreTiledLayout",
Expand Down
239 changes: 15 additions & 224 deletions torchao/dtypes/uintx/block_sparse_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,231 +3,22 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch.utils._python_dispatch import (
return_and_correct_aliasing,
)
# Backward compatibility stub - imports from the new location
import warnings

from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
register_layout,
)
from torchao.dtypes.uintx.plain_layout import (
PlainAQTTensorImpl,
_aqt_is_int8_reduced_range,
)
from torchao.dtypes.utils import (
Layout,
PlainLayout,
warnings.warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add #2752 to the message

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, I'd remove torchao v0.16.0 and just say "in a future release of torchao", just in case the work gets delayed

"Importing BlockSparseLayout from torchao.dtypes is deprecated. "
"Please use 'from torchao.prototype.dtypes import BlockSparseLayout' instead. "
"This import path will be removed in torchao v0.16.0. "
"Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ",
DeprecationWarning,
stacklevel=2,
)

logger = logging.getLogger(__name__)

aten = torch.ops.aten


@dataclass(frozen=True)
class BlockSparseLayout(Layout):
"""BlockSparseLayout is a data class that represents the layout of a block sparse matrix.

Attributes:
blocksize (int): The size of the blocks in the sparse matrix. Default is 64.
"""

blocksize: int = 64


@register_layout(BlockSparseLayout)
class BlockSparseAQTTensorImpl(PlainAQTTensorImpl):
bsr_crow_indices: Optional[torch.Tensor]
bsr_col_indices: Optional[torch.Tensor]
bsr_values: Optional[torch.Tensor]
scale: Optional[torch.Tensor]
zero_point: Optional[torch.Tensor]

__slots__ = [
"bsr_crow_indices",
"bsr_col_indices",
"bsr_values",
"scale",
"zero_point",
]

@staticmethod
def __new__( # noqa: PYI034
cls,
shape: torch.Size,
bsr_crow_indices: Optional[torch.Tensor],
bsr_col_indices: Optional[torch.Tensor],
bsr_values: Optional[torch.Tensor],
scale: Optional[torch.Tensor],
zero_point: Optional[torch.Tensor],
_layout: Layout,
requires_grad: bool = False,
):
if bsr_values is None:
raise ValueError("bsr values must be provided!")
else:
previous_tensor = bsr_values

kwargs = {
"device": previous_tensor.device,
"dtype": previous_tensor.dtype,
"layout": previous_tensor.layout,
"requires_grad": requires_grad,
}
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__( # noqa: PYI034
self,
shape: torch.Size,
bsr_crow_indices: Optional[torch.Tensor],
bsr_col_indices: Optional[torch.Tensor],
bsr_values: Optional[torch.Tensor],
scale: Optional[torch.Tensor],
zero_point: Optional[torch.Tensor],
_layout: Layout,
requires_grad: bool = False,
):
self.bsr_crow_indices = bsr_crow_indices
self.bsr_col_indices = bsr_col_indices
self.bsr_values = bsr_values
self.scale = scale
self.zero_point = zero_point
self._layout = _layout

def __tensor_flatten__(self):
inner_tensors = list(
filter(lambda x: getattr(self, x) is not None, self.__slots__)
)
tensor_meta = (self.shape, self._layout, self.requires_grad)
return inner_tensors, tensor_meta

@classmethod
def __tensor_unflatten__(
cls,
inner_tensors,
tensor_meta: Tuple[torch.Size, bool],
outer_size,
outer_stride,
) -> torch.Tensor:
shape, _layout, requires_grad = tensor_meta
return cls(
shape=shape,
bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None),
bsr_col_indices=inner_tensors.get("bsr_col_indices", None),
bsr_values=inner_tensors.get("bsr_values", None),
scale=inner_tensors.get("scale", None),
zero_point=inner_tensors.get("zero_point", None),
_layout=_layout,
requires_grad=requires_grad,
)

@classmethod
def from_plain(cls, int_data, scale, zero_point, _layout):
bsr_tensor = int_data.to_sparse_bsr(_layout.blocksize)
return cls(
shape=int_data.shape,
bsr_crow_indices=bsr_tensor.crow_indices(),
bsr_col_indices=bsr_tensor.col_indices(),
bsr_values=bsr_tensor.values(),
scale=scale,
zero_point=zero_point,
_layout=_layout,
requires_grad=False,
)

def get_plain(self):
int_data_expanded = torch.ops.blocksparse.bsr_to_dense(
self.crow_indices(),
self.col_indices(),
self.values(),
self.shape[0],
self.shape[1],
)
return int_data_expanded, self.scale, self.zero_point

def _apply_fn_to_data(self, func):
return self.__class__(
shape=self.shape,
bsr_crow_indices=func(self.bsr_crow_indices),
bsr_col_indices=func(self.bsr_col_indices),
bsr_values=func(self.bsr_values),
scale=self.scale,
zero_point=self.zero_point,
_layout=self._layout,
requires_grad=self.requires_grad,
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)
if func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

# Need the following for bsr specific functions
if func is aten.crow_indices.default:
return args[0].bsr_crow_indices.detach()

if func is aten.col_indices.default:
return args[0].bsr_col_indices.detach()

if func is aten.values.default:
return args[0].bsr_values.detach()

if func is aten._nnz.default:
return args[0].bsr_values.shape[0]

raise NotImplementedError(
f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported"
)


def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and _aqt_is_int8_reduced_range(input_tensor)
and isinstance(weight_tensor, AffineQuantizedTensor)
and weight_tensor.is_cuda
and input_tensor.dtype == weight_tensor.dtype
and isinstance(input_tensor._layout, PlainLayout)
and isinstance(weight_tensor._layout, BlockSparseLayout)
)


def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias):
x_vals_int8 = input_tensor.tensor_impl.int_data
x_scales = input_tensor.tensor_impl.scale
w_vals = weight_tensor.tensor_impl
w_scales = weight_tensor.tensor_impl.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
tmp_t = tmp.t()

y = torch.ops.blocksparse.int_addmm(
w_vals.crow_indices(),
w_vals.col_indices(),
w_vals.values(),
tmp_t,
w_scales,
x_scales.reshape(-1),
)
y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1])
y = y.reshape(*y_shape)

# can downcast only at the very end
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
return y
from torchao.prototype.dtypes.uintx.block_sparse_layout import (
BlockSparseAQTTensorImpl, # noqa: F401
BlockSparseLayout, # noqa: F401
_linear_int8_act_int8_weight_block_sparse_check, # noqa: F401
_linear_int8_act_int8_weight_block_sparse_impl, # noqa: F401
)
11 changes: 11 additions & 0 deletions torchao/prototype/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from .uintx import BlockSparseLayout

__all__ = [
"BlockSparseLayout",
]
11 changes: 11 additions & 0 deletions torchao/prototype/dtypes/uintx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from .block_sparse_layout import BlockSparseLayout

__all__ = [
"BlockSparseLayout",
]
Loading
Loading