Skip to content

Commit

Permalink
fix reshape (#660)
Browse files Browse the repository at this point in the history
  • Loading branch information
v923z committed Feb 10, 2024
1 parent 1c37edb commit acfec3e
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 14 deletions.
18 changes: 5 additions & 13 deletions code/ndarray.c
Original file line number Diff line number Diff line change
Expand Up @@ -558,13 +558,9 @@ ndarray_obj_t *ndarray_new_dense_ndarray(uint8_t ndim, size_t *shape, uint8_t dt
ndarray_obj_t *ndarray_new_ndarray_from_tuple(mp_obj_tuple_t *_shape, uint8_t dtype) {
// creates a dense array from a tuple
// the function should work in the general n-dimensional case
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
for(size_t i = 0; i < ULAB_MAX_DIMS; i++) {
if(i >= _shape->len) {
shape[ULAB_MAX_DIMS - 1 - i] = 0;
} else {
shape[ULAB_MAX_DIMS - 1 - i] = mp_obj_get_int(_shape->items[i]);
}
size_t *shape = m_new0(size_t, ULAB_MAX_DIMS);
for(size_t i = 0; i < _shape->len; i++) {
shape[ULAB_MAX_DIMS - 1 - i] = mp_obj_get_int(_shape->items[_shape->len - 1 - i]);
}
return ndarray_new_dense_ndarray(_shape->len, shape, dtype);
}
Expand Down Expand Up @@ -2021,7 +2017,7 @@ mp_obj_t ndarray_reshape_core(mp_obj_t oin, mp_obj_t _shape, bool inplace) {
mp_obj_t *items = m_new(mp_obj_t, 1);
items[0] = _shape;
shape = mp_obj_new_tuple(1, items);
} else {
} else { // at this point it's certain that _shape is a tuple
shape = MP_OBJ_TO_PTR(_shape);
}

Expand Down Expand Up @@ -2072,11 +2068,7 @@ mp_obj_t ndarray_reshape_core(mp_obj_t oin, mp_obj_t _shape, bool inplace) {
if(inplace) {
mp_raise_ValueError(MP_ERROR_TEXT("cannot assign new shape"));
}
if(mp_obj_is_type(_shape, &mp_type_tuple)) {
ndarray = ndarray_new_ndarray_from_tuple(shape, source->dtype);
} else {
ndarray = ndarray_new_linear_array(source->len, source->dtype);
}
ndarray = ndarray_new_dense_ndarray(shape->len, new_shape, source->dtype);
ndarray_copy_array(source, ndarray, 0);
}
return MP_OBJ_FROM_PTR(ndarray);
Expand Down
2 changes: 1 addition & 1 deletion code/ulab.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include "user/user.h"
#include "utils/utils.h"

#define ULAB_VERSION 6.5.0
#define ULAB_VERSION 6.5.1
#define xstr(s) str(s)
#define str(s) #s

Expand Down
17 changes: 17 additions & 0 deletions tests/2d/numpy/reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
try:
from ulab import numpy as np
except ImportError:
import numpy as np

dtypes = (np.uint8, np.int8, np.uint16, np.int16, np.float)

for dtype in dtypes:
print()
print('=' * 50)
a = np.array(range(12), dtype=dtype).reshape((3, 4))
print(a)
b = a[0,:]
print(b.reshape((1,4)))
b = a[:,0]
print(b.reshape((1,3)))

35 changes: 35 additions & 0 deletions tests/2d/numpy/reshape.py.exp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

==================================================
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=uint8)
array([[0, 1, 2, 3]], dtype=uint8)
array([[0, 4, 8]], dtype=uint8)

==================================================
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=int8)
array([[0, 1, 2, 3]], dtype=int8)
array([[0, 4, 8]], dtype=int8)

==================================================
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=uint16)
array([[0, 1, 2, 3]], dtype=uint16)
array([[0, 4, 8]], dtype=uint16)

==================================================
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=int16)
array([[0, 1, 2, 3]], dtype=int16)
array([[0, 4, 8]], dtype=int16)

==================================================
array([[0.0, 1.0, 2.0, 3.0],
[4.0, 5.0, 6.0, 7.0],
[8.0, 9.0, 10.0, 11.0]], dtype=float64)
array([[0.0, 1.0, 2.0, 3.0]], dtype=float64)
array([[0.0, 4.0, 8.0]], dtype=float64)

0 comments on commit acfec3e

Please sign in to comment.