-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Closed
Description
Consider the following example
import numpy
import ctest
print ctest.square(numpy.arange(3, dtype=float))
print ctest.square(numpy.arange(3, dtype=float) * 1j)
print ctest.square(9.0)
print ctest.square(9.0 * 1j)
where ctest
is my pybind11 module
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/complex.h>
namespace py = pybind11;
template <typename T>
T square(T x) {
return x * x;
}
void pyexport(py::module& m) {
m.def("square", py::vectorize(square<std::complex<double>>));
m.def("square", py::vectorize(square<std::complex<float>>));
m.def("square", py::vectorize(square<float>));
m.def("square", py::vectorize(square<double>));
}
The problem is that my functions are always resolved to the std::complex
type, the above Python code returning
[ 0.+0.j 1.+0.j 4.+0.j]
[ 0.+0.j -1.+0.j -4.+0.j]
(81+0j)
(-81+0j)
- should be real valued
- is OK
- should be real valued
- is OK
If I now reorder my def
s like
void pyexport(py::module& m) {
m.def("square", py::vectorize(square<float>));
m.def("square", py::vectorize(square<double>));
m.def("square", py::vectorize(square<std::complex<double>>));
m.def("square", py::vectorize(square<std::complex<float>>));
}
the std::complex
type is never resolved:
[ 0. 1. 4.]
test.py:8: ComplexWarning: Casting complex values to real discards the imaginary part
print ctest.square(numpy.arange(3, dtype=float) * 1j)
[ 0. 0. 0.]
81.0
Traceback (most recent call last):
File "test.py", line 11, in <module>
print ctest.square(9.0 * 1j)
TypeError: Incompatible function arguments. The following argument types are supported:
1. (array[float]) -> object
2. (array[float]) -> object
3. (array[complex]) -> object
4. (array[complex]) -> object
- is OK
- is implicitly cast back to real valued number (not OK)
- is Ok
- fails to resolve
Metadata
Metadata
Assignees
Labels
No labels