Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a76fe8d
implement cpu_kernel_multiple_outputs to support returning multiple v…
RockingJavaBean Jan 26, 2021
02cf0ce
Merge branch 'master' of https://github.com/pytorch/pytorch into cpu_…
RockingJavaBean Feb 1, 2021
dbc385b
add torch.frexp using cpu_kernel_multiple_outputs and gpu_kernel_mult…
RockingJavaBean Feb 4, 2021
139b6b5
Merge branch 'master' of https://github.com/pytorch/pytorch into cpu_…
RockingJavaBean Feb 4, 2021
0197e6d
fix
RockingJavaBean Feb 4, 2021
9329cda
CI fix and add multiple outputs support for UnaryUfuncInfo
RockingJavaBean Feb 5, 2021
b66297b
Merge branch 'master' of https://github.com/pytorch/pytorch into cpu_…
RockingJavaBean Feb 5, 2021
ee3ef78
fix flake8
RockingJavaBean Feb 5, 2021
2cbc4a8
fix test_native_functions_yaml and try resolve ROCM issues
RockingJavaBean Feb 7, 2021
e66f752
Merge branch 'master' of https://github.com/pytorch/pytorch into cpu_…
RockingJavaBean Feb 7, 2021
4daccb2
disable frexp on ROCM
RockingJavaBean Feb 7, 2021
84d892a
Merge branch 'master' of https://github.com/pytorch/pytorch into cpu_…
RockingJavaBean Feb 8, 2021
a408087
skip CUDA tests for ROCM platform
RockingJavaBean Feb 8, 2021
ea487d1
Merge branch 'master' of https://github.com/pytorch/pytorch into cpu_…
RockingJavaBean Feb 11, 2021
c5ce1bf
update torch.frexp to return mantissa, exponent as a tuple of (scalar…
RockingJavaBean Feb 23, 2021
99e62bc
Merge branch 'master' of https://github.com/pytorch/pytorch into cpu_…
RockingJavaBean Feb 22, 2021
945a490
fix flake8 issues
RockingJavaBean Feb 22, 2021
ff5972a
Merge branch 'master' of https://github.com/pytorch/pytorch into cpu_…
RockingJavaBean Feb 23, 2021
1ba836d
fix Windows CI
RockingJavaBean Feb 23, 2021
e1698c5
Merge branch 'master' of https://github.com/pytorch/pytorch into cpu_…
RockingJavaBean Feb 25, 2021
91198f6
Merge branch 'master' of https://github.com/pytorch/pytorch into cpu_…
RockingJavaBean Feb 26, 2021
d975938
update the UnaryUfuncInfo of torch.frexp
RockingJavaBean Feb 26, 2021
f46b760
fix test_reference_numerics_hard for torch.half
RockingJavaBean Feb 26, 2021
54472cc
Merge branch 'master' of https://github.com/pytorch/pytorch into cpu_…
RockingJavaBean Mar 8, 2021
52abf28
updates according to code reviews
RockingJavaBean Mar 9, 2021
b20c25a
add comments according to review
RockingJavaBean Mar 10, 2021
f74b527
Merge branch 'master' of https://github.com/pytorch/pytorch into cpu_…
RockingJavaBean Mar 10, 2021
742a99f
updates
RockingJavaBean Mar 10, 2021
f34c685
improve test
RockingJavaBean Mar 10, 2021
c8198e1
Merge branch 'master' of https://github.com/pytorch/pytorch into cpu_…
RockingJavaBean Mar 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ _(aten, diagflat) \
_(aten, diagonal) \
_(aten, fill_diagonal_) \
_(aten, diff) \
_(aten, frexp) \
_(aten, digamma) \
_(aten, dim) \
_(aten, dist) \
Expand Down
35 changes: 35 additions & 0 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,40 @@ Tensor& lgamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_fl
Tensor lgamma(const Tensor& self) { return unary_op_impl_float(self, lgamma_stub); }
Tensor& lgamma_(Tensor& self) { return unary_op_impl_(self, at::lgamma_out); }

std::tuple<Tensor, Tensor> frexp(const Tensor& self) {
Tensor mantissa = at::empty_like(self);
Tensor exponent = at::empty_like(self, self.options().dtype(at::kInt));

at::frexp_out(mantissa, exponent, self);
return std::tuple<Tensor, Tensor>(mantissa, exponent);
}

std::tuple<Tensor&, Tensor&> frexp_out(const Tensor& self,
Tensor& mantissa, Tensor& exponent) {
// torch.frexp is implemented for floating-point dtypes for now,
// should add support for integral dtypes in the future.
TORCH_CHECK(at::isFloatingType(self.scalar_type()),
"torch.frexp() only supports floating-point dtypes");

TORCH_CHECK(mantissa.dtype() == self.dtype(),
"torch.frexp() expects mantissa to have dtype ", self.dtype(),
" but got ", mantissa.dtype());
TORCH_CHECK(exponent.dtype() == at::kInt,
"torch.frexp() expects exponent to have int dtype "
"but got ", exponent.dtype());

auto iter = TensorIteratorConfig()
.add_output(mantissa)
.add_output(exponent)
.add_input(self)
.check_all_same_dtype(false)
.set_check_mem_overlap(true)
.build();
frexp_stub(iter.device_type(), iter);

return std::tuple<Tensor&, Tensor&>(mantissa, exponent);
}

// alias for lgamma, implements special.gammanln equivalent to
// scipy.special.gammaln
Tensor special_gammaln(const Tensor& self) { return self.lgamma(); }
Expand Down Expand Up @@ -712,6 +746,7 @@ DEFINE_DISPATCH(exp2_stub);
DEFINE_DISPATCH(expm1_stub);
DEFINE_DISPATCH(floor_stub);
DEFINE_DISPATCH(frac_stub);
DEFINE_DISPATCH(frexp_stub);
DEFINE_DISPATCH(i0_stub);
DEFINE_DISPATCH(log_stub);
DEFINE_DISPATCH(log10_stub);
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/UnaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ DECLARE_DISPATCH(unary_fn, exp2_stub);
DECLARE_DISPATCH(unary_fn, expm1_stub);
DECLARE_DISPATCH(unary_fn, floor_stub);
DECLARE_DISPATCH(unary_fn, frac_stub);
DECLARE_DISPATCH(unary_fn, frexp_stub);
DECLARE_DISPATCH(unary_fn, i0_stub);
DECLARE_DISPATCH(unary_fn, log_stub);
DECLARE_DISPATCH(unary_fn, log10_stub);
Expand Down
91 changes: 91 additions & 0 deletions aten/src/ATen/native/cpu/Loops.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,70 @@ basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_
execute_op(data, strides, i, n, std::forward<func_t>(op));
}

