diff --git a/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h index eee51eafc6..c7f61b3301 100644 --- a/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h @@ -80,6 +80,7 @@ Tensor pack_weights_cpu( weight_scales.dtype() == torch::kFloat32, "weight_scales must be float32"); CHECK_MSG(weight_scales.dim() == 1, "weight_scales must be 1D"); + CHECK_MSG(group_size >= 1, "group_size must be >= 1"); CHECK_MSG( weight_scales.size(0) == ((n * k) / group_size), "expected 1 scale per group"); @@ -134,9 +135,9 @@ Tensor pack_weights_without_zeros_cpu( const Tensor& weight_qvals, const Tensor& weight_scales, // TODO(T200095131): convert to int64_t when supported by AOTI - // group_size is a meta tensor with size (group_size) + // group_size is a tensor with size (0, group_size) const Tensor& group_size_tensor) { - int64_t group_size = group_size_tensor.size(0); + int64_t group_size = group_size_tensor.size(1); return pack_weights_cpu( weight_qvals, weight_scales, std::nullopt, group_size); } @@ -151,7 +152,7 @@ Tensor pack_weights_with_zeros_cpu( // TODO(T200095131): convert to int64_t when supported by AOTI // group_size is a meta tensor with size (group_size) const Tensor& group_size_tensor) { - int64_t group_size = group_size_tensor.size(0); + int64_t group_size = group_size_tensor.size(1); return pack_weights_cpu( weight_qvals, weight_scales, weight_zeros, group_size); } @@ -164,6 +165,7 @@ Tensor pack_weights_meta( const Tensor& weight_scales, const std::optional& weight_zeros, int64_t group_size) { + CHECK_MSG(group_size >= 1, "group_size must be >= 1"); int n = weight_qvals.size(0); int k = weight_qvals.size(1); @@ -190,7 +192,7 @@ Tensor pack_weights_without_zeros_meta( // TODO(T200095131): convert to int64_t when supported by AOTI // group_size is a meta tensor with size (group_size) const Tensor& group_size_tensor) { - int64_t group_size = group_size_tensor.size(0); + int64_t group_size = group_size_tensor.size(1); return pack_weights_meta( weight_qvals, weight_scales, std::nullopt, group_size); } @@ -205,7 +207,7 @@ Tensor pack_weights_with_zeros_meta( // TODO(T200095131): convert to int64_t when supported by AOTI // group_size is a meta tensor with size (group_size) const Tensor& group_size_tensor) { - int64_t group_size = group_size_tensor.size(0); + int64_t group_size = group_size_tensor.size(1); return pack_weights_meta( weight_qvals, weight_scales, weight_zeros, group_size); } @@ -216,16 +218,19 @@ template Tensor linear_out_cpu( const Tensor& packed_weights, // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to - // int64_t when supported by AOTI Currently they are meta tensors with size - // equal to the int they wrap + // int64_t when supported by AOTI Currently they are tensors with size + // equal to (0, the int they wrap) const Tensor& n_tensor, const Tensor& k_tensor, const Tensor& group_size_tensor, const Tensor& activations, Tensor& out) { - int n = n_tensor.size(0); - int k = k_tensor.size(0); - int group_size = group_size_tensor.size(0); + int n = n_tensor.size(1); + int k = k_tensor.size(1); + int group_size = group_size_tensor.size(1); + CHECK_MSG(n >= 1, "n must be >= 1"); + CHECK_MSG(k >= 1, "k must be >= 1"); + CHECK_MSG(group_size >= 1, "group_size must be >= 1"); #ifdef USE_ATEN CHECK_MSG( @@ -303,8 +308,8 @@ template Tensor linear_cpu( const Tensor& packed_weights, // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to - // int64_t when supported by AOTI Currently they are meta tensors with size - // equal to the int they wrap + // int64_t when supported by AOTI Currently they are tensors with size + // equal to (0, the int they wrap) const Tensor& n_tensor, const Tensor& k_tensor, const Tensor& group_size_tensor, @@ -327,14 +332,17 @@ Tensor linear_meta( const Tensor& packed_weights, // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to // int64_t when supported by AOTI - // Currently they are meta tensors with size equal to the int they wrap + // Currently they are tensors with size equal to (0, the int they wrap) const Tensor& n_tensor, const Tensor& k_tensor, const Tensor& group_size_tensor, const Tensor& activations) { - int n = n_tensor.size(0); - int k = k_tensor.size(0); + int n = n_tensor.size(1); + int k = k_tensor.size(1); + CHECK_MSG(n >= 1, "n must be >= 1"); + CHECK_MSG(k >= 1, "k must be >= 1"); + CHECK_MSG(activations.dim() == 2, "activations must be 2D"); int m = activations.size(0); int k_ = activations.size(1); CHECK_MSG(k == k_, "activation shape is incompatible with packed weights."); diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 26797bdb1c..531d7efaaa 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -80,9 +80,11 @@ def quantize_and_pack_weights(self, weights, nbit, group_size, has_weight_zeros) # TODO(T200095131): convert self.n, self.k, self.group_size to # int when supported by AOTI - self._n = torch.empty(n, dtype=torch.int8) - self._k = torch.empty(k, dtype=torch.int8) - self._group_size = torch.empty(self.group_size, dtype=torch.int8) + # AOTI does not allow a tensor of size (n, 0), so we do (0, n) + self._n = torch.empty(0, n, dtype=torch.int8) + self._k = torch.empty(0, k, dtype=torch.int8) + self._group_size = torch.empty(0, group_size, dtype=torch.int8) + weight_qvals, weight_scales, weight_zeros = _quantize( weights, self.group_size, self.nbit, self.has_weight_zeros @@ -109,7 +111,7 @@ def forward(self, x): assert x.dim() >= 3 lead_shape = x.shape[0:-2] m, k = x.shape[-2], x.shape[-1] - n = self._n.shape[0] + n = self._n.shape[1] x = x.reshape(-1, m, k) res = [ @@ -254,7 +256,7 @@ def _replace_linear_with_quantized_linear(module: nn.Module, kwargs={}): if not isinstance(qlinear, _Int8DynActIntxWeightQuantizedLinearNative): raise e logger.warning( - "_Int8DynActIntxWeightQuantizedLinearNative raised an exception during quantize_and_pack_weights: {e}\n" + f"_Int8DynActIntxWeightQuantizedLinearNative raised an exception during quantize_and_pack_weights: {e}\n" + "Falling back to **slow** implementation _Int8DynActIntxWeightQuantizedLinearFallback." ) qlinear = _Int8DynActIntxWeightQuantizedLinearFallback() diff --git a/torchao/experimental/tests/test_int8_dyn_act_intx_weight_quantizer.py b/torchao/experimental/tests/test_int8_dyn_act_intx_weight_quantizer.py index d431d26939..1a05dccc56 100644 --- a/torchao/experimental/tests/test_int8_dyn_act_intx_weight_quantizer.py +++ b/torchao/experimental/tests/test_int8_dyn_act_intx_weight_quantizer.py @@ -77,16 +77,17 @@ def test_accuracy(self): def test_export_compile_aoti(self): group_size = 32 - m = 1 - n = 256 - k = 256 + m = 3 + k0 = 512 + k1 = 256 + k2 = 128 + k3 = 1024 nbit = 4 has_weight_zeros = False - n_layers = 3 - layers = [torch.nn.Linear(k, n, bias=False) for _ in range(n_layers)] + layers = [torch.nn.Linear(k0, k1, bias=False), torch.nn.Linear(k1, k2, bias=False), torch.nn.Linear(k2, k3, bias=False)] model = torch.nn.Sequential(*layers) - activations = torch.randn(m, k, dtype=torch.float32) + activations = torch.randn(2, 1, m, k0, dtype=torch.float32) print("Quantizing model") quantizer = Int8DynActIntxWeightQuantizer(