Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions torchao/dtypes/floatx/cutlass_semi_sparse_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
)
elif func is aten.to.dtype_layout:
dense, scale, _ = args[0].get_plain()
dense = dense.to(
product = dense.to(scale.dtype) * scale
return product.to(
*args[1:],
dtype=kwargs.get("dtype", dense.dtype),
device=kwargs.get("device", dense.device),
)
return scale * dense

raise NotImplementedError(
f"CutlassSemiSparseTensorImpl dispatch: attempting to run {func}, this is not supported"
Expand All @@ -135,11 +135,12 @@ def get_plain(self):
# semi-structured format, so multiplying with identity matrix,
# and using identity scale factors, for the conversion.
cols = self.shape[1]
input = torch.eye(cols, dtype=self.sparse.dtype, device=self.sparse.device)
input_scale = torch.ones(
(cols,), dtype=self.scale.dtype, device=self.sparse.device
)
plain_input = torch.eye(cols, device=self.sparse.device)
input = plain_input.to(dtype=self.sparse.dtype)
plain_input_scale = torch.ones((cols,), device=self.sparse.device)
input_scale = plain_input_scale.to(dtype=self.scale.dtype)
sparse_scale = torch.ones_like(self.scale)

out_dtype = torch.bfloat16
dense = (
rowwise_scaled_linear_sparse_cutlass_f8f8(
Expand Down
18 changes: 13 additions & 5 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,14 @@ def _same_metadata(

@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)

input_tensor = kwargs.get("input", args[0] if len(args) > 0 else None)
weight_tensor = kwargs.get("weight", args[1] if len(args) > 1 else None)
bias = kwargs.get("bias", args[2] if len(args) > 2 else None)

assert input_tensor is not None, "input tensor must not be None"
assert weight_tensor is not None, "weight tensor must not be None"

if isinstance(weight_tensor, LinearActivationQuantizedTensor):
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)

Expand Down Expand Up @@ -216,6 +219,11 @@ def _(func, types, args, kwargs):
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
elif type(self) is torch.Tensor and type(src) is LinearActivationQuantizedTensor:
new_src = src.to(dtype=self.dtype, device=self.device)
self.copy_(new_src)
return

raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)
Expand Down