Skip to content

Commit

Permalink
fix: fix int fill with float rounding issue
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com>
  • Loading branch information
henryiii committed Aug 29, 2023
1 parent c8359e5 commit c3cada0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
21 changes: 21 additions & 0 deletions include/bh_python/fill.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <boost/mp11.hpp>
#include <boost/variant2/variant.hpp>

#include <cmath>
#include <stdexcept>
#include <type_traits>
#include <vector>
Expand Down Expand Up @@ -77,6 +78,7 @@ inline bool is_value<std::string>(py::handle h) {

template <class T>
decltype(auto) special_cast(py::handle x) {
py::print("Regular special cast with", x);
return py::cast<T>(x);
}

Expand All @@ -95,6 +97,25 @@ inline decltype(auto) special_cast<c_array_t<std::string>>(py::handle x) {
return py::cast<B>(x);
}

// Allow single floats to be integers
template <>
inline decltype(auto) special_cast<int>(py::handle x) {
try {
return static_cast<int>(std::floor(py::cast<double>(x)));
} catch(const py::cast_error&) {
return py::cast<int>(x);
}
}

// Allow float arrays to be integers
template <>
inline decltype(auto) special_cast<c_array_t<int>>(py::handle x) {
auto np = py::module_::import("numpy");
to_int = np.attr("floor")(x)

return py::cast<c_array_t<int>>(to_int);
}

using arg_t = variant::variant<c_array_t<double>,
double,
c_array_t<int>,
Expand Down
13 changes: 13 additions & 0 deletions tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,19 @@ def test_fill_int_1d():
h[-3]


def test_fill_int_with_float_single_1d():
h = bh.Histogram(bh.axis.Integer(-1, 2))
h.fill(0.3)
h.fill(-0.3)
assert h.values() == approx(np.array([1, 1, 0]))


def test_fill_int_with_float_array_1d():
h = bh.Histogram(bh.axis.Integer(-1, 2))
h.fill([-0.3, 0.3])
assert h.values() == approx(np.array([1, 1, 0]))


def test_fill_1d(flow):
h = bh.Histogram(bh.axis.Regular(3, -1, 2, underflow=flow, overflow=flow))
with pytest.raises(ValueError):
Expand Down

0 comments on commit c3cada0

Please sign in to comment.