Skip to content

Commit

Permalink
Merge pull request #2480 from scipp/fix-operator-binding
Browse files Browse the repository at this point in the history
Clean up function bindings
  • Loading branch information
SimonHeybrock committed Mar 9, 2022
2 parents f305237 + b273d89 commit 7bc56d6
Show file tree
Hide file tree
Showing 16 changed files with 693 additions and 320 deletions.
2 changes: 2 additions & 0 deletions docs/reference/free-functions.rst
Expand Up @@ -15,6 +15,7 @@ General
choose
collapse
histogram
logical_not
logical_and
logical_or
logical_xor
Expand Down Expand Up @@ -53,6 +54,7 @@ Math
log10
mod
multiply
negative
norm
pow
reciprocal
Expand Down
18 changes: 9 additions & 9 deletions lib/cmake/scipp-functions.cmake
Expand Up @@ -59,27 +59,27 @@ scipp_binary(comparison less_equal)
scipp_binary(comparison not_equal)
setup_scipp_category(comparison)

scipp_function("unary" arithmetic operator- OP unary_minus)
scipp_function("unary" arithmetic operator- OP negative)
scipp_function("binary" arithmetic operator+ OP add)
scipp_function("binary" arithmetic operator- OP subtract)
scipp_function("binary" arithmetic operator* OP multiply SKIP_VARIABLE)
scipp_function("binary" arithmetic operator/ OP divide)
scipp_function("binary" arithmetic floor_divide)
scipp_function("binary" arithmetic operator% OP mod)
scipp_function("inplace" arithmetic operator+= OP add_equals)
scipp_function("inplace" arithmetic operator-= OP subtract_equals)
scipp_function("inplace" arithmetic operator*= OP multiply_equals)
scipp_function("inplace" arithmetic operator/= OP divide_equals)
scipp_function("inplace" arithmetic operator%= OP mod_equals)
scipp_function("inplace" arithmetic operator+= OP add_equals SKIP_PYTHON)
scipp_function("inplace" arithmetic operator-= OP subtract_equals SKIP_PYTHON)
scipp_function("inplace" arithmetic operator*= OP multiply_equals SKIP_PYTHON)
scipp_function("inplace" arithmetic operator/= OP divide_equals SKIP_PYTHON)
scipp_function("inplace" arithmetic operator%= OP mod_equals SKIP_PYTHON)
setup_scipp_category(arithmetic)

scipp_function("unary" logical operator~ OP logical_not)
scipp_function("binary" logical operator| OP logical_or)
scipp_function("binary" logical operator& OP logical_and)
scipp_function("binary" logical operator^ OP logical_xor)
scipp_function("inplace" logical operator|= OP logical_or_equals)
scipp_function("inplace" logical operator&= OP logical_and_equals)
scipp_function("inplace" logical operator^= OP logical_xor_equals)
scipp_function("inplace" logical operator|= OP logical_or_equals SKIP_PYTHON)
scipp_function("inplace" logical operator&= OP logical_and_equals SKIP_PYTHON)
scipp_function("inplace" logical operator^= OP logical_xor_equals SKIP_PYTHON)
setup_scipp_category(logical)

scipp_function("unary" bins bin_sizes SKIP_VARIABLE BASE_INCLUDE dataset/bins.h)
Expand Down
34 changes: 18 additions & 16 deletions lib/cmake/scipp-util.cmake
Expand Up @@ -3,7 +3,7 @@
# Copyright (c) 2022 Scipp contributors (https://github.com/scipp)
# ~~~
function(scipp_function template category function_name)
set(options SKIP_VARIABLE OUT)
set(options SKIP_VARIABLE SKIP_PYTHON OUT)
set(oneValueArgs OP PREPROCESS_VARIABLE BASE_INCLUDE)
cmake_parse_arguments(
PARSE_ARGV 3 SCIPP_FUNCTION "${options}" "${oneValueArgs}" ""
Expand Down Expand Up @@ -57,21 +57,23 @@ function(scipp_function template category function_name)
configure_in_module("variable" ${OPNAME})
endif()
configure_in_module("dataset" ${OPNAME})
configure_file(templates/python_${template}.cpp.in python/${src})
set(python_SRC_FILES
${python_SRC_FILES} ${src}
PARENT_SCOPE
)
set(python_binders_fwd python_${category}_binders_fwd)
set(python_binders python_${category}_binders)
set(${python_binders_fwd}
"${${python_binders_fwd}}\nvoid init_${OPNAME}(pybind11::module &)ENDL"
PARENT_SCOPE
)
set(${python_binders}
"${${python_binders}}\n init_${OPNAME}(m)ENDL"
PARENT_SCOPE
)
if(NOT SCIPP_FUNCTION_SKIP_PYTHON)
configure_file(templates/python_${template}.cpp.in python/${src})
set(python_SRC_FILES
${python_SRC_FILES} ${src}
PARENT_SCOPE
)
set(python_binders_fwd python_${category}_binders_fwd)
set(python_binders python_${category}_binders)
set(${python_binders_fwd}
"${${python_binders_fwd}}\nvoid init_${OPNAME}(pybind11::module &)ENDL"
PARENT_SCOPE
)
set(${python_binders}
"${${python_binders}}\n init_${OPNAME}(m)ENDL"
PARENT_SCOPE
)
endif()
endfunction()

macro(scipp_unary)
Expand Down
2 changes: 1 addition & 1 deletion lib/core/include/scipp/core/element/arithmetic.h
Expand Up @@ -187,7 +187,7 @@ constexpr auto mod_equals =
[](units::Unit &a, const units::Unit &b) { a %= b; },
[](auto &&a, const auto &b) { a = mod(a, b); }};

