Skip to content

Commit

Permalink
Update on "Rewrite implementation of faithful cpp signatures"
Browse files Browse the repository at this point in the history
This rewrite is as per my comments at #44087 (comment)
I did the rewrite by reverting #44087 and then reimplementing it on top.
You may find it easier to review by diffing against master with only #44087
reverted.

There are two main ideas.

First, we now factor cpp argument processing into two phases operating
on three representations of data:

1. `FunctionSchema` - this is the source from native_functions.yaml
2. `Union[Argument, ThisArgument, TensorOptionsArgument]` - this is
   the arguments after doing some basic semantic analysis to group
   them (for TensorOptions) or identify the this argument (if this
   is a method).  There is only ever one of these per functions.
3. `Union[CppArgument, CppThisArgument, CppTensorOptionsArgument]` -
   this is the arguments after we've elaborated them to C++.  There
   may be multiple of these per actual C++ signature.

You can think of (2) as common processing, whereas (3) bakes in specific
assumptions about whether or not you have a faithful or non-faithful
signature.

Second, we now have CppSignature and CppSignatureGroup representing
the *total* public C++ API signature.  So those dataclasses are what
know how to render definitions/declarations, and you no longer have
to manually type it out in the Functions/TensorMethods codegen.

Here is an exhaustive accounting of the changes.

tools.codegen.api.types

- CppSignature and CppSignatureGroup got moved to tools.codegen.api.types
- Add new CppThisArgument and CppTensorOptionsArguments (modeled off
  of ThisArgument and TensorOptionsArguments) so that we can retain
  high level semantic structure even after elaborating terms with C++
  API information.  Once this is done, we can refine
  CppArgument.argument to no longer contain a ThisArgument (ThisArgument
  is always translated to CppThisArgument.  Note that this doesn't
  apply to TensorOptionsArguments, as those may be expanded or not
  expanded, and so you could get a single CppArgument for 'options')
- Add no_default() functional mutator to easily remove default arguments
  from CppArgument and friends
- Add an explicit_arguments() method to CppArgument and friends to
  extract (flat) argument list that must be explicitly written in the signature.
  This is everything except (Cpp)ThisArgument, and is also convenient
  when you don't care about the extra structure of
  CppTensorOptionsArguments

tools.codegen.api.cpp

- group_arguments is back, and it doesn't send things directly to a
  CppSignatureGroup; instead, it moves us from representation (1) to (2)
  (perhaps it should live in model).  Here I changed my mind from my
  PR comment; I discovered it was not necessary to do classification at
  grouping time, and it was simpler and easier to do it later.
- argument got split into argument_not_this/argument/argument_faithful.
  argument and argument_faithful are obvious enough what they do,
  and I needed argument_not_this as a more refined version of argument
  so that I could get the types to work out on TensorOptionsArguments

tools.codegen.api.dispatcher

- Here we start seeing the payoff.  The old version of this code had a
  "scatter" mode and a "gather" mode.  We don't need that anymore:
  cppargument_exprs is 100% type-directed via the passed in cpp
  arguments.  I am able to write the functions without any reference
  to use_c10_dispatcher

tools.codegen.gen

- Instead of having exprs_str and types_str functions, I moved these to
  live directly on CppSignature, since it seemed pretty logical.
- The actual codegen for TensorMethods/Functions is greatly simplified,
  since (1) all of the heavy lifting is now happening in
  CppSignature(Group) construction, and (2) I don't need to proxy one
  way or another, the new dispatcher translation code is able to handle
  both cases no problem.  There is a little faffing about with ordering
  to reduce the old and new diff which could be removed afterwards.

Here are codegen diffs.  For use_c10_dispatcher: full:

```
+// aten::_cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
 Tensor _cudnn_init_dropout_state(double dropout, bool train, int64_t dropout_seed, const TensorOptions & options) {
-    return _cudnn_init_dropout_state(dropout, train, dropout_seed, optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    static auto op = c10::Dispatcher::singleton()
+        .findSchemaOrThrow("aten::_cudnn_init_dropout_state", "")
+        .typed<Tensor (double, bool, int64_t, c10::optional<ScalarType>, c10::optional<Layout>, c10::optional<Device>, c10::optional<bool>)>();
+    return op.call(dropout, train, dropout_seed, optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
 }
```

Otherwise:

```
+// aten::empty_meta(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
 Tensor empty_meta(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<MemoryFormat> memory_format) {
-    return empty_meta(size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory), memory_format);
+    static auto op = c10::Dispatcher::singleton()
+        .findSchemaOrThrow("aten::empty_meta", "")
+        .typed<Tensor (IntArrayRef, const TensorOptions &, c10::optional<MemoryFormat>)>();
+    return op.call(size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory), memory_format);
 }
```

Things that I probably did not get right:

