Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ScalarTensor or 0dim overload for _foreach_add #111079

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 38 additions & 0 deletions aten/src/ATen/native/ForeachOpsKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,43 @@ namespace at::native {
return result; \
}

#define FOREACH_BINARY_OP_TENSOR_ALPHA(OP) \
void foreach_tensor_##OP##_tensor_kernel_slow_( \
TensorList tensors, const Tensor& scalar, const Scalar& alpha) { \
TORCH_CHECK( \
scalar.dim() == 0 && scalar.numel() == 1, \
"scalar tensor expected to be 0 dim but it has ", \
scalar.dim(), \
" dimensions and ", \
scalar.numel(), \
" elements."); \
check_foreach_api_restrictions(tensors); \
\
for (auto& t : tensors) { \
t.OP##_(scalar, alpha); \
} \
} \
\
std::vector<Tensor> foreach_tensor_##OP##_tensor_kernel_slow( \
TensorList tensors, const Tensor& scalar, const Scalar& alpha) { \
TORCH_CHECK( \
scalar.dim() == 0 && scalar.numel() == 1, \
"scalar tensor expected to be 0 dim but it has ", \
scalar.dim(), \
" dimensions and ", \
scalar.numel(), \
" elements."); \
check_foreach_api_restrictions(tensors); \
\
std::vector<Tensor> result; \
result.reserve(tensors.size()); \
for (const auto& t : tensors) { \
result.emplace_back(t.OP(scalar, alpha)); \
} \
\
return result; \
}

#define FOREACH_BINARY_OP_SCALAR(OP) \
void foreach_tensor_##OP##_scalar_kernel_slow_( \
TensorList tensors, const Scalar& scalar) { \
Expand Down Expand Up @@ -295,6 +332,7 @@ FOREACH_BINARY_OP_LIST_ALPHA(add);
FOREACH_BINARY_OP_LIST_ALPHA(sub);
FOREACH_BINARY_OP_LIST_ALPHA(lerp);

FOREACH_BINARY_OP_TENSOR_ALPHA(add);
FOREACH_BINARY_OP_TENSOR(mul);

FOREACH_BINARY_OP_SCALAR(add);
Expand Down
57 changes: 49 additions & 8 deletions aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_foreach_add_native.h>
#include <ATen/ops/_foreach_mul_native.h>

#include <ATen/ops/empty_like_native.h>
Expand All @@ -18,7 +19,8 @@ namespace at::native {
template <typename T, template <class> class Op>
std::vector<Tensor> foreach_binary_op(
TensorList tensors,
const Tensor& scalar) {
const Tensor& scalar,
const Scalar& alpha = 1) {
TORCH_CHECK(
scalar.dim() == 0 && scalar.numel() == 1,
"scalar tensor expected to be 0 dim but it has ",
Expand Down Expand Up @@ -51,12 +53,16 @@ std::vector<Tensor> foreach_binary_op(
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Op<opmath_t>(),
scalar.data_ptr<T>());
scalar.data_ptr<T>(),
alpha.to<opmath_t>());
return tensor_lists[1];
}

template <typename T, template <class> class Op>
void foreach_binary_op_(TensorList tensors, const Tensor& scalar) {
void foreach_binary_op_(
TensorList tensors,
const Tensor& scalar,
const Scalar& alpha = 1) {
TORCH_CHECK(
scalar.dim() == 0 && scalar.numel() == 1,
"scalar tensor expected to be 0 dim but has ",
Expand All @@ -82,7 +88,8 @@ void foreach_binary_op_(TensorList tensors, const Tensor& scalar) {
/* r_args_depth */ 1,
/* res_arg_index */ 0>(),
Op<opmath_t>(),
scalar.data_ptr<T>());
scalar.data_ptr<T>(),
alpha.to<opmath_t>());
increment_version(tensors);
}

Expand Down Expand Up @@ -115,32 +122,66 @@ void foreach_binary_op_(TensorList tensors, const Tensor& scalar) {
return FUNCTION<OP>(tensors, scalar); \
}

#define FOREACH_BINARY_OP_SCALAR_TENSOR_ALPHA(FUNCTION, NAME, OP) \
void foreach_tensor_##NAME##_tensor_kernel_cuda_( \
TensorList tensors, const Tensor& scalar, const Scalar& alpha) { \
check_foreach_api_restrictions(tensors); \
if (!(can_use_fast_route(ArrayRef<TensorList>{tensors}, alpha) && \
tensors[0].scalar_type() == scalar.scalar_type())) { \
return at::native::foreach_tensor_##NAME##_tensor_kernel_slow_( \
tensors, scalar, alpha); \
} \
\
FUNCTION##_<OP>(tensors, scalar, alpha); \
} \
\
std::vector<Tensor> foreach_tensor_##NAME##_tensor_kernel_cuda( \
TensorList tensors, const Tensor& scalar, const Scalar& alpha) { \
check_foreach_api_restrictions(tensors); \
if (!(can_use_fast_route(ArrayRef<TensorList>{tensors}, alpha) && \
tensors[0].scalar_type() == scalar.scalar_type())) { \
return at::native::foreach_tensor_##NAME##_tensor_kernel_slow( \
tensors, scalar, alpha); \
} \
\
return FUNCTION<OP>(tensors, scalar, alpha); \
}