constexpr auto unary_minus =
constexpr auto negative =
overloaded{arg_list<double, float, int64_t, int32_t, Eigen::Vector3d>,
[](const auto x) { return -x; }};

Expand Down
2 changes: 1 addition & 1 deletion lib/core/test/element_arithmetic_test.cpp
Expand Up @@ -51,7 +51,7 @@ TEST_F(ElementArithmeticTest, non_in_place) {
EXPECT_EQ(divide(a, b), a / b);
}

TEST_F(ElementArithmeticTest, unary_minus) { EXPECT_EQ(unary_minus(a), -a); }
TEST_F(ElementArithmeticTest, negative) { EXPECT_EQ(negative(a), -a); }

TEST(ElementArithmeticIntegerDivisionTest, truediv_32bit) {
const int32_t a = 2;
Expand Down
20 changes: 0 additions & 20 deletions lib/templates/python_inplace.cpp.in

This file was deleted.

4 changes: 2 additions & 2 deletions lib/templates/python_unary.cpp.in
Expand Up @@ -13,11 +13,11 @@ namespace py = pybind11;

template <typename T> void bind_@OPNAME@(py::module &m) {
m.def(
"@NAME@", [](const T &x) { return @NAME@(x); },
"@OPNAME@", [](const T &x) { return @NAME@(x); },
py::arg("x"), py::call_guard<py::gil_scoped_release>());
if constexpr(std::is_same_v<T, Variable> && @GENERATE_OUT@)
m.def(
"@NAME@", [](const T &x, T &out) { return @NAME@(x, out); },
"@OPNAME@", [](const T &x, T &out) { return @NAME@(x, out); },
py::arg("x"), py::arg("out"), py::keep_alive<0, 2>(),
py::call_guard<py::gil_scoped_release>());
}
Expand Down
5 changes: 3 additions & 2 deletions src/scipp/__init__.py
Expand Up @@ -33,7 +33,7 @@
from . import units
from . import geometry
# Import functions
from ._scipp.core import as_const, choose, logical_and, logical_or, logical_xor
from ._scipp.core import as_const, choose
# Import python functions
from .show import show, make_svg
from .table import table
Expand Down Expand Up @@ -64,13 +64,14 @@

from .coords import transform_coords, show_graph

from .core import add, divide, floor_divide, mod, multiply, subtract
from .core import add, divide, floor_divide, mod, multiply, negative, subtract
from .core import lookup, histogram, bin, bins, bins_like
from .core import less, greater, less_equal, greater_equal, equal, not_equal, identical, isclose, allclose
from .core import counts_to_density, density_to_counts
from .core import cumsum
from .core import merge
from .core import groupby
from .core import logical_not, logical_and, logical_or, logical_xor
from .core import abs, nan_to_num, norm, reciprocal, pow, sqrt, exp, log, log10, round, floor, ceil, erf, erfc, midpoints
from .core import dot, islinspace, issorted, allsorted, cross, sort, values, variances, stddevs, rebin, where
from .core import mean, nanmean, sum, nansum, min, max, nanmin, nanmax, all, any
Expand Down
3 changes: 2 additions & 1 deletion src/scipp/core/__init__.py
Expand Up @@ -70,13 +70,14 @@ def custom_formatwarning(msg, *args, **kwargs):
setattr(_cls, '__array_ufunc__', None)
del _cls

from .arithmetic import add, divide, floor_divide, mod, multiply, subtract
from .arithmetic import add, divide, floor_divide, mod, multiply, negative, subtract
from .bins import lookup, histogram, bin, bins, bins_like
from .comparison import less, greater, less_equal, greater_equal, equal, not_equal, identical, isclose, allclose
from .counts import counts_to_density, density_to_counts
from .cumulative import cumsum
from .dataset import irreducible_mask, merge
from .groupby import groupby
from .logical import logical_not, logical_and, logical_or, logical_xor
from .math import abs, cross, dot, nan_to_num, norm, reciprocal, pow, sqrt, exp, log, log10, round, floor, ceil, erf, erfc, midpoints
from .operations import islinspace, issorted, allsorted, sort, values, variances, stddevs, rebin, where, to
from .reduction import mean, nanmean, sum, nansum, min, max, nanmin, nanmax, all, any
Expand Down

0 comments on commit 7bc56d6

Please sign in to comment.