diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index 1b86b3f2d634..fdee519c4bd0 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -61,15 +61,30 @@ Tensor flip_cpu(const Tensor& self, IntArrayRef dims) { } } - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, in_tensor.scalar_type(), "flip_cpu", [&] { - flip_cpu_kernel( - total_dims, - stride_contiguous_v, - flip_dims_b, - in_tensor, - out_tensor - ); - }); + if (in_tensor.is_quantized()) { + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(in_tensor.scalar_type(), + "flip_quantized_cpu", [&] { + flip_cpu_kernel( + total_dims, + stride_contiguous_v, + flip_dims_b, + in_tensor, + out_tensor + ); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, + in_tensor.scalar_type(), + "flip_cpu", [&] { + flip_cpu_kernel( + total_dims, + stride_contiguous_v, + flip_dims_b, + in_tensor, + out_tensor + ); + }); + } return out_tensor; } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 398aa7474eab..349256477df3 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3882,7 +3882,7 @@ use_c10_dispatcher: full variants: function, method dispatch: - CPU: flip_cpu + CPU, QuantizedCPU: flip_cpu CUDA: flip_cuda - func: fliplr(Tensor self) -> Tensor