// the recursive variadic template for iterating over the returned tuple
template<class T, size_t N>
struct TupleOutput {
static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
const T &tuple) {
TupleOutput<T, N - 1>::handle(data, strides, i, tuple);

auto output = std::get<N - 1>(tuple);
using output_type = decltype(output);
output_type * out_ptr = (output_type *)(data[N - 1] + i * strides[N - 1]);
*out_ptr = output;
}
};

// Base case for the above recursive template
template<class T>
struct TupleOutput<T, 1> {
static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
const T &tuple) {
auto output = std::get<0>(tuple);
using output_type = decltype(output);
output_type* out_ptr = (output_type *)(data[0] + i * strides[0]);
*out_ptr = output;
}
};

template<class... Args>
void handle_tuple_outputs(char* C10_RESTRICT data[],
const int64_t* strides,
int64_t i,
const std::tuple<Args...> &tuple) {
TupleOutput<decltype(tuple), sizeof...(Args)>::handle(data, strides, i, tuple);
}

// Loop operation for `cpu_kernel_multiple_outputs`.
// 1. Use `c10::guts::apply` to make dynamic method invocation
// for the lambda passed in `cpu_kernel_multiple_outputs`.
// 2. Iterate over the members of the returned tuple, set the corresponding
// output tensor by the tuple member in `handle_tuple_outputs` function.
template <typename func_t>
static inline void
multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
using traits = function_traits<func_t>;

using result_type = typename traits::result_type;
constexpr int num_outputs = std::tuple_size<result_type>::value;
constexpr int ntensors = traits::arity + num_outputs;

