Skip to content

Commit

Permalink
Update on "[ONNX] Add binary_cross_entropy_with_logits op to ONNX ops…
Browse files Browse the repository at this point in the history
…et version 12 (#49675)"


Fixes #{#47997}
Exporting the operator binary_cross_entropy_with_logits to ONNX opset version 12.

[ghstack-poisoned]
  • Loading branch information
BowenBao committed Jan 25, 2021
2 parents 1264483 + 81ecff0 commit c0604c8
Show file tree
Hide file tree
Showing 109 changed files with 2,552 additions and 1,114 deletions.
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ genrule(
"aten/src/ATen/RegisterMeta.cpp",
"aten/src/ATen/RegisterDefaultBackend.cpp",
"aten/src/ATen/RegisterSchema.cpp",
"aten/src/ATen/CPUFunctions.h",
"aten/src/ATen/CUDAFunctions.h",
"aten/src/ATen/Functions.h",
"aten/src/ATen/Functions.cpp",
"aten/src/ATen/NativeFunctions.h",
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ if(USE_CUDA AND NOT USE_ROCM)
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcublas_static.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcufft_static_nocallback.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcusolver_static.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/liblapack_static.a # needed for libcusolver_static
)
else()
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/CPUGeneratorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <ATen/Utils.h>
#include <ATen/core/MT19937RNGEngine.h>
#include <c10/util/C++17.h>
#include <c10/util/MathConstants.h>
#include <algorithm>

namespace at {
Expand Down Expand Up @@ -153,7 +154,7 @@ void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
// intermediate values.
if (legacy_pod->normal_is_valid) {
auto r = legacy_pod->normal_rho;
auto theta = 2.0 * M_PI * legacy_pod->normal_x;
auto theta = 2.0 * c10::pi<double> * legacy_pod->normal_x;
// we return the sin version of the normal sample when in caching mode
double_normal_sample = c10::optional<double>(r * ::sin(theta));
}
Expand Down
47 changes: 0 additions & 47 deletions aten/src/ATen/LegacyTHFunctionsCPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,53 +776,6 @@ std::tuple<Tensor,Tensor> _th_geqrf(const Tensor & self) {
}
return std::tuple<Tensor, Tensor>(res1, res2);
}
Tensor & _th_orgqr_out(Tensor & result, const Tensor & self, const Tensor & input2) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);

switch (dispatch_scalar_type) {
case ScalarType::Double: {
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_orgqr_out", false, DeviceType::CPU, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_orgqr_out", false, DeviceType::CPU, dispatch_scalar_type);
auto input2_ = checked_dense_tensor_unwrap(input2, "input2", 2, "_th_orgqr_out", false, DeviceType::CPU, dispatch_scalar_type);
THDoubleTensor_orgqr(result_, self_, input2_);
break;
}
case ScalarType::Float: {
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_orgqr_out", false, DeviceType::CPU, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_orgqr_out", false, DeviceType::CPU, dispatch_scalar_type);
auto input2_ = checked_dense_tensor_unwrap(input2, "input2", 2, "_th_orgqr_out", false, DeviceType::CPU, dispatch_scalar_type);
THFloatTensor_orgqr(result_, self_, input2_);
break;
}
default:
AT_ERROR("_th_orgqr_out not supported on CPUType for ", dispatch_scalar_type);
}
return result;
}
Tensor _th_orgqr(const Tensor & self, const Tensor & input2) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);
auto result_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
auto result = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(result_));
switch (dispatch_scalar_type) {
case ScalarType::Double: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_orgqr", false, DeviceType::CPU, dispatch_scalar_type);
auto input2_ = checked_dense_tensor_unwrap(input2, "input2", 2, "_th_orgqr", false, DeviceType::CPU, dispatch_scalar_type);
THDoubleTensor_orgqr(result_, self_, input2_);
break;
}
case ScalarType::Float: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_orgqr", false, DeviceType::CPU, dispatch_scalar_type);
auto input2_ = checked_dense_tensor_unwrap(input2, "input2", 2, "_th_orgqr", false, DeviceType::CPU, dispatch_scalar_type);
THFloatTensor_orgqr(result_, self_, input2_);
break;
}
default:
AT_ERROR("_th_orgqr not supported on CPUType for ", dispatch_scalar_type);
}
return result;
}
Tensor & _th_ormqr_out(Tensor & result, const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/LegacyTHFunctionsCPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper);
Tensor _th_potri(const Tensor & self, bool upper);
std::tuple<Tensor &,Tensor &> _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self);
std::tuple<Tensor,Tensor> _th_geqrf(const Tensor & self);
Tensor & _th_orgqr_out(Tensor & result, const Tensor & self, const Tensor & input2);
Tensor _th_orgqr(const Tensor & self, const Tensor & input2);
Tensor & _th_ormqr_out(Tensor & result, const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose);
Tensor _th_ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose);

Expand Down
11 changes: 2 additions & 9 deletions aten/src/ATen/core/DistributionsHelper.h
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
#pragma once

// define constants like M_PI and C keywords for MSVC
#ifdef _MSC_VER
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include <math.h>
#endif

#include <ATen/core/Array.h>
#include <ATen/core/TransformationHelper.h>
#include <c10/util/Half.h>
#include <c10/util/BFloat16.h>
#include <c10/util/MathConstants.h>
#include <c10/util/Optional.h>
#include <c10/macros/Macros.h>

