diff --git a/include/nanobind/stl/complex.h b/include/nanobind/stl/complex.h new file mode 100644 index 00000000..3c4c8a1c --- /dev/null +++ b/include/nanobind/stl/complex.h @@ -0,0 +1,67 @@ +/* + nanobind/stl/complex.h: type caster for std::complex<...> + + Copyright (c) 2023 Degottex Gilles and Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include +#include + +NAMESPACE_BEGIN(NB_NAMESPACE) +NAMESPACE_BEGIN(detail) + +template struct type_caster> { + NB_TYPE_CASTER(std::complex, const_name("complex") ) + + template + bool from_python(handle src, uint8_t flags, + cleanup_list *cleanup) noexcept { + (void) flags; + (void) cleanup; + + if (PyComplex_Check(src.ptr())) { + value = std::complex( + (T) PyComplex_RealAsDouble(src.ptr()), + (T) PyComplex_ImagAsDouble(src.ptr()) + ); + return true; + } + + if (Recursive && !PyFloat_CheckExact(src.ptr()) && + !PyLong_CheckExact(src.ptr()) && + PyObject_HasAttrString(src.ptr(), "imag")) { + try { + object tmp = handle(&PyComplex_Type)(src); + return from_python(tmp, flags, cleanup); + } catch (...) { + return false; + } + } + + make_caster caster; + if (caster.from_python(src, flags, cleanup)) { + value = std::complex(caster.operator cast_t()); + return true; + } + + return true; + } + + template + static handle from_cpp(T2 &&value, rv_policy policy, + cleanup_list *cleanup) noexcept { + (void) policy; + (void) cleanup; + + return PyComplex_FromDoubles((double) value.real(), + (double) value.imag()); + } +}; + +NAMESPACE_END(detail) +NAMESPACE_END(NB_NAMESPACE) diff --git a/tests/test_stl.cpp b/tests/test_stl.cpp index 6e9ca652..0ddc76f5 100644 --- a/tests/test_stl.cpp +++ b/tests/test_stl.cpp @@ -12,6 +12,7 @@ #include #include #include +#include NB_MAKE_OPAQUE(std::vector>) @@ -422,4 +423,20 @@ NB_MODULE(test_stl_ext, m) { vec.flip(); return vec; }); + + + m.def("complex_value_float", [](const std::complex& x){ + return x; + }); + m.def("complex_value_double", [](const std::complex& x){ + return x; + }); + + m.def("complex_array_float", [](const std::vector>& x){ + return x; + }); + m.def("complex_array_double", [](const std::vector>& x){ + return x; + }); + } diff --git a/tests/test_stl.py b/tests/test_stl.py index 19e42c12..bb1b1d9d 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -3,7 +3,6 @@ import sys from common import collect, skip_on_pypy - @pytest.fixture def clean(): collect() @@ -758,3 +757,66 @@ def test67_vector_bool(): bool_vector = [True, False, True, False] result = t.flip_vector_bool(bool_vector) assert result == [not x for x in bool_vector] + + +def test68_complex_value(): + # double: 64bits + assert t.complex_value_double(1.0) == 1.0 + assert t.complex_value_double(1.0j) == 1.0j + assert t.complex_value_double(0.0) == 0.0 + assert t.complex_value_double(0.0j) == 0.0j + assert t.complex_value_double(0) == 0 + assert t.complex_value_float(1.0) == 1.0 + assert t.complex_value_float(1.0j) == 1.0j + assert t.complex_value_float(0.0) == 0.0 + assert t.complex_value_float(0.0j) == 0.0j + assert t.complex_value_float(0) == 0 + + val_64 = 2.7-3.2j + val_32 = 2.700000047683716-3.200000047683716j + assert val_64 != val_32 + + assert t.complex_value_float(val_32) == val_32 + assert t.complex_value_float(val_64) == val_32 + assert t.complex_value_double(val_32) == val_32 + assert t.complex_value_double(val_64) == val_64 + + try: + import numpy as np + assert t.complex_value_float(np.complex64(val_32)) == val_32 + assert t.complex_value_float(np.complex64(val_64)) == val_32 + assert t.complex_value_double(np.complex64(val_32)) == val_32 + assert t.complex_value_double(np.complex64(val_64)) == val_32 + assert t.complex_value_float(np.complex128(val_32)) == val_32 + assert t.complex_value_float(np.complex128(val_64)) == val_32 + assert t.complex_value_double(np.complex128(val_32)) == val_32 + assert t.complex_value_double(np.complex128(val_64)) == val_64 + except ImportError: + pass + +def test69_complex_array(): + val1_64 = 2.7-3.2j + val1_32 = 2.700000047683716-3.200000047683716j + val2_64 = 3.1415 + val2_32 = 3.1414999961853027+0j + + # test 64 bit casts + assert t.complex_array_double([val1_64, -1j, val2_64]) == [val1_64, -0-1j, val2_64] + + # test 32 bit casts + assert t.complex_array_float([val1_64, -1j, val2_64]) == [val1_32, (-0-1j), val2_32] + + try: + import numpy as np + + # test 64 bit casts + assert t.complex_array_double(np.array([val1_64, -1j, val2_64])) == [val1_64, -0-1j, val2_64] + assert t.complex_array_double(np.array([val1_64, -1j, val2_64],dtype=np.complex128)) == [val1_64, -0-1j, val2_64] + assert t.complex_array_double(np.array([val1_64, -1j, val2_64],dtype=np.complex64)) == [val1_32, -0-1j, val2_32] + + # test 32 bit casts + assert t.complex_array_float(np.array([val1_64, -1j, val2_64])) == [val1_32, (-0-1j), val2_32] + assert t.complex_array_float(np.array([val1_64, -1j, val2_64],dtype=np.complex128)) == [val1_32, (-0-1j), val2_32] + assert t.complex_array_float(np.array([val1_64, -1j, val2_64],dtype=np.complex64)) == [val1_32, (-0-1j), val2_32] + except ImportError: + pass