// Copying strides to temporary array helps auto vectorization in older GCC
// versions.
int64_t strides[ntensors];
for (int arg = 0; arg < ntensors; arg++) {
strides[arg] = strides_[arg];
}

for (; i < n; i++) {
auto output = c10::guts::apply(op, dereference<traits>(
&data[num_outputs],
&strides[num_outputs],
i));
handle_tuple_outputs(data, strides, i, output);
}
}

// Explicitly vectorized loop implementation. All inputs and outputs must be
// the same type and contiguous with one exception: a single input may be
// a scalar (stride 0). It's position is indicated by the argument `S`. If `S`
Expand Down Expand Up @@ -206,6 +270,33 @@ void cpu_kernel(TensorIteratorBase& iter, func_t&& op) {
iter.cast_outputs();
}

// This function helps write elementwise kernels that requires multiple outputs.
// It follows the similar structure of cpu_kernel.
// Instead of `basic_loop` function, a new `multiple_outputs_loop` function is
// manipulated to handle multiple return values.
// For now `needs_dynamic_casting` check is not added as the passed lambda (`func_t`)
// of `multiple_outputs_loop` returns `std::tuple` instead of `scalar_t`.
// The `gpu_kernel_multiple_outputs` is also implemented without this check,
// We could extend `needs_dynamic_casting` to support both `std::tuple` and
// `thrust::tuple` in the future.
template <typename func_t>
void cpu_kernel_multiple_outputs(TensorIteratorBase& iter, func_t&& op) {
using traits = function_traits<func_t>;
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);

iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
if (is_contiguous<traits>(strides)) {
multiple_outputs_loop(data, strides, 0, n, std::forward<func_t>(op));
} else {
using Indices = std::make_index_sequence<traits::arity>;
unroll_contiguous_scalar_checks<traits>(strides, Indices{}, [&](size_t _idx) {
multiple_outputs_loop(data, strides, 0, n, std::forward<func_t>(op));
});
}
});
iter.cast_outputs();
}

template <bool check_dynamic_cast=true, typename func_t, typename vec_func_t>
void cpu_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop) {
using traits = function_traits<func_t>;
Expand Down
18 changes: 18 additions & 0 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,23 @@ static void rsqrt_kernel(TensorIterator& iter) {
});
}

static void frexp_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND(kHalf,
// The iter.dtype() here is the dtype of mantissa output.
// It's a floating point type and must be the same as the input's dtype.
iter.dtype(),
"frexp_cpu", [&]() {
cpu_kernel_multiple_outputs(
iter,
[](scalar_t a) -> std::tuple<scalar_t, int32_t> {
int32_t exponent;
scalar_t mantissa = std::frexp(a, &exponent);
return std::tuple<scalar_t, int32_t>(mantissa, exponent);
}
);
});
}

// TODO: Disable cont. branch to test more risky code

#define IMPLEMENT_ITERATOR_LAMBDA(op) \
Expand Down Expand Up @@ -701,6 +718,7 @@ REGISTER_DISPATCH(clamp_stub, &clamp_kernel);
REGISTER_DISPATCH(clamp_max_stub, &clamp_max_kernel);
REGISTER_DISPATCH(clamp_min_stub, &clamp_min_kernel);
REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel)
REGISTER_DISPATCH(frexp_stub, &frexp_kernel)


IMPLEMENT_COMPLEX_KERNEL(acos)
Expand Down
22 changes: 22 additions & 0 deletions aten/src/ATen/native/cuda/UnaryOpsKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,27 @@ void kaiser_window_kernel_cuda(TensorIterator& iter, int64_t window_length, doub
});
}

void frexp_kernel_cuda(TensorIterator& iter) {
#ifdef __HIP_PLATFORM_HCC__
// Reference: https://rocmdocs.amd.com/en/latest/ROCm_API_References/HIP-MATH.html
// https://github.com/ROCm-Developer-Tools/HIP/issues/2169
// ROCm does not support frexp function yet
TORCH_CHECK(false, "torch.frexp() is not implemented on ROCm platform.");
#else
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half,
// The iter.dtype() here is the dtype of mantissa output.
// It's a floating point type and must be the same as the input's dtype.
iter.dtype(),
"frexp_cuda", [&]() {
gpu_kernel_multiple_outputs(iter, [=] GPU_LAMBDA (scalar_t a) -> thrust::tuple<scalar_t, int32_t> {
int32_t exponent;
scalar_t mantissa = std::frexp(a, &exponent);
return {mantissa, exponent};
});
});
#endif
}