template <template <class> class Op>
std::vector<Tensor> all_types_complex_bool_half_bfloat16(
TensorList tensors,
const Tensor& scalar) {
const Tensor& scalar,
const Scalar& alpha = 1) {
return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool,
kHalf,
kBFloat16,
tensors[0].scalar_type(),
"foreach_binary_op_scalar_cuda",
[&]() { return foreach_binary_op<scalar_t, Op>(tensors, scalar); });
[&]() {
return foreach_binary_op<scalar_t, Op>(tensors, scalar, alpha);
});
}

template <template <class> class Op>
void all_types_complex_bool_half_bfloat16_(
TensorList tensors,
const Tensor& scalar) {
const Tensor& scalar,
const Scalar& alpha = 1) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool,
kHalf,
kBFloat16,
tensors[0].scalar_type(),
"foreach_binary_op_scalar_cuda_",
[&]() { foreach_binary_op_<scalar_t, Op>(tensors, scalar); });
[&]() { foreach_binary_op_<scalar_t, Op>(tensors, scalar, alpha); });
}

FOREACH_BINARY_OP_SCALAR_TENSOR_ALPHA(
all_types_complex_bool_half_bfloat16,
add,
std::plus);

FOREACH_BINARY_OP_SCALAR_TENSOR(
all_types_complex_bool_half_bfloat16,
mul,
Expand Down
79 changes: 33 additions & 46 deletions aten/src/ATen/native/cuda/ForeachFunctors.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -152,49 +152,6 @@ __device__ __forceinline__ void binary_op_scalar(
}
}

template <int res_arg_index, typename Op, typename T, typename scalar_t = T>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up inlining this change due to not being able to specify that scalar is of type opmath_t. I cannot figure out where to put using opmath_t = at::opmath_type<T>; before the function signature but after the template specification.

__device__ __forceinline__ void binary_op_scalar_tensor(
T r_args[][kILP],
T** args,
scalar_t* scalar,
const int64_t n,
const int chunk_size,
const bool all_aligned,
Op op) {
using opmath_t = at::opmath_type<T>;
// to make things simple, we put aligned case in a different code path
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
i_start * kILP < n && i_start * kILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_args[0], args[0], 0, i_start);
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = static_cast<T>(
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(*scalar)));
}
// store
load_store(args[res_arg_index], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * kILP) {
// Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args
// has depth 1
load_args<1>(r_args, args, i_start, chunk_size, n);
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = static_cast<T>(
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(*scalar)));
}
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}

