From cee10363ad51fbd15bcdd26e19af7da8a1c1a60a Mon Sep 17 00:00:00 2001 From: Jordan Moxon Date: Fri, 22 Feb 2019 15:55:18 -0800 Subject: [PATCH] Add support for ComplexDataVector to Pypp tests --- tests/Unit/Pypp/PyppFundamentals.hpp | 62 ++++++++++++++++++++++++++++ tests/Unit/Pypp/Test_Pypp.cpp | 21 ++++++++++ 2 files changed, 83 insertions(+) diff --git a/tests/Unit/Pypp/PyppFundamentals.hpp b/tests/Unit/Pypp/PyppFundamentals.hpp index a59eabe9ed00..3300cf91d75f 100644 --- a/tests/Unit/Pypp/PyppFundamentals.hpp +++ b/tests/Unit/Pypp/PyppFundamentals.hpp @@ -212,6 +212,32 @@ struct ToPyObject { } }; +template <> +struct ToPyObject { + static PyObject* convert(const ComplexDataVector& t) { + PyObject* npy_array = PyArray_SimpleNew( // NOLINT + 1, (std::array{{static_cast(t.size())}}.data()), + NPY_COMPLEX128); + + if (npy_array == nullptr) { + throw std::runtime_error{"Failed to convert argument."}; + } + for (size_t i = 0; i < t.size(); ++i) { + // clang-tidy: Do not use pointer arithmetic + // clang-tidy: Do not use reinterpret cast + const auto data = + static_cast*>(PyArray_GETPTR1( // NOLINT + reinterpret_cast(npy_array), // NOLINT + static_cast(i))); + if (data == nullptr) { + throw std::runtime_error{"Failed to access argument of PyArray."}; + } + *data = t[i]; + } + return npy_array; + } +}; + template <> struct FromPyObject { static long convert(PyObject* t) { @@ -397,6 +423,42 @@ struct FromPyObject { } }; +template <> +struct FromPyObject { + static ComplexDataVector convert(PyObject* p) { + if (p == nullptr) { + throw std::runtime_error{"Received null PyObject."}; + } + // clang-tidy: c-style casts. (Expanded from macro) + if (not PyArray_CheckExact(p)) { // NOLINT + throw std::runtime_error{ + "Cannot convert non-array type to ComplexDataVector."}; + } + // clang-tidy: reinterpret_cast + const auto npy_array = reinterpret_cast(p); // NOLINT + if (PyArray_TYPE(npy_array) != NPY_COMPLEX128) { + throw std::runtime_error{ + "Cannot convert array of non-complex type to ComplexDataVector."}; + } + if (PyArray_NDIM(npy_array) != 1) { + throw std::runtime_error{ + "Cannot convert array of ndim != 1 to ComplexDataVector."}; + } + // clang-tidy: c-style casts, pointer arithmetic. (Expanded from macro) + ComplexDataVector t(static_cast(PyArray_Size(p))); // NOLINT + for (size_t i = 0; i < t.size(); ++i) { + // clang-tidy: pointer arithmetic. (Expanded from macro) + const auto value = static_cast*>( + PyArray_GETPTR1(npy_array, static_cast(i))); // NOLINT + if (value == nullptr) { + throw std::runtime_error{"Failed to get argument from PyArray."}; + } + t[i] = *value; + } + return t; + } +}; + // This function is needed because one cannot cast an array of size_t's (used by // SpECTRE) to an array of longs (used by NumPy). template diff --git a/tests/Unit/Pypp/Test_Pypp.cpp b/tests/Unit/Pypp/Test_Pypp.cpp index ea81e1e18bb3..2243cb2a259f 100644 --- a/tests/Unit/Pypp/Test_Pypp.cpp +++ b/tests/Unit/Pypp/Test_Pypp.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -12,6 +13,7 @@ #include #include +#include "DataStructures/ComplexDataVector.hpp" #include "DataStructures/DataVector.hpp" #include "DataStructures/Tensor/Tensor.hpp" #include "Utilities/Gsl.hpp" @@ -137,6 +139,25 @@ SPECTRE_TEST_CASE("Unit.Pypp.DataVector", "[Pypp][Unit]") { CHECK_THROWS(pypp::call("PyppPyTests", "ndarray_of_floats")); } +SPECTRE_TEST_CASE("Unit.Pypp.ComplexDataVector", "[Pypp][Unit]") { + pypp::SetupLocalPythonEnvironment local_python_env{"Pypp/"}; + std::complex test_value_0{1.3, 2.2}; + std::complex test_value_1{4.0, 3.1}; + std::complex test_value_2{4.2, 5.7}; + std::complex test_value_3{6.8, 7.3}; + const auto ret = pypp::call( + "numpy", "multiply", ComplexDataVector{test_value_0, test_value_1}, + ComplexDataVector{test_value_2, test_value_3}); + CHECK_ITERABLE_APPROX(ret[0], test_value_0 * test_value_2); + CHECK_ITERABLE_APPROX(ret[1], test_value_1 * test_value_3); + CHECK_THROWS(pypp::call( + "numpy", "multiply", ComplexDataVector{test_value_0, test_value_1}, + ComplexDataVector{test_value_2, test_value_3})); + CHECK_THROWS(pypp::call("PyppPyTests", "two_dim_ndarray")); + CHECK_THROWS( + pypp::call("PyppPyTests", "ndarray_of_floats")); +} + SPECTRE_TEST_CASE("Unit.Pypp.Tensor.Double", "[Pypp][Unit]") { pypp::SetupLocalPythonEnvironment local_python_env{"Pypp/"};