Skip to content

Commit

Permalink
Render full numpy numeric names (e.g. numpy.int32)
Browse files Browse the repository at this point in the history
  • Loading branch information
sizmailov authored and wjakob committed Jun 10, 2020
1 parent 63df87f commit 22b2504
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 32 deletions.
6 changes: 3 additions & 3 deletions include/pybind11/numpy.h
Expand Up @@ -1007,22 +1007,22 @@ struct npy_format_descriptor_name;
template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
static constexpr auto name = _<std::is_same<T, bool>::value>(
_("bool"), _<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>()
_("bool"), _<std::is_signed<T>::value>("numpy.int", "numpy.uint") + _<sizeof(T)*8>()
);
};

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
static constexpr auto name = _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
_("float") + _<sizeof(T)*8>(), _("longdouble")
_("numpy.float") + _<sizeof(T)*8>(), _("numpy.longdouble")
);
};

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
static constexpr auto name = _<std::is_same<typename T::value_type, float>::value
|| std::is_same<typename T::value_type, double>::value>(
_("complex") + _<sizeof(typename T::value_type)*16>(), _("longcomplex")
_("numpy.complex") + _<sizeof(typename T::value_type)*16>(), _("numpy.longcomplex")
);
};

Expand Down
34 changes: 19 additions & 15 deletions tests/test_eigen.py
Expand Up @@ -79,15 +79,17 @@ def test_mutator_descriptors():
m.fixed_mutator_a(zc)
with pytest.raises(TypeError) as excinfo:
m.fixed_mutator_r(zc)
assert ('(arg0: numpy.ndarray[float32[5, 6], flags.writeable, flags.c_contiguous]) -> None'
assert ('(arg0: numpy.ndarray[numpy.float32[5, 6],'
' flags.writeable, flags.c_contiguous]) -> None'
in str(excinfo.value))
with pytest.raises(TypeError) as excinfo:
m.fixed_mutator_c(zr)
assert ('(arg0: numpy.ndarray[float32[5, 6], flags.writeable, flags.f_contiguous]) -> None'
assert ('(arg0: numpy.ndarray[numpy.float32[5, 6],'
' flags.writeable, flags.f_contiguous]) -> None'
in str(excinfo.value))
with pytest.raises(TypeError) as excinfo:
m.fixed_mutator_a(np.array([[1, 2], [3, 4]], dtype='float32'))
assert ('(arg0: numpy.ndarray[float32[5, 6], flags.writeable]) -> None'
assert ('(arg0: numpy.ndarray[numpy.float32[5, 6], flags.writeable]) -> None'
in str(excinfo.value))
zr.flags.writeable = False
with pytest.raises(TypeError):
Expand Down Expand Up @@ -179,15 +181,15 @@ def test_negative_stride_from_python(msg):
m.double_threer(second_row)
assert msg(excinfo.value) == """
double_threer(): incompatible function arguments. The following argument types are supported:
1. (arg0: numpy.ndarray[float32[1, 3], flags.writeable]) -> None
1. (arg0: numpy.ndarray[numpy.float32[1, 3], flags.writeable]) -> None
Invoked with: """ + repr(np.array([ 5., 4., 3.], dtype='float32')) # noqa: E501 line too long

with pytest.raises(TypeError) as excinfo:
m.double_threec(second_col)
assert msg(excinfo.value) == """
double_threec(): incompatible function arguments. The following argument types are supported:
1. (arg0: numpy.ndarray[float32[3, 1], flags.writeable]) -> None
1. (arg0: numpy.ndarray[numpy.float32[3, 1], flags.writeable]) -> None
Invoked with: """ + repr(np.array([ 7., 4., 1.], dtype='float32')) # noqa: E501 line too long

Expand Down Expand Up @@ -607,17 +609,19 @@ def test_special_matrix_objects():

def test_dense_signature(doc):
assert doc(m.double_col) == """
double_col(arg0: numpy.ndarray[float32[m, 1]]) -> numpy.ndarray[float32[m, 1]]
double_col(arg0: numpy.ndarray[numpy.float32[m, 1]]) -> numpy.ndarray[numpy.float32[m, 1]]
"""
assert doc(m.double_row) == """
double_row(arg0: numpy.ndarray[float32[1, n]]) -> numpy.ndarray[float32[1, n]]
"""
assert doc(m.double_complex) == """
double_complex(arg0: numpy.ndarray[complex64[m, 1]]) -> numpy.ndarray[complex64[m, 1]]
"""
assert doc(m.double_mat_rm) == """
double_mat_rm(arg0: numpy.ndarray[float32[m, n]]) -> numpy.ndarray[float32[m, n]]
double_row(arg0: numpy.ndarray[numpy.float32[1, n]]) -> numpy.ndarray[numpy.float32[1, n]]
"""
assert doc(m.double_complex) == ("""
double_complex(arg0: numpy.ndarray[numpy.complex64[m, 1]])"""
""" -> numpy.ndarray[numpy.complex64[m, 1]]
""")
assert doc(m.double_mat_rm) == ("""
double_mat_rm(arg0: numpy.ndarray[numpy.float32[m, n]])"""
""" -> numpy.ndarray[numpy.float32[m, n]]
""")


def test_named_arguments():
Expand Down Expand Up @@ -654,10 +658,10 @@ def test_sparse():
@pytest.requires_eigen_and_scipy
def test_sparse_signature(doc):
assert doc(m.sparse_copy_r) == """
sparse_copy_r(arg0: scipy.sparse.csr_matrix[float32]) -> scipy.sparse.csr_matrix[float32]
sparse_copy_r(arg0: scipy.sparse.csr_matrix[numpy.float32]) -> scipy.sparse.csr_matrix[numpy.float32]
""" # noqa: E501 line too long
assert doc(m.sparse_copy_c) == """
sparse_copy_c(arg0: scipy.sparse.csc_matrix[float32]) -> scipy.sparse.csc_matrix[float32]
sparse_copy_c(arg0: scipy.sparse.csc_matrix[numpy.float32]) -> scipy.sparse.csc_matrix[numpy.float32]
""" # noqa: E501 line too long


Expand Down
18 changes: 9 additions & 9 deletions tests/test_numpy_array.py
Expand Up @@ -286,13 +286,13 @@ def test_overload_resolution(msg):
m.overloaded("not an array")
assert msg(excinfo.value) == """
overloaded(): incompatible function arguments. The following argument types are supported:
1. (arg0: numpy.ndarray[float64]) -> str
2. (arg0: numpy.ndarray[float32]) -> str
3. (arg0: numpy.ndarray[int32]) -> str
4. (arg0: numpy.ndarray[uint16]) -> str
5. (arg0: numpy.ndarray[int64]) -> str
6. (arg0: numpy.ndarray[complex128]) -> str
7. (arg0: numpy.ndarray[complex64]) -> str
1. (arg0: numpy.ndarray[numpy.float64]) -> str
2. (arg0: numpy.ndarray[numpy.float32]) -> str
3. (arg0: numpy.ndarray[numpy.int32]) -> str
4. (arg0: numpy.ndarray[numpy.uint16]) -> str
5. (arg0: numpy.ndarray[numpy.int64]) -> str
6. (arg0: numpy.ndarray[numpy.complex128]) -> str
7. (arg0: numpy.ndarray[numpy.complex64]) -> str
Invoked with: 'not an array'
"""
Expand All @@ -307,8 +307,8 @@ def test_overload_resolution(msg):
assert m.overloaded3(np.array([1], dtype='intc')) == 'int'
expected_exc = """
overloaded3(): incompatible function arguments. The following argument types are supported:
1. (arg0: numpy.ndarray[int32]) -> str
2. (arg0: numpy.ndarray[float64]) -> str
1. (arg0: numpy.ndarray[numpy.int32]) -> str
2. (arg0: numpy.ndarray[numpy.float64]) -> str
Invoked with: """

Expand Down
10 changes: 5 additions & 5 deletions tests/test_numpy_vectorize.py
Expand Up @@ -109,7 +109,7 @@ def test_type_selection():

def test_docs(doc):
assert doc(m.vectorized_func) == """
vectorized_func(arg0: numpy.ndarray[int32], arg1: numpy.ndarray[float32], arg2: numpy.ndarray[float64]) -> object
vectorized_func(arg0: numpy.ndarray[numpy.int32], arg1: numpy.ndarray[numpy.float32], arg2: numpy.ndarray[numpy.float64]) -> object
""" # noqa: E501 line too long


Expand Down Expand Up @@ -160,12 +160,12 @@ def test_passthrough_arguments(doc):
assert doc(m.vec_passthrough) == (
"vec_passthrough(" + ", ".join([
"arg0: float",
"arg1: numpy.ndarray[float64]",
"arg2: numpy.ndarray[float64]",
"arg3: numpy.ndarray[int32]",
"arg1: numpy.ndarray[numpy.float64]",
"arg2: numpy.ndarray[numpy.float64]",
"arg3: numpy.ndarray[numpy.int32]",
"arg4: int",
"arg5: m.numpy_vectorize.NonPODClass",
"arg6: numpy.ndarray[float64]"]) + ") -> object")
"arg6: numpy.ndarray[numpy.float64]"]) + ") -> object")

b = np.array([[10, 20, 30]], dtype='float64')
c = np.array([100, 200]) # NOT a vectorized argument
Expand Down

0 comments on commit 22b2504

Please sign in to comment.