-
Notifications
You must be signed in to change notification settings - Fork 24.5k
[torch][segment_reduce] Support for multi dimension (cpu only) #59951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit 6dc7c90 (more details on the Dr. CI page):
1 failure not recognized by patterns:
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
This pull request was exported from Phabricator. Differential Revision: D29105457 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is becoming quite complicated, and yet it can only handle reduction over 0-th axis for a contiguous tensor, and cannot easily be extended to handle other axes. At this point, the implementation should be switched to TensorIterator. Please see e.g. how tensorIterator is handling max
along dimension
pytorch/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
Lines 20 to 125 in 8e92a3a
template <typename scalar_t, typename scalar_t_2 = int64_t, typename loop1d_t> | |
static inline void compare_base_kernel_core( | |
Tensor& result1, | |
Tensor& result2, | |
const Tensor& self, | |
int64_t dim, | |
bool keepdim, | |
const loop1d_t& loop) { | |
auto self_sizes = ensure_nonempty_vec(self.sizes().vec()); | |
self_sizes[dim] = 1; | |
// result1 and result2 may be a empty tensor, if not, | |
// reshape them as self dims | |
if (!keepdim) { | |
if (result1.ndimension() >= dim) { | |
result1.unsqueeze_(dim); | |
} | |
if (result2.ndimension() >= dim) { | |
result2.unsqueeze_(dim); | |
} | |
} | |
at::native::resize_output(result1, self_sizes); | |
at::native::resize_output(result2, self_sizes); | |
auto iter = TensorIteratorConfig() | |
.check_all_same_dtype(false) | |
.resize_outputs(false) | |
.declare_static_shape(self.sizes(), /*squash_dims=*/dim) | |
.add_output(result1) | |
.add_output(result2) | |
.add_input(self) | |
.build(); | |
iter.for_each(loop, /* grain_size */ 1); | |
if (!keepdim) { | |
result1.squeeze_(dim); | |
result2.squeeze_(dim); | |
} | |
} | |
template <typename scalar_t, typename scalar_t_2=int64_t, typename func_t> | |
static inline void compare_base_kernel(Tensor& result1, Tensor& result2, | |
const Tensor& self, | |
int64_t dim, | |
bool keepdim, | |
const func_t& f) { | |
auto self_dim_stride = ensure_nonempty_stride(self, dim); | |
auto loop = [&](char** data, const int64_t* strides, int64_t n) { | |
auto* result1_data_bytes = data[0]; | |
auto* result2_data_bytes = data[1]; | |
const auto* self_data_bytes = data[2]; | |
for (int64_t i = 0; i < n; ++i) { | |
f((scalar_t*)result1_data_bytes, | |
(scalar_t_2*)result2_data_bytes, | |
(scalar_t*)self_data_bytes, | |
self_dim_stride); | |
result1_data_bytes += strides[0]; | |
result2_data_bytes += strides[1]; | |
self_data_bytes += strides[2]; | |
} | |
}; | |
compare_base_kernel_core<scalar_t, scalar_t_2>( | |
result1, result2, self, dim, keepdim, loop); | |
} | |
static void min_kernel_impl( | |
Tensor& result, | |
Tensor& indice, | |
const Tensor& self, | |
int64_t dim, | |
bool keepdim) { | |
auto wrap_dim = maybe_wrap_dim(dim, self.dim()); | |
int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim); | |
TORCH_CHECK(result.scalar_type() == self.scalar_type() && indice.scalar_type() == kLong, | |
"Expect dtype ", self.scalar_type(), "and torch.long, but got ", result.scalar_type(), "and", indice.scalar_type()); | |
AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "min_cpu", [&] { | |
compare_base_kernel<scalar_t>(result, indice, self, wrap_dim, keepdim, [&] ( | |
scalar_t* result_data, int64_t* indice_data, | |
const scalar_t* self_data, auto self_dim_stride) { | |
using value_t = typename c10::scalar_value_type<scalar_t>::type; | |
value_t (*zabs_)(scalar_t) = zabs<scalar_t, value_t>; | |
scalar_t min_number = self_data[0]; | |
int64_t index = 0; | |
for (int64_t i = 0; i < self_dim_size; ++i) { | |
scalar_t value = self_data[i * self_dim_stride]; | |
if (!(zabs_(value) >= zabs_(min_number))) { | |
min_number = value; | |
index = i; | |
if (_isnan<scalar_t>(value)) { | |
break; | |
} | |
} | |
} | |
*result_data = min_number; | |
*indice_data = index; | |
} | |
); | |
}); | |
} |
auto output = at::empty({batch_size}, data.options()); | ||
int64_t segment_count = lengths.numel(); | ||
auto output_shape = data.sizes().vec(); | ||
output_shape[0] = segment_count; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be axis
, not 0
? You can assert that axis==0
if you don't support anything else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, I am actually checking this on caller side.
@@ -169,7 +211,7 @@ Tensor segment_reduce_kernel( | |||
auto min_length = lengths_value.min().item<int64_t>(); | |||
TORCH_CHECK((min_length >= 0), "lengths contains negative value!"); | |||
TORCH_CHECK(min_length != 0 || initial.has_value()); | |||
TORCH_CHECK(lengths_value.sum().item<int64_t>() == data.numel()); | |||
TORCH_CHECK(lengths_value.sum().item<int64_t>() == data.size(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data.size[0] is slightly faster, also, it should really be axis
, not 0?
} | ||
int64_t lengths_cum_sum = 0; | ||
for (int64_t i = 0; i < segment_count; ++i) { | ||
for (int64_t l = 0; l < stride_count; ++l) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are you splitting stride_count and seg_element_count into 2 loops? They are conceptually the same, the dimensions that are not being reduced.
Differential Revision: D28922838 fbshipit-source-id: 6544b91df1ed2bc4ef50191c9016395023f3e2ea
…ch#59951) Summary: Pull Request resolved: pytorch#59951 Add support for multi-d input for cpu forward/backward implementation. Next step: Adding cuda support for multi-d input. Test Plan: Added unit tests. Differential Revision: D29105457 fbshipit-source-id: d61fe767a80410501272231219d751cf29225b0b
This pull request was exported from Phabricator. Differential Revision: D29105457 |
ab05b61
to
6dc7c90
Compare
This pull request has been merged in a727f65. |
Summary:
Add support for multi-d input for cpu forward/backward implementation.
Next step: Adding cuda support for multi-d input.
Test Plan: Added unit tests.
Differential Revision: D29105457