Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 45 additions & 15 deletions torchao/dtypes/floatx/cutlass_semi_sparse_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,23 +190,53 @@ def _apply_fn_to_data(self, fn):
def _linear_fp8_act_fp8_weight_sparse_cutlass_check(input_tensor, weight_tensor, bias):
from torchao.dtypes.floatx import Float8Layout

return (
isinstance(input_tensor, AffineQuantizedTensor)
and isinstance(input_tensor._layout, Float8Layout)
and input_tensor.dtype in (torch.float16, torch.bfloat16)
and len(input_tensor.shape) >= 2
and input_tensor.tensor_impl.scale.dtype == torch.float32
and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1
and isinstance(weight_tensor, AffineQuantizedTensor)
and isinstance(weight_tensor._layout, CutlassSemiSparseLayout)
and weight_tensor.dtype == input_tensor.dtype
and len(weight_tensor.shape) == 2
and weight_tensor.tensor_impl.scale.dtype == torch.float32
and len(weight_tensor.tensor_impl.scale.shape) == 1
and (bias is None or bias.dtype == input_tensor.dtype)
and (bias is None or len(bias.shape) == 1)
base_check = (
isinstance(input_tensor, AffineQuantizedTensor) and
isinstance(input_tensor._layout, Float8Layout) and
input_tensor.dtype in (torch.float16, torch.bfloat16) and
len(input_tensor.shape) >= 2 and
input_tensor.tensor_impl.scale.dtype == torch.float32 and
isinstance(weight_tensor, AffineQuantizedTensor) and
isinstance(weight_tensor._layout, CutlassSemiSparseLayout) and
weight_tensor.dtype == input_tensor.dtype and
len(weight_tensor.shape) == 2 and
weight_tensor.tensor_impl.scale.dtype == torch.float32 and
(bias is None or bias.dtype == input_tensor.dtype) and
(bias is None or len(bias.shape) == 1)
)

if base_check:

# do extra check and reshape if needed
input_tensor_squeezed = False
if len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) and \
len(input_tensor.tensor_impl.scale.shape) > 1 and \
input_tensor.tensor_impl.scale.shape[-1] == 1:
input_tensor.tensor_impl.scale = torch.squeeze(input_tensor.tensor_impl.scale, dim=-1)
input_tensor_squeezed = True

weight_tensor_squeezed = False
if len(weight_tensor.tensor_impl.scale.shape) == 2 and \
weight_tensor.tensor_impl.scale.shape[-1] == 1:
weight_tensor.tensor_impl.scale = torch.squeeze(weight_tensor.tensor_impl.scale, dim=-1)
weight_tensor_squeezed = True

extra_check = (
len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1
and len(weight_tensor.tensor_impl.scale.shape) == 1
)

if not extra_check: # revert if extra check failed
if input_tensor_squeezed:
input_tensor.tensor_impl.scale = torch.unsqueeze(input_tensor.tensor_impl.scale, dim=-1)
if weight_tensor_squeezed:
weight_tensor.tensor_impl.scale = torch.unsqueeze(weight_tensor.tensor_impl.scale, dim=-1)

return extra_check

else:
return False


def _linear_fp8_act_fp8_weight_sparse_cutlass_impl(input_tensor, weight_tensor, bias):
from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8
Expand Down
Loading