REGISTER_DISPATCH(bitwise_not_stub, &bitwise_not_kernel_cuda);
REGISTER_DISPATCH(exp_stub, &exp_kernel_cuda);
REGISTER_DISPATCH(exp2_stub, &exp2_kernel_cuda);
Expand All @@ -270,6 +291,7 @@ REGISTER_DISPATCH(clamp_min_stub, &clamp_min_kernel_cuda);
REGISTER_DISPATCH(clamp_max_stub, &clamp_max_kernel_cuda);
REGISTER_DISPATCH(nan_to_num_stub, &nan_to_num_kernel_cuda);
REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel_cuda);
REGISTER_DISPATCH(frexp_stub, &frexp_kernel_cuda);

} // namespace native
} // namespace at
9 changes: 9 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4209,6 +4209,15 @@
- func: norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures

- func: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent)
variants: method, function
dispatch:
DefaultBackend: frexp

- func: frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent)
dispatch:
DefaultBackend: frexp_out

- func: frobenius_norm(Tensor self) -> Tensor
variants: function

Expand Down
24 changes: 24 additions & 0 deletions aten/src/ATen/test/tensor_iterator_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,27 @@ TEST(TensorIteratorTest, FailNonPromotingBinaryOp) {
config.add_input(at::ones({1,1}, at::dtype(at::kInt)));
ASSERT_ANY_THROW(config.build());
}

#define MULTIPLE_OUTPUTS_TEST_ITER_FOR_TYPE(ctype,name) \
TEST(TensorIteratorTest, CpuKernelMultipleOutputs_##name) { \
auto in1 = random_tensor_for_type(k##name); \
auto in2 = random_tensor_for_type(k##name); \
Tensor out1 = at::empty({0}, in1.options()); \
Tensor out2 = at::empty({0}, in1.options()); \
auto expected1 = in1.add(in2); \
auto expected2 = in1.mul(in2); \
auto iter = at::TensorIteratorConfig() \
.add_output(out1) \
.add_output(out2) \
.add_input(in1) \
.add_input(in2) \
.build(); \
at::native::cpu_kernel_multiple_outputs(iter, [=](ctype a, ctype b) -> std::tuple<ctype, ctype> { \
ctype add = a + b; \
ctype mul = a * b; \
return std::tuple<ctype, ctype>(add, mul); \
}); \
EXPECT_TRUE(out1.equal(expected1)); \
EXPECT_TRUE(out2.equal(expected2)); \
}
AT_FORALL_SCALAR_TYPES(MULTIPLE_OUTPUTS_TEST_ITER_FOR_TYPE)
1 change: 1 addition & 0 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: fmod_
.. automethod:: frac
.. automethod:: frac_
.. automethod:: frexp
.. automethod:: gather
.. automethod:: gcd
.. automethod:: gcd_
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ Pointwise Ops
floor_divide
fmod
frac
frexp
imag
ldexp
lerp
Expand Down
4 changes: 3 additions & 1 deletion test/test_namedtuple_return_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq',
'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "_unpack_dual", 'linalg_qr',
'_svd_helper', 'linalg_svd', 'linalg_slogdet', 'fake_quantize_per_tensor_affine_cachemask',
'fake_quantize_per_channel_affine_cachemask', 'linalg_lstsq'
'fake_quantize_per_channel_affine_cachemask', 'linalg_lstsq',
'frexp'
}


Expand Down Expand Up @@ -78,6 +79,7 @@ def test_namedtuple_return(self):
names=('output', 'mask',), hasout=False),
op(operators=['_unpack_dual'], input=(0,), names=('primal', 'tangent'), hasout=False),
op(operators=['linalg_lstsq'], input=(a,), names=('solution', 'residuals', 'rank', 'singular_values'), hasout=False),
op(operators=['frexp'], input=(), names=('mantissa', 'exponent'), hasout=True),
]

def get_func(f):
Expand Down
Loading