diff --git a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py index 35e6a83656..e49e8e8129 100644 --- a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py +++ b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py @@ -106,12 +106,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif func is aten.to.dtype_layout: dense, scale, _ = args[0].get_plain() - dense = dense.to( + product = dense.to(scale.dtype) * scale + return product.to( *args[1:], dtype=kwargs.get("dtype", dense.dtype), device=kwargs.get("device", dense.device), ) - return scale * dense raise NotImplementedError( f"CutlassSemiSparseTensorImpl dispatch: attempting to run {func}, this is not supported" @@ -135,11 +135,12 @@ def get_plain(self): # semi-structured format, so multiplying with identity matrix, # and using identity scale factors, for the conversion. cols = self.shape[1] - input = torch.eye(cols, dtype=self.sparse.dtype, device=self.sparse.device) - input_scale = torch.ones( - (cols,), dtype=self.scale.dtype, device=self.sparse.device - ) + plain_input = torch.eye(cols, device=self.sparse.device) + input = plain_input.to(dtype=self.sparse.dtype) + plain_input_scale = torch.ones((cols,), device=self.sparse.device) + input_scale = plain_input_scale.to(dtype=self.scale.dtype) sparse_scale = torch.ones_like(self.scale) + out_dtype = torch.bfloat16 dense = ( rowwise_scaled_linear_sparse_cutlass_f8f8( diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index cbeb9cdb6f..ebbe844d83 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -133,11 +133,14 @@ def _same_metadata( @implements([torch.nn.functional.linear, aten.linear.default]) def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) + + input_tensor = kwargs.get("input", args[0] if len(args) > 0 else None) + weight_tensor = kwargs.get("weight", args[1] if len(args) > 1 else None) + bias = kwargs.get("bias", args[2] if len(args) > 2 else None) + + assert input_tensor is not None, "input tensor must not be None" + assert weight_tensor is not None, "weight tensor must not be None" + if isinstance(weight_tensor, LinearActivationQuantizedTensor): return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) @@ -216,6 +219,11 @@ def _(func, types, args, kwargs): for tensor_name in self_tensors: getattr(self, tensor_name).copy_(getattr(src, tensor_name)) return + elif type(self) is torch.Tensor and type(src) is LinearActivationQuantizedTensor: + new_src = src.to(dtype=self.dtype, device=self.device) + self.copy_(new_src) + return + raise ValueError( f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" )