Skip to content

Commit

Permalink
Merge 384efe4 into 4aaa72a
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsvu committed Feb 8, 2023
2 parents 4aaa72a + 384efe4 commit dc49ca2
Show file tree
Hide file tree
Showing 16 changed files with 1,069 additions and 432 deletions.
2 changes: 2 additions & 0 deletions src/DataStructures/Python/DataVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "DataStructures/Python/DataVector.hpp"

#include <memory>
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Expand Down Expand Up @@ -202,5 +203,6 @@ void bind_datavector(py::module& m) {
// NOLINTNEXTLINE(misc-redundant-expression)
.def(py::self != py::self)
.def("__neg__", +[](const DataVector& t) { return DataVector{-t}; });
py::implicitly_convertible<py::array, DataVector>();
}
} // namespace py_bindings
28 changes: 22 additions & 6 deletions src/DataStructures/Tensor/Python/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "DataStructures/Tensor/Python/Tensor.hpp"

#include <memory>
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Expand Down Expand Up @@ -143,24 +144,32 @@ void bind_tensor_impl(py::module& m, const std::string& name) { // NOLINT
throw std::runtime_error(
"Incompatible format: expected a double array.");
}
if (info.ndim != 2) {
// 1D arrays are allowed to construct scalars. Higher-rank
// tensors require 2D arrays.
if (not(info.ndim == 2 or
(info.ndim == 1 and TensorType::rank() == 0))) {
throw std::runtime_error(
"Tensor data is expected to be 2D with shape (size, "
"num_points).");
"num_points). Scalars can be 1D.");
}
const auto size = static_cast<size_t>(info.shape[0]);
const size_t size =
info.ndim == 1 ? 1 : static_cast<size_t>(info.shape[0]);
if (size != TensorType::size()) {
throw std::runtime_error(
"This tensor type has " +
std::to_string(TensorType::size()) +
" independent components, but data has first dimension " +
std::to_string(size) + ".");
}
const auto num_points = static_cast<size_t>(info.shape[1]);
const auto num_points =
static_cast<size_t>(info.shape[info.ndim == 1 ? 0 : 1]);
auto data = static_cast<double*>(info.ptr);
const std::array<size_t, 2> strides{
{static_cast<size_t>(info.strides[0] / info.itemsize),
static_cast<size_t>(info.strides[1] / info.itemsize)}};
{info.ndim == 1
? 1
: static_cast<size_t>(info.strides[0] / info.itemsize),
static_cast<size_t>(info.strides[info.ndim == 1 ? 0 : 1] /
info.itemsize)}};
if (copy) {
TensorType result{num_points};
for (size_t i = 0; i < size; ++i) {
Expand Down Expand Up @@ -191,7 +200,14 @@ void bind_tensor_impl(py::module& m, const std::string& name) { // NOLINT
}

if constexpr (TensorType::rank() <= 1) {
// Scalars and vectors have an unambiguous storage order for their
// components: a scalar has only a single component, and a vector has Dim
// components in an obvious order. Therefore, they can be constructed from
// the underlying std::array (Python list) of DataVectors and also
// implicitly converted from Numpy arrays. Component ordering for
// higher-rank tensors isn't obvious, so we don't enable the conversion.
tensor.def(py::init<typename TensorType::storage_type>());
py::implicitly_convertible<py::array, TensorType>();
}
}
} // namespace
Expand Down
7 changes: 7 additions & 0 deletions src/DataStructures/Tensor/Python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,10 @@ def __init__(self, inverse: bool):
Scalar = {DataVector: ScalarDV, float: ScalarD}
Jacobian = JacobianMeta(inverse=False)
InverseJacobian = JacobianMeta(inverse=True)

# Define a type annotation that means "any tensor". This should really be a
# common superclass of all Tensor types, but we currently don't have that in C++
# so we use a type alias to `typing.Any` as a workaround.
from typing import Any

Tensor = Any

0 comments on commit dc49ca2

Please sign in to comment.