From acfec3e9af5d351df88a7d71786f11a12f7545d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20V=C3=B6r=C3=B6s?= Date: Sat, 10 Feb 2024 20:46:34 +0100 Subject: [PATCH] fix reshape (#660) --- code/ndarray.c | 18 +++++------------- code/ulab.c | 2 +- tests/2d/numpy/reshape.py | 17 +++++++++++++++++ tests/2d/numpy/reshape.py.exp | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 14 deletions(-) create mode 100644 tests/2d/numpy/reshape.py create mode 100644 tests/2d/numpy/reshape.py.exp diff --git a/code/ndarray.c b/code/ndarray.c index ffd3d621..84ce8495 100644 --- a/code/ndarray.c +++ b/code/ndarray.c @@ -558,13 +558,9 @@ ndarray_obj_t *ndarray_new_dense_ndarray(uint8_t ndim, size_t *shape, uint8_t dt ndarray_obj_t *ndarray_new_ndarray_from_tuple(mp_obj_tuple_t *_shape, uint8_t dtype) { // creates a dense array from a tuple // the function should work in the general n-dimensional case - size_t *shape = m_new(size_t, ULAB_MAX_DIMS); - for(size_t i = 0; i < ULAB_MAX_DIMS; i++) { - if(i >= _shape->len) { - shape[ULAB_MAX_DIMS - 1 - i] = 0; - } else { - shape[ULAB_MAX_DIMS - 1 - i] = mp_obj_get_int(_shape->items[i]); - } + size_t *shape = m_new0(size_t, ULAB_MAX_DIMS); + for(size_t i = 0; i < _shape->len; i++) { + shape[ULAB_MAX_DIMS - 1 - i] = mp_obj_get_int(_shape->items[_shape->len - 1 - i]); } return ndarray_new_dense_ndarray(_shape->len, shape, dtype); } @@ -2021,7 +2017,7 @@ mp_obj_t ndarray_reshape_core(mp_obj_t oin, mp_obj_t _shape, bool inplace) { mp_obj_t *items = m_new(mp_obj_t, 1); items[0] = _shape; shape = mp_obj_new_tuple(1, items); - } else { + } else { // at this point it's certain that _shape is a tuple shape = MP_OBJ_TO_PTR(_shape); } @@ -2072,11 +2068,7 @@ mp_obj_t ndarray_reshape_core(mp_obj_t oin, mp_obj_t _shape, bool inplace) { if(inplace) { mp_raise_ValueError(MP_ERROR_TEXT("cannot assign new shape")); } - if(mp_obj_is_type(_shape, &mp_type_tuple)) { - ndarray = ndarray_new_ndarray_from_tuple(shape, source->dtype); - } else { - ndarray = ndarray_new_linear_array(source->len, source->dtype); - } + ndarray = ndarray_new_dense_ndarray(shape->len, new_shape, source->dtype); ndarray_copy_array(source, ndarray, 0); } return MP_OBJ_FROM_PTR(ndarray); diff --git a/code/ulab.c b/code/ulab.c index f55768bc..df73f7bb 100644 --- a/code/ulab.c +++ b/code/ulab.c @@ -33,7 +33,7 @@ #include "user/user.h" #include "utils/utils.h" -#define ULAB_VERSION 6.5.0 +#define ULAB_VERSION 6.5.1 #define xstr(s) str(s) #define str(s) #s diff --git a/tests/2d/numpy/reshape.py b/tests/2d/numpy/reshape.py new file mode 100644 index 00000000..7f4add6a --- /dev/null +++ b/tests/2d/numpy/reshape.py @@ -0,0 +1,17 @@ +try: + from ulab import numpy as np +except ImportError: + import numpy as np + +dtypes = (np.uint8, np.int8, np.uint16, np.int16, np.float) + +for dtype in dtypes: + print() + print('=' * 50) + a = np.array(range(12), dtype=dtype).reshape((3, 4)) + print(a) + b = a[0,:] + print(b.reshape((1,4))) + b = a[:,0] + print(b.reshape((1,3))) + diff --git a/tests/2d/numpy/reshape.py.exp b/tests/2d/numpy/reshape.py.exp new file mode 100644 index 00000000..806a26c3 --- /dev/null +++ b/tests/2d/numpy/reshape.py.exp @@ -0,0 +1,35 @@ + +================================================== +array([[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11]], dtype=uint8) +array([[0, 1, 2, 3]], dtype=uint8) +array([[0, 4, 8]], dtype=uint8) + +================================================== +array([[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11]], dtype=int8) +array([[0, 1, 2, 3]], dtype=int8) +array([[0, 4, 8]], dtype=int8) + +================================================== +array([[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11]], dtype=uint16) +array([[0, 1, 2, 3]], dtype=uint16) +array([[0, 4, 8]], dtype=uint16) + +================================================== +array([[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11]], dtype=int16) +array([[0, 1, 2, 3]], dtype=int16) +array([[0, 4, 8]], dtype=int16) + +================================================== +array([[0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + [8.0, 9.0, 10.0, 11.0]], dtype=float64) +array([[0.0, 1.0, 2.0, 3.0]], dtype=float64) +array([[0.0, 4.0, 8.0]], dtype=float64)