Skip to content

Commit

Permalink
Merge branch 'master' into develop/numpy/unary-float-op/erfinv
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed Dec 23, 2020
2 parents 08a59d7 + 68d438c commit c3166a1
Show file tree
Hide file tree
Showing 151 changed files with 2,976 additions and 797 deletions.
2 changes: 1 addition & 1 deletion .circleci/cimodel/data/dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
]

ROCM_VERSIONS = [
"3.9",
"3.10",
"4.0",
]

ROCM_VERSION_LABELS = ["rocm" + v for v in ROCM_VERSIONS]
Expand Down
208 changes: 104 additions & 104 deletions .circleci/config.yml

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,6 @@ filegroup(
filegroup(
name = "thc_srcs_cu",
srcs = [
"aten/src/THC/THCBlas.cu.cc",
"aten/src/THC/THCReduceApplyUtils.cu.cc",
"aten/src/THC/THCSleep.cu.cc",
"aten/src/THC/THCSortUtils.cu.cc",
Expand Down
1 change: 0 additions & 1 deletion android/gradle/android_tasks.gradle
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import java.nio.file.Files
import java.nio.file.Paths
import java.io.FileOutputStream
Expand Down
1 change: 0 additions & 1 deletion android/pytorch_android/host/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,3 @@ dependencies {
}

apply from: rootProject.file('gradle/release.gradle')

1 change: 0 additions & 1 deletion android/settings.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@ project(':pytorch_android_torchvision').projectDir = file('pytorch_android_torch

project(':pytorch_host').projectDir = file('pytorch_android/host')
project(':test_app').projectDir = file('test_app/app')

3 changes: 3 additions & 0 deletions aten/src/ATen/MemoryOverlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ MemOverlapStatus get_overlap_status(TensorImpl* a, TensorImpl* b) {
if (!a->is_contiguous() || !b->is_contiguous()) {
return MemOverlapStatus::TOO_HARD;
}
if (!a->has_storage() || !b->has_storage()) {
return MemOverlapStatus::NO;
}
if (a->storage().data() == b->storage().data()) {
const auto a_begin = static_cast<char*>(a->data());
const auto a_end = a_begin + a->numel() * a->itemsize();
Expand Down
14 changes: 7 additions & 7 deletions aten/src/ATen/SparseTensorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,17 @@ Tensor flatten_indices_by_dims(const Tensor& indices, const IntArrayRef& sizes,
}

Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz) {
/*
Find the CSR representation for a row `indices` from the COO format
/*
Find the CSR representation for a row `indices` from the COO format
Inputs:
`indices` is the row pointer from COO indices
`dim` is the row dimensionality
`nnz` is the number of non-zeros
Output:
`dim` is the row dimensionality
`nnz` is the number of non-zeros
Output:
`csr` is a compressed row array in a CSR format
*/
Tensor csr = native::zeros({dim + 1}, kLong);
Tensor csr = at::zeros({dim + 1}, kLong);

// TODO: eliminate this conditional when zero-size dims supported correctly
if (nnz > 0) {
Expand Down
19 changes: 19 additions & 0 deletions aten/src/ATen/TensorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,25 @@ std::ostream& operator<<(std::ostream & out, TensorGeometryArg t) {
return out;
}

void checkDim(
CheckedFrom c,
const Tensor& tensor,
const char* name,
int pos, // 1-indexed
int64_t dim) {
TORCH_CHECK(
tensor.dim() == dim,
"Expected ",
dim,
"-dimensional tensor, but got ",
tensor.dim(),
"-dimensional tensor for ",
TensorGeometryArg(TensorArg({tensor, name, pos})),
" (while checking arguments for ",
c,
")");
}

void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim) {
TORCH_CHECK(t->dim() == dim,
"Expected ", dim, "-dimensional tensor, but got ", t->dim(),
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/TensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ using CheckedFrom = const char*;
// conversion will blow up if you have undefined tensors.

TORCH_API std::ostream& operator<<(std::ostream& out, TensorGeometryArg t);
TORCH_API void checkDim(
CheckedFrom c,
const Tensor& tensor,
const char* name,
int pos, // 1-indexed
int64_t dim);
TORCH_API void checkDim(
CheckedFrom c,
const TensorGeometryArg& t,
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/core/Formatting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,11 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
stream << ", axis: " << tensor_.q_per_channel_axis();
}
}

auto& fw_grad = tensor.fw_grad(/* level */ 0);
if (fw_grad.defined()) {
stream << ", tangent:" << std::endl << fw_grad;
}
stream << " ]";
}
return stream;
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/NamedRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,4 +510,5 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("_version", CppFunction::makeFallthrough());
m.impl("requires_grad_", CppFunction::makeFallthrough());
m.impl("retain_grad", CppFunction::makeFallthrough());
m.impl("_fw_primal", CppFunction::makeFallthrough());
}
2 changes: 2 additions & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ _(aten, logdet) \
_(aten, logit) \
_(aten, logspace) \
_(aten, logsumexp) \
_(aten, xlogy) \
_(aten, lstm) \
_(aten, lstm_cell) \
_(aten, lstsq) \
Expand Down Expand Up @@ -552,6 +553,7 @@ _(aten, permute) \
_(aten, pin_memory) \
_(aten, pinverse) \
_(aten, pixel_shuffle) \
_(aten, pixel_unshuffle) \
_(aten, poisson) \
_(aten, polygamma) \
_(aten, pow) \
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ struct Vec256 {
Vec256<T> angle() const {
// other_t_angle is for SFINAE and clarity. Make sure it is not changed.
static_assert(std::is_same<other_t_angle, T>::value, "other_t_angle must be T");
return Vec256(0);
return map(at::native::angle_impl<T>); // compiler is unable to resolve the overload without <T>
}
template <typename complex_t_angle = T,
typename std::enable_if<c10::is_complex<complex_t_angle>::value, int>::type = 0>
Expand Down
18 changes: 17 additions & 1 deletion aten/src/ATen/cpu/vec256/vec256_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,23 @@ template <> class Vec256<BFloat16> {
return cvtfp32_bf16(o1, o2);
}
Vec256<BFloat16> angle() const {
return _mm256_set1_epi16(0);
__m256 lo, hi;
cvtbf16_fp32(values, lo, hi);
auto angle_lambda = [](__m256 values) {
const auto zero_vec = _mm256_set1_ps(0.f);
const auto nan_vec = _mm256_set1_ps(NAN);
const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ);
const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ);
const auto pi = _mm256_set1_ps(M_PI);

const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask);
angle = _mm256_blendv_ps(angle, nan_vec, nan_mask);
return angle;
};
auto o1 = angle_lambda(lo);
auto o2 = angle_lambda(hi);
return cvtfp32_bf16(o1, o2);
}
Vec256<BFloat16> real() const {
return *this;
Expand Down
11 changes: 10 additions & 1 deletion aten/src/ATen/cpu/vec256/vec256_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,16 @@ template <> class Vec256<double> {
return _mm256_andnot_pd(mask, values);
}
Vec256<double> angle() const {
return _mm256_set1_pd(0);
const auto zero_vec = _mm256_set1_pd(0.f);
const auto nan_vec = _mm256_set1_pd(NAN);
const auto not_nan_mask = _mm256_cmp_pd(values, values, _CMP_EQ_OQ);
const auto nan_mask = _mm256_cmp_pd(not_nan_mask, zero_vec, _CMP_EQ_OQ);
const auto pi = _mm256_set1_pd(M_PI);

const auto neg_mask = _mm256_cmp_pd(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm256_blendv_pd(zero_vec, pi, neg_mask);
angle = _mm256_blendv_pd(angle, nan_vec, nan_mask);
return angle;
}
Vec256<double> real() const {
return *this;
Expand Down
11 changes: 10 additions & 1 deletion aten/src/ATen/cpu/vec256/vec256_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,16 @@ template <> class Vec256<float> {
return _mm256_andnot_ps(mask, values);
}
Vec256<float> angle() const {
return _mm256_set1_ps(0);
const auto zero_vec = _mm256_set1_ps(0.f);
const auto nan_vec = _mm256_set1_ps(NAN);
const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ);
const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ);
const auto pi = _mm256_set1_ps(M_PI);

const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask);
angle = _mm256_blendv_ps(angle, nan_vec, nan_mask);
return angle;
}
Vec256<float> real() const {
return *this;
Expand Down
12 changes: 0 additions & 12 deletions aten/src/ATen/cpu/vec256/vec256_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,6 @@ class Vec256<int64_t> : public Vec256i {
auto inverse = _mm256_xor_si256(values, is_larger);
return _mm256_sub_epi64(inverse, is_larger);
}
Vec256<int64_t> angle() const {
return _mm256_set1_epi64x(0);
}
Vec256<int64_t> real() const {
return *this;
}
Expand Down Expand Up @@ -250,9 +247,6 @@ class Vec256<int32_t> : public Vec256i {
Vec256<int32_t> abs() const {
return _mm256_abs_epi32(values);
}
Vec256<int32_t> angle() const {
return _mm256_set1_epi32(0);
}
Vec256<int32_t> real() const {
return *this;
}
Expand Down Expand Up @@ -467,9 +461,6 @@ class Vec256<int16_t> : public Vec256i {
Vec256<int16_t> abs() const {
return _mm256_abs_epi16(values);
}
Vec256<int16_t> angle() const {
return _mm256_set1_epi16(0);
}
Vec256<int16_t> real() const {
return *this;
}
Expand Down Expand Up @@ -719,9 +710,6 @@ class Vec256<int8_t> : public Vec256i {
Vec256<int8_t> abs() const {
return _mm256_abs_epi8(values);
}
Vec256<int8_t> angle() const {
return _mm256_set1_epi8(0);
}
Vec256<int8_t> real() const {
return *this;
}
Expand Down
27 changes: 27 additions & 0 deletions aten/src/ATen/native/AutogradComposite.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include <ATen/ATen.h>

namespace at {
namespace native {

/// This function can be used to create a dual Tensor that holds a tangent to compute forward mode gradients.
/// Note that the dual Tensor's primal is a view of the given primal and the given tangent is used as-is.
/// This function is backward differentiable.
at::Tensor make_dual(const at::Tensor& primal, const at::Tensor& tangent, int64_t level) {
TORCH_CHECK(!primal.fw_grad(level).defined(), "Making a dual Tensor based on a Tensor that "
"already has a forward gradient at the same level ", level, " is not supported.");

auto dual_tensor = primal.view(primal.sizes());
dual_tensor.set_fw_grad(tangent, level, /* is_inplace_op */ false);
return dual_tensor;
}

/// This function can be used to unpack a given dual Tensor to get its primal and tangent. The returned primal
/// is a view of the dual and the tangent is returned as is.
/// This function is backward differentiable.
std::tuple<at::Tensor, at::Tensor> unpack_dual(const at::Tensor& tensor, int64_t level) {
return std::tuple<at::Tensor, at::Tensor>(tensor._fw_primal(level), tensor.fw_grad(level));
}

} // namespace native

} // namespace at
38 changes: 38 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ DEFINE_DISPATCH(igammac_stub);
DEFINE_DISPATCH(nextafter_stub);
DEFINE_DISPATCH(heaviside_stub);
DEFINE_DISPATCH(copysign_stub);
DEFINE_DISPATCH(xlogy_stub);

static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
Expand Down Expand Up @@ -1101,5 +1102,42 @@ Tensor& ldexp_(Tensor& self, const Tensor& other) {
return at::ldexp_out(self, self, other);
}

Tensor& xlogy_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_float_op(result, self, other);
xlogy_stub(iter.device_type(), iter);
return result;
}

Tensor& xlogy_out(Tensor& result, Scalar self, const Tensor& other) {
return at::xlogy_out(result, c10::scalar_to_tensor(self, other.device()), other);
}

Tensor& xlogy_out(Tensor& result, const Tensor& self, Scalar other) {
return at::xlogy_out(result, self, c10::scalar_to_tensor(other, self.device()));
}

Tensor xlogy(const Tensor& x, const Tensor& y) {
Tensor result;
auto iter = TensorIterator::binary_float_op(result, x, y);
xlogy_stub(iter.device_type(), iter);
return iter.output();
}

Tensor xlogy(Scalar x, const Tensor& y) {
return at::xlogy(c10::scalar_to_tensor(x, y.device()), y);
}

Tensor xlogy(const Tensor& x, Scalar y) {
return at::xlogy(x, c10::scalar_to_tensor(y, x.device()));
}

Tensor& xlogy_(Tensor& x, const Tensor& y) {
return at::xlogy_out(x, x, y);
}

Tensor& xlogy_(Tensor& x, Scalar y) {
return at::xlogy_out(x, x, c10::scalar_to_tensor(y, x.device()));
}

} // namespace native
} // namespace at
1 change: 1 addition & 0 deletions aten/src/ATen/native/BinaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,6 @@ DECLARE_DISPATCH(binary_fn, igammac_stub);
DECLARE_DISPATCH(binary_fn, nextafter_stub);
DECLARE_DISPATCH(binary_fn, heaviside_stub);
DECLARE_DISPATCH(binary_fn, copysign_stub);
DECLARE_DISPATCH(binary_fn, xlogy_stub);

}} // namespace at::native

0 comments on commit c3166a1

Please sign in to comment.