- The Union[Argument, TensorOptionsArguments, ThisArgument] and
  the Cpp variants are starting to get a little unwieldy.  Not sure if
  this means I should add a supertype (or at the very least an
  alias); in some cases I do purposely omit one of these from the Union
- Code may not necessarily live in the most logical files.  There isn't
  very much rhyme or reason to it.
- The fields on CppSignature.  They're not very well constrained and
  it will be better if people don't use them directly.
- Disambiguation.  We should do this properly in #44087 and we don't
  need special logic for deleting defaulting for faithful signatures;
  there is a more general story here.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D24144035](https://our.internmc.facebook.com/intern/diff/D24144035)

[ghstack-poisoned]
  • Loading branch information
ezyang committed Oct 8, 2020
2 parents 03699a1 + 7d4f506 commit 729aea3
Show file tree
Hide file tree
Showing 185 changed files with 15,708 additions and 2,814 deletions.
1 change: 1 addition & 0 deletions .circleci/cimodel/data/pytorch_build_definitions.py
Expand Up @@ -288,6 +288,7 @@ def instantiate_configs():
rocm_version = None
if compiler_name == "cuda":
cuda_version = fc.find_prop("compiler_version")
restrict_phases = ["build", "test1", "test2"]

elif compiler_name == "rocm":
rocm_version = fc.find_prop("compiler_version")
Expand Down
78 changes: 68 additions & 10 deletions .circleci/config.yml
Expand Up @@ -6668,7 +6668,7 @@ workflows:
build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-build"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7"
- pytorch_linux_test:
name: pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_test
name: pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_test1
requires:
- pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_build
filters:
Expand All @@ -6677,7 +6677,21 @@ workflows:
- master
- /ci-all\/.*/
- /release\/.*/
build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-test"
build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-test1"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
- pytorch_linux_test:
name: pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_test2
requires:
- pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_build
filters:
branches:
only:
- master
- /ci-all\/.*/
- /release\/.*/
build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-test2"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
Expand Down Expand Up @@ -6706,10 +6720,18 @@ workflows:
build_environment: "pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-build"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7"
- pytorch_linux_test:
name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test
name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test1
requires:
- pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build
build_environment: "pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test"
build_environment: "pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test1"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
- pytorch_linux_test:
name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test2
requires:
- pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build
build_environment: "pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test2"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
Expand Down Expand Up @@ -6780,7 +6802,21 @@ workflows:
build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-build"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7"
- pytorch_linux_test:
name: pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test
name: pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test1
requires:
- pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build
filters:
branches:
only:
- master
- /ci-all\/.*/
- /release\/.*/
build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test1"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
- pytorch_linux_test:
name: pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test2
requires:
- pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build
filters:
Expand All @@ -6789,7 +6825,7 @@ workflows:
- master
- /ci-all\/.*/
- /release\/.*/
build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test"
build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test2"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
Expand All @@ -6806,7 +6842,21 @@ workflows:
build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-build"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
- pytorch_linux_test:
name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test
name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test1
requires:
- pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build
filters:
branches:
only:
- master
- /ci-all\/.*/
- /release\/.*/
build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test1"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
- pytorch_linux_test:
name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test2
requires:
- pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build
filters:
Expand All @@ -6815,7 +6865,7 @@ workflows:
- master
- /ci-all\/.*/
- /release\/.*/
build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test"
build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test2"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
Expand All @@ -6826,10 +6876,18 @@ workflows:
build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-build"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
- pytorch_linux_test:
name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test
name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test1
requires:
- pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build
build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test1"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
- pytorch_linux_test:
name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test2
requires:
- pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build
build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test"
build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test2"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
Expand Down
6 changes: 4 additions & 2 deletions .clang-tidy
@@ -1,6 +1,7 @@
---
# NOTE there must be no spaces before the '-', so put the comma last.
Checks: '-*,
InheritParentConfig: true
Checks: '
bugprone-*,
-bugprone-forward-declaration-namespace,
-bugprone-macro-parentheses,
Expand All @@ -17,6 +18,7 @@ cppcoreguidelines-*,
-cppcoreguidelines-pro-type-union-access,
-cppcoreguidelines-pro-type-vararg,
-cppcoreguidelines-special-member-functions,
-facebook-hte-RelativeInclude,
hicpp-exception-baseclass,
hicpp-avoid-goto,
modernize-*,
Expand All @@ -27,7 +29,7 @@ modernize-*,
-modernize-use-trailing-return-type,
performance-*,
-performance-noexcept-move-constructor,
'
'
HeaderFilterRegex: 'torch/csrc/.*'
AnalyzeTemporaryDtors: false
CheckOptions:
Expand Down
8 changes: 2 additions & 6 deletions .gitmodules
Expand Up @@ -124,13 +124,9 @@
url = https://github.com/google/XNNPACK.git
[submodule "third_party/fmt"]
ignore = dirty
path = third_party/fmt
url = https://github.com/fmtlib/fmt.git
path = third_party/fmt
url = https://github.com/fmtlib/fmt.git
[submodule "third_party/tensorpipe"]
ignore = dirty
path = third_party/tensorpipe
url = https://github.com/pytorch/tensorpipe.git
[submodule "third_party/valgrind"]
ignore = dirty
path = third_party/valgrind
url = https://sourceware.org/git/valgrind.git
4 changes: 2 additions & 2 deletions .jenkins/pytorch/win-test-helpers/build_pytorch.bat
Expand Up @@ -95,15 +95,15 @@ if "%USE_CUDA%"=="1" (
copy %TMP_DIR_WIN%\bin\sccache.exe %TMP_DIR_WIN%\bin\nvcc.exe

:: randomtemp is used to resolve the intermittent build error related to CUDA.
:: code: https://github.com/peterjc123/randomtemp
:: code: https://github.com/peterjc123/randomtemp-rust
:: issue: https://github.com/pytorch/pytorch/issues/25393
::
:: Previously, CMake uses CUDA_NVCC_EXECUTABLE for finding nvcc and then
:: the calls are redirected to sccache. sccache looks for the actual nvcc
:: in PATH, and then pass the arguments to it.
:: Currently, randomtemp is placed before sccache (%TMP_DIR_WIN%\bin\nvcc)
:: so we are actually pretending sccache instead of nvcc itself.
curl -kL https://github.com/peterjc123/randomtemp/releases/download/v0.3/randomtemp.exe --output %TMP_DIR_WIN%\bin\randomtemp.exe
curl -kL https://github.com/peterjc123/randomtemp-rust/releases/download/v0.2/randomtemp.exe --output %TMP_DIR_WIN%\bin\randomtemp.exe
set RANDOMTEMP_EXECUTABLE=%TMP_DIR_WIN%\bin\nvcc.exe
set CUDA_NVCC_EXECUTABLE=%TMP_DIR_WIN%\bin\randomtemp.exe
set RANDOMTEMP_BASEDIR=%TMP_DIR_WIN%\bin
Expand Down
18 changes: 18 additions & 0 deletions aten/src/ATen/BatchingRegistrations.cpp
Expand Up @@ -172,6 +172,22 @@ std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int6
return result;
}

std::vector<Tensor> tensor_split_sections_batching_rule(const Tensor& self, int64_t sections, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = at::tensor_split(self_physical.tensor(), sections, dim_physical);
self_physical.makeLogicalFromPhysicalListInplace(result);
return result;
}

std::vector<Tensor> tensor_split_indices_batching_rule(const Tensor& self, IntArrayRef indices, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = at::tensor_split(self_physical.tensor(), indices, dim_physical);
self_physical.makeLogicalFromPhysicalListInplace(result);
return result;
}

Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
// NB: unsqueeze has some special handling of its `dim` argument so we can't call
Expand Down Expand Up @@ -527,6 +543,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {

// view operations
m.impl("chunk", chunk_batching_rule);
m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
m.impl("diagonal", diagonal_batching_rule);
m.impl("expand", expand_batching_rule);
m.impl("expand_as", native::expand_as); // composite wrt autograd
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/NamedRegistrations.cpp
Expand Up @@ -453,6 +453,8 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("tanh", CppFunction::makeFallthrough());
m.impl("tanh.out", CppFunction::makeFallthrough());
m.impl("tanh_", CppFunction::makeFallthrough());
m.impl("tensor_split.indices", CppFunction::makeFallthrough());
m.impl("tensor_split.sections", CppFunction::makeFallthrough());
m.impl("threshold", CppFunction::makeFallthrough());
m.impl("threshold.out", CppFunction::makeFallthrough());
m.impl("threshold_", CppFunction::makeFallthrough());
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -664,6 +664,7 @@ _(aten, tan) \
_(aten, tanh) \
_(aten, tensor) \
_(aten, tensordot) \
_(aten, tensor_split) \
_(aten, th_addmm) \
_(aten, th_clone) \
_(aten, th_norm) \
Expand Down
15 changes: 11 additions & 4 deletions aten/src/ATen/core/type.cpp
Expand Up @@ -27,7 +27,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
out << "Tensor";
}
if (auto ndim = value->sizes().size()) {
bool has_valid_strides_info =
bool has_valid_strides_info = *ndim > 0 &&
value->strides().isComplete() && value->strides().size() == ndim;

out << "(";
Expand All @@ -41,10 +41,17 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
} else {
out << "*";
}
if (has_valid_strides_info &&
type_verbosity() >= TypeVerbosity::TypeAndStride) {
out << ":" << *value->strides()[i];
}
if (has_valid_strides_info &&
type_verbosity() >= TypeVerbosity::TypeAndStride) {
out << ", strides=[";
for (size_t i = 0; i < *ndim; ++i) {
if (i > 0) {
out << ", ";
}
out << *value->strides()[i];
}
out << "]";
}
if (type_verbosity() >= TypeVerbosity::Full) {
if (value->requiresGrad()) {
Expand Down
33 changes: 1 addition & 32 deletions aten/src/ATen/cpu/vec256/vec256_base.h
Expand Up @@ -615,23 +615,12 @@ inline T minimum(const T& a, const T& b) {
return c;
}

// To save BC, it will not propagate NaN based on IEEE 754 201X
template <class T,
typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
Vec256<T> inline clamp(const Vec256<T> &a, const Vec256<T> &min_vec, const Vec256<T> &max_vec) {
Vec256<T> c = Vec256<T>();
for (int i = 0; i != Vec256<T>::size(); i++) {
c[i] = a[i] < min_vec[i] ? min_vec[i] : (a[i] > max_vec[i] ? max_vec[i] : a[i]);
}
return c;
}

template <class T,
typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
Vec256<T> inline clamp(const Vec256<T> &a, const Vec256<T> &min_vec, const Vec256<T> &max_vec) {
Vec256<T> c = Vec256<T>();
for (int i = 0; i != Vec256<T>::size(); i++) {
c[i] = std::abs(a[i]) < std::abs(min_vec[i]) ? min_vec[i] : (std::abs(a[i]) > std::abs(max_vec[i]) ? max_vec[i] : a[i]);
c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]);
}
return c;
}
Expand All @@ -646,16 +635,6 @@ Vec256<T> inline clamp_max(const Vec256<T> &a, const Vec256<T> &max_vec) {
return c;
}

template <class T,
typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
Vec256<T> inline clamp_max(const Vec256<T> &a, const Vec256<T> &max_vec) {
Vec256<T> c = Vec256<T>();
for (int i = 0; i != Vec256<T>::size(); i++) {
c[i] = std::abs(a[i]) > std::abs(max_vec[i]) ? max_vec[i] : a[i];
}
return c;
}

template <class T,
typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
Vec256<T> inline clamp_min(const Vec256<T> &a, const Vec256<T> &min_vec) {
Expand All @@ -666,16 +645,6 @@ Vec256<T> inline clamp_min(const Vec256<T> &a, const Vec256<T> &min_vec) {
return c;
}

template <class T,
typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
Vec256<T> inline clamp_min(const Vec256<T> &a, const Vec256<T> &min_vec) {
Vec256<T> c = Vec256<T>();
for (int i = 0; i != Vec256<T>::size(); i++) {
c[i] = std::abs(a[i]) < std::abs(min_vec[i]) ? min_vec[i] : a[i];
}
return c;
}

struct Vec256i;

#ifdef CPU_CAPABILITY_AVX2
Expand Down
26 changes: 0 additions & 26 deletions aten/src/ATen/cpu/vec256/vec256_complex_double.h
Expand Up @@ -416,32 +416,6 @@ Vec256<c10::complex<double>> inline minimum(const Vec256<c10::complex<double>>&
return _mm256_or_pd(min, isnan);
}

template <>
Vec256<c10::complex<double>> inline clamp(const Vec256<c10::complex<double>>& a, const Vec256<c10::complex<double>>& min, const Vec256<c10::complex<double>>& max) {
auto abs_a = a.abs_2_();
auto abs_min = min.abs_2_();
auto max_mask = _mm256_cmp_pd(abs_a, abs_min, _CMP_LT_OQ);
auto abs_max = max.abs_2_();
auto min_mask = _mm256_cmp_pd(abs_a, abs_max, _CMP_GT_OQ);
return _mm256_blendv_pd(_mm256_blendv_pd(a, min, max_mask), max, min_mask);
}

template <>
Vec256<c10::complex<double>> inline clamp_min(const Vec256<c10::complex<double>>& a, const Vec256<c10::complex<double>>& min) {
auto abs_a = a.abs_2_();
auto abs_min = min.abs_2_();
auto max_mask = _mm256_cmp_pd(abs_a, abs_min, _CMP_LT_OQ);
return _mm256_blendv_pd(a, min, max_mask);
}

template <>
Vec256<c10::complex<double>> inline clamp_max(const Vec256<c10::complex<double>>& a, const Vec256<c10::complex<double>>& max) {
auto abs_a = a.abs_2_();
auto abs_max = max.abs_2_();
auto min_mask = _mm256_cmp_pd(abs_a, abs_max, _CMP_GT_OQ);
return _mm256_blendv_pd(a, max, min_mask);
}

template <>
Vec256<c10::complex<double>> inline operator&(const Vec256<c10::complex<double>>& a, const Vec256<c10::complex<double>>& b) {
return _mm256_and_pd(a, b);
Expand Down

0 comments on commit 729aea3

Please sign in to comment.