Skip to content

Commit

Permalink
[quant] Quantized flip dispatch (#46235)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #46235

Test Plan: Imported from OSS

Reviewed By: vkuzo

Differential Revision: D24689161

Pulled By: z-a-f

fbshipit-source-id: 6833c2639b29ea5f6c81c880b8928c5a1951c7b8
  • Loading branch information
z-a-f authored and facebook-github-bot committed Nov 3, 2020
1 parent f41f3e3 commit 31ebac3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
33 changes: 24 additions & 9 deletions aten/src/ATen/native/TensorTransformations.cpp
Expand Up @@ -61,15 +61,30 @@ Tensor flip_cpu(const Tensor& self, IntArrayRef dims) {
}
}

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, in_tensor.scalar_type(), "flip_cpu", [&] {
flip_cpu_kernel<scalar_t>(
total_dims,
stride_contiguous_v,
flip_dims_b,
in_tensor,
out_tensor
);
});
if (in_tensor.is_quantized()) {
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(in_tensor.scalar_type(),
"flip_quantized_cpu", [&] {
flip_cpu_kernel<scalar_t>(
total_dims,
stride_contiguous_v,
flip_dims_b,
in_tensor,
out_tensor
);
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool,
in_tensor.scalar_type(),
"flip_cpu", [&] {
flip_cpu_kernel<scalar_t>(
total_dims,
stride_contiguous_v,
flip_dims_b,
in_tensor,
out_tensor
);
});
}

return out_tensor;
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Expand Up @@ -3882,7 +3882,7 @@
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU: flip_cpu
CPU, QuantizedCPU: flip_cpu
CUDA: flip_cuda

- func: fliplr(Tensor self) -> Tensor
Expand Down

0 comments on commit 31ebac3

Please sign in to comment.