Skip to content

Commit

Permalink
[fix] check for histogramdd when bins is int[]
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed May 4, 2023
1 parent db4572d commit bb97087
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
4 changes: 4 additions & 0 deletions aten/src/ATen/native/Histogram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ std::vector<Tensor>& histogramdd_bin_edges_out_cpu(const Tensor& self, IntArrayR

auto outer_bin_edges = select_outer_bin_edges(reshaped_self, range);

const int64_t bin_size = bin_ct.size();
TORCH_CHECK(
N == bin_size,
"histogramdd: The size of bins must be equal to the innermost dimension of the input.");
for (const auto dim : c10::irange(N)) {
linspace_out(outer_bin_edges.first[dim], outer_bin_edges.second[dim],
bin_ct[dim] + 1, bin_edges_out[dim]);
Expand Down
9 changes: 9 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2748,6 +2748,12 @@ def sample_inputs_histogramdd(op_info, device, dtype, requires_grad, **kwargs):
yield SampleInput(input_tensor, bins_tensor,
weight=weight_tensor, density=density)

def error_inputs_histogramdd(opinfo, device, **kwargs):
invalid_bins = [1, 1, 1, 1, 1]
make_arg = partial(make_tensor, dtype=torch.float, device=device, requires_grad=False)
msg = "histogramdd: The size of bins must be equal to the innermost dimension of the input."
yield ErrorInput(SampleInput(make_arg(5, 6), invalid_bins), error_regex=msg)

def sample_inputs_histc(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)

Expand Down Expand Up @@ -16024,8 +16030,11 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
dtypes=floating_types(),
dtypesIfCUDA=_dispatch_dtypes(), # histogramdd is only implemented on CPU
sample_inputs_func=sample_inputs_histogramdd,
error_inputs_func=error_inputs_histogramdd,
supports_autograd=False,
skips=(
# Not implemented on CUDA
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
# JIT tests don't work with Tensor keyword arguments
# https://github.com/pytorch/pytorch/issues/58507
Expand Down

0 comments on commit bb97087

Please sign in to comment.