Expand Down Expand Up @@ -220,7 +213,7 @@ struct normal_distribution {
const dist_acctype<T> u1 = uniform(generator);
const dist_acctype<T> u2 = uniform(generator);
const dist_acctype<T> r = ::sqrt(static_cast<T>(-2.0) * ::log(static_cast<T>(1.0)-u2));
const dist_acctype<T> theta = static_cast<T>(2.0) * static_cast<T>(M_PI) * u1;
const dist_acctype<T> theta = static_cast<T>(2.0) * c10::pi<T> * u1;
if (std::is_same<T, double>::value) {
maybe_set_next_double_normal_sample(generator, r * ::sin(theta));
} else {
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/core/TransformationHelper.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <c10/macros/Macros.h>
#include <c10/util/Half.h>
#include <c10/util/BFloat16.h>
#include <c10/util/MathConstants.h>
#include <ATen/NumericUtils.h>
#include <limits>
#include <cstdint>
Expand Down Expand Up @@ -101,7 +102,7 @@ C10_HOST_DEVICE inline T normal(T val, T mean, T std) {
template <typename T>
C10_HOST_DEVICE inline T cauchy(T val, T median, T sigma) {
// https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function
return median + sigma * at::tan(static_cast<T>(M_PI) * (val - static_cast<T>(0.5)));
return median + sigma * at::tan(c10::pi<T> * (val - static_cast<T>(0.5)));
}

/**
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ namespace c10 {
_(prim, GetAttr) \
_(prim, HasAttr) \
_(prim, profile) \
_(prim, profile_optional) \
_(prim, profile_ivalue) \
_(prim, AddStatValue) \
_(prim, TimePoint) \
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/cpu/vec256/vec256_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
#include <bitset>

#include <ATen/cpu/vec256/intrinsics.h>
#include <ATen/Utils.h>
#include <ATen/native/Copy.h>
#include <ATen/native/Math.h>
#include <ATen/NumericUtils.h>
#include <c10/util/C++17.h>
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ template <> class Vec256<BFloat16> {
const auto nan_vec = _mm256_set1_ps(NAN);
const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ);
const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ);
const auto pi = _mm256_set1_ps(M_PI);
const auto pi = _mm256_set1_ps(c10::pi<float>);

const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask);
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/cpu/vec256/vec256_complex_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ template <> class Vec256<c10::complex<double>> {
}
Vec256<c10::complex<double>> acos() const {
// acos(x) = pi/2 - asin(x)
const __m256d pi_2 = _mm256_setr_pd(M_PI/2, 0.0, M_PI/2, 0.0);
constexpr auto pi_2d = c10::pi<double> / 2;
const __m256d pi_2 = _mm256_setr_pd(pi_2d, 0.0, pi_2d, 0.0);
return _mm256_sub_pd(pi_2, asin());
}
Vec256<c10::complex<double>> atan() const;
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/cpu/vec256/vec256_complex_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ template <> class Vec256<c10::complex<float>> {
}
Vec256<c10::complex<float>> acos() const {
// acos(x) = pi/2 - asin(x)
const __m256 pi_2 = _mm256_setr_ps(M_PI/2, 0.0, M_PI/2, 0.0, M_PI/2, 0.0, M_PI/2, 0.0);
constexpr float pi_2f = c10::pi<float> / 2;
const __m256 pi_2 = _mm256_setr_ps(pi_2f, 0.0, pi_2f, 0.0, pi_2f, 0.0, pi_2f, 0.0);
return _mm256_sub_ps(pi_2, asin());
}
Vec256<c10::complex<float>> atan() const;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ template <> class Vec256<double> {
const auto nan_vec = _mm256_set1_pd(NAN);
const auto not_nan_mask = _mm256_cmp_pd(values, values, _CMP_EQ_OQ);
const auto nan_mask = _mm256_cmp_pd(not_nan_mask, zero_vec, _CMP_EQ_OQ);
const auto pi = _mm256_set1_pd(M_PI);
const auto pi = _mm256_set1_pd(c10::pi<double>);

const auto neg_mask = _mm256_cmp_pd(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm256_blendv_pd(zero_vec, pi, neg_mask);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ template <> class Vec256<float> {
const auto nan_vec = _mm256_set1_ps(NAN);
const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ);
const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ);
const auto pi = _mm256_set1_ps(M_PI);
const auto pi = _mm256_set1_ps(c10::pi<float>);

const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_qint.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include <ATen/cpu/vec256/intrinsics.h>
#include <ATen/cpu/vec256/vec256_base.h>
#include <ATen/native/quantized/affine_quantizer.h>
#include <ATen/native/quantized/affine_quantizer_base.h>
#include <c10/util/qint32.h>
#include <c10/util/qint8.h>
#include <c10/util/quint8.h>
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) {
std::pair<uint64_t, uint64_t> CUDAGeneratorImpl::philox_engine_inputs(uint64_t increment) {
at::cuda::assertNotCapturing("Refactor this op to use CUDAGeneratorImpl::philox_cuda_state. "
"Cannot call CUDAGeneratorImpl::philox_engine_inputs");
// rounds increment up to the nearest multiple of 4
increment = ((increment + 3) / 4) * 4;
// see Note [Why enforce RNG offset % 4 == 0?]
TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0);
uint64_t offset = this->philox_offset_per_thread_;
Expand Down

0 comments on commit c0604c8

Please sign in to comment.