Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 3 additions & 29 deletions aten/src/ATen/native/Histogram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ namespace at { namespace native {

DEFINE_DISPATCH(histogramdd_stub);
DEFINE_DISPATCH(histogramdd_linear_stub);
DEFINE_DISPATCH(histogram_select_outer_bin_edges_stub);

namespace {

Expand Down Expand Up @@ -153,22 +154,6 @@ void histogramdd_prepare_out(const Tensor& input, TensorList bins,
histogramdd_prepare_out(input, bin_ct, hist, bin_edges);
}

template<typename scalar_t>
void infer_bin_edges_from_input(const Tensor& input, const int64_t N,
std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges) {
// Calls aminmax on input with dim=0, reducing all but the innermost dimension of input.
Tensor min, max;
std::tie(min, max) = aminmax(input, 0);

TORCH_INTERNAL_ASSERT(min.is_contiguous() && max.is_contiguous());

const scalar_t *min_data = min.data_ptr<scalar_t>();
std::copy(min_data, min_data + N, leftmost_edges.begin());

const scalar_t *max_data = max.data_ptr<scalar_t>();
std::copy(max_data, max_data + N, rightmost_edges.begin());
}

/* Determines the outermost bin edges. For simplicity when calling into aminmax,
* assumes that input has already been reshaped to (M, N).
*/
Expand All @@ -192,19 +177,8 @@ select_outer_bin_edges(const Tensor& input, c10::optional<c10::ArrayRef<double>>
}
} else if (input.numel() > 0) {
// non-empty input
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "histogramdd", [&]() {
if (input.is_mps()) {
Tensor min, max;
std::tie(min, max) = at::aminmax(input, 0);

for (const auto i : c10::irange(N)) {
leftmost_edges[i] = min[i].item().to<scalar_t>();
rightmost_edges[i] = max[i].item().to<scalar_t>();
}
} else {
infer_bin_edges_from_input<scalar_t>(input, N, leftmost_edges, rightmost_edges);
}
});

histogram_select_outer_bin_edges_stub(input.device().type(), input, N, leftmost_edges, rightmost_edges);
}

for (const auto dim : c10::irange(N)) {
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/Histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ namespace at::native {

using histogramdd_fn = void(*)(const Tensor&, const c10::optional<Tensor>&, bool, Tensor&, const TensorList&);
using histogramdd_linear_fn = void(*)(const Tensor&, const c10::optional<Tensor>&, bool, Tensor&, const TensorList&, bool);
using histogram_select_outer_bin_edges_fn = void(*)(const Tensor& input, const int64_t N, std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges);

DECLARE_DISPATCH(histogramdd_fn, histogramdd_stub);
DECLARE_DISPATCH(histogramdd_linear_fn, histogramdd_linear_stub);
DECLARE_DISPATCH(histogram_select_outer_bin_edges_fn, histogram_select_outer_bin_edges_stub);

} // namespace at::native
26 changes: 25 additions & 1 deletion aten/src/ATen/native/cpu/HistogramKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/aminmax.h>
#include <ATen/ops/sum.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_like_ops.h>
Expand Down Expand Up @@ -282,10 +283,33 @@ static void histogramdd_linear_kernel_impl(const Tensor& self, const c10::option
}
}

template<typename scalar_t>
void infer_bin_edges_from_input(const Tensor& input, const int64_t N,
std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges) {
// Calls aminmax on input with dim=0, reducing all but the innermost dimension of input.
Tensor min, max;
std::tie(min, max) = aminmax(input, 0);

TORCH_INTERNAL_ASSERT(min.is_contiguous() && max.is_contiguous());

const scalar_t *min_data = min.data_ptr<scalar_t>();
std::copy(min_data, min_data + N, leftmost_edges.begin());

const scalar_t *max_data = max.data_ptr<scalar_t>();
std::copy(max_data, max_data + N, rightmost_edges.begin());
}

static void histogram_select_outer_bin_edges_impl(const Tensor& input, const int64_t N,
std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges) {
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "histogramdd", [&]() {
infer_bin_edges_from_input<scalar_t>(input, N, leftmost_edges, rightmost_edges);
});
}

} // namespace

REGISTER_DISPATCH(histogramdd_stub, &histogramdd_kernel_impl);

REGISTER_DISPATCH(histogramdd_linear_stub, &histogramdd_linear_kernel_impl);
REGISTER_DISPATCH(histogram_select_outer_bin_edges_stub, &histogram_select_outer_bin_edges_impl);

} // namespace at::native
15 changes: 15 additions & 0 deletions aten/src/ATen/native/mps/operations/HistogramKernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/aminmax.h>
#include <ATen/ops/sum.h>
#endif

Expand Down Expand Up @@ -396,6 +397,20 @@ static void histogramdd_linear_kernel(const Tensor& self,
}
}

static void histogram_select_outer_bin_edges_kernel(const Tensor& input,
const int64_t N,
std::vector<double>& leftmost_edges,
std::vector<double>& rightmost_edges) {
Tensor min, max;
std::tie(min, max) = at::aminmax(input, 0);

for (const auto i : c10::irange(N)) {
leftmost_edges[i] = min[i].item().to<double>();
rightmost_edges[i] = max[i].item().to<double>();
}
}

REGISTER_DISPATCH(histogramdd_stub, &histogramdd_kernel);
REGISTER_DISPATCH(histogramdd_linear_stub, &histogramdd_linear_kernel);
REGISTER_DISPATCH(histogram_select_outer_bin_edges_stub, &histogram_select_outer_bin_edges_kernel);
} // namespace at::native