template <int res_arg_index, typename Op, typename T, typename opmath_t>
__device__ __forceinline__ void pointwise_op_scalar(
T r_args[][kILP],
Expand Down Expand Up @@ -354,7 +311,8 @@ struct BinaryOpScalarTensorFunctor {
int chunk_size,
TensorListMetadata<depth>& tl,
Op op,
T* scalar) {
T* scalar,
opmath_t alpha) {
const int tensor_loc = tl.block_to_tensor[blockIdx.x];
const int chunk_idx = tl.block_to_chunk[blockIdx.x];
auto n = tl.numel_for_tensor[tensor_loc];
Expand All @@ -365,8 +323,37 @@ struct BinaryOpScalarTensorFunctor {
n -= chunk_idx * chunk_size;
T r_args[r_args_depth][kILP];

binary_op_scalar_tensor<res_arg_index>(
r_args, args, scalar, n, chunk_size, all_aligned, op);
// to make things simple, we put aligned case in a different code path
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
i_start * kILP < n && i_start * kILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_args[0], args[0], 0, i_start);
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = static_cast<T>(op(
static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
}
// store
load_store(args[res_arg_index], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * kILP) {
// Regardless if depth is 1 (for inplace) or 2 (for out of place),
// r_args has depth 1
load_args<1>(r_args, args, i_start, chunk_size, n);
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = static_cast<T>(op(
static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
}
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
};

Expand Down
15 changes: 15 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10117,6 +10117,21 @@
CUDA: foreach_tensor_add_scalarlist_kernel_cuda_
autogen: _foreach_add.ScalarList_out

- func: _foreach_add.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
CPU: foreach_tensor_add_tensor_kernel_slow
CUDA: foreach_tensor_add_tensor_kernel_cuda

- func: _foreach_add_.Tensor(Tensor(a!)[] self, Tensor other, *, Scalar alpha=1) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
CPU: foreach_tensor_add_tensor_kernel_slow_
CUDA: foreach_tensor_add_tensor_kernel_cuda_
autogen: _foreach_add.Tensor_out

- func: _foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
Expand Down
3 changes: 3 additions & 0 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,12 @@ aten::_foreach_add.Scalar
aten::_foreach_add.ScalarList
aten::_foreach_add.ScalarList_out
aten::_foreach_add.Scalar_out
aten::_foreach_add.Tensor
aten::_foreach_add.Tensor_out
aten::_foreach_add_.List
aten::_foreach_add_.Scalar
aten::_foreach_add_.ScalarList
aten::_foreach_add_.Tensor
aten::_foreach_addcdiv.Scalar
aten::_foreach_addcdiv.ScalarList
aten::_foreach_addcdiv.ScalarList_out
Expand Down
5 changes: 2 additions & 3 deletions test/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,15 +778,14 @@ def test_tensors_grouping(self):
def test_0dim_tensor_overload_exception(self):
# check exceptions of fast path
tensors = [make_tensor((2, 2), dtype=torch.float, device="cuda") for _ in range(2)]

with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be 0 dim but"):
Copy link
Contributor Author

@janeyx99 janeyx99 Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test looks to be already tested when d="cuda" in the next few lines

torch._foreach_mul(tensors, torch.tensor([1.0, 1.0], device="cuda"))
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be on"):
torch._foreach_mul(tensors, torch.tensor(1.0, device="cpu"))

tensors = [make_tensor((2, 2), dtype=torch.float, device=d) for d in ("cpu", "cuda")]
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be 0 dim but"):
torch._foreach_mul(tensors, torch.tensor([1.0, 1.0], device="cuda"))
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be 0 dim but"):
torch._foreach_add(tensors, torch.tensor([1.0, 1.0], device="cuda"))

@onlyCUDA
@ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9064,7 +9064,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
foreach_binary_op_db: List[OpInfo] = [
ForeachFuncInfo(
"add",
foreach_inputs_sample_func(2, True, True),
foreach_inputs_sample_func(2, True, True, True),
janeyx99 marked this conversation as resolved.
Show resolved Hide resolved
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
supports_alpha_param=True,
Expand Down
5 changes: 4 additions & 1 deletion torchgen/api/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,10 @@ def is_foreach_func(f: NativeFunction) -> bool:
# is functional for their backward derivatives (and forward derivatives in the future), i.e.,
# they would find such one in `functional_info_by_signature`. There however are some exceptions:
_foreach_with_inplace_ref = {"_foreach_zero_"}
_foreach_with_tensor_overload = {"_foreach_mul.Tensor"}
_foreach_with_tensor_overload = {
"_foreach_add.Tensor",
"_foreach_mul.Tensor",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we already had this code for mul, why is it so much work to add support for add?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

}


# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function
Expand Down