Skip to content

Commit

Permalink
fix the np.delete bug (#653)
Browse files Browse the repository at this point in the history
* fix the `np.delete` bug

* fix the `np.delete` bug, add unittest code

* increment the version number and update the change log

* update the expected file `delete.py.exp`
  • Loading branch information
hiltay committed Dec 25, 2023
1 parent e329206 commit 7a93706
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 1 deletion.
5 changes: 5 additions & 0 deletions code/numpy/transform.c
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ static mp_obj_t transform_delete(size_t n_args, const mp_obj_t *pos_args, mp_map
mp_raise_TypeError(MP_ERROR_TEXT("wrong index type"));
}
index_len = MP_OBJ_SMALL_INT_VALUE(mp_obj_len_maybe(indices));
if (index_len == 0){
// if the second positional argument is empty
// return the original array
return MP_OBJ_FROM_PTR(ndarray);
}
}

if(index_len > axis_len) {
Expand Down
2 changes: 1 addition & 1 deletion code/ulab.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include "user/user.h"
#include "utils/utils.h"

#define ULAB_VERSION 6.4.2
#define ULAB_VERSION 6.4.3
#define xstr(s) str(s)
#define str(s) #s

Expand Down
6 changes: 6 additions & 0 deletions docs/ulab-change-log.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
Mon, 25 Dec 2023

version 6.4.3

fix the 'np.delete' error that occurs when passing an empty iterable object as the second positional argument (#653)

Thu, 11 Dec 2023

version 6.4.2
Expand Down
2 changes: 2 additions & 0 deletions tests/2d/numpy/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
a = np.array(range(25), dtype=dtype).reshape((5,5))
print(np.delete(a, [1, 2], axis=0))
print(np.delete(a, [1, 2], axis=1))
print(np.delete(a, [], axis=1))
print(np.delete(a, [1, 5, 10]))
print(np.delete(a, []))

for dtype in dtypes:
a = np.array(range(25), dtype=dtype).reshape((5,5))
Expand Down
50 changes: 50 additions & 0 deletions tests/2d/numpy/delete.py.exp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@ array([[0, 3, 4],
[10, 13, 14],
[15, 18, 19],
[20, 23, 24]], dtype=uint8)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=uint8)
array([0, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=uint8)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=uint8)
array([[0, 1, 2, 3, 4],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=int8)
Expand All @@ -15,7 +25,17 @@ array([[0, 3, 4],
[10, 13, 14],
[15, 18, 19],
[20, 23, 24]], dtype=int8)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=int8)
array([0, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=int8)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=int8)
array([[0, 1, 2, 3, 4],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=uint16)
Expand All @@ -24,7 +44,17 @@ array([[0, 3, 4],
[10, 13, 14],
[15, 18, 19],
[20, 23, 24]], dtype=uint16)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=uint16)
array([0, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=uint16)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=uint16)
array([[0, 1, 2, 3, 4],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=int16)
Expand All @@ -33,7 +63,17 @@ array([[0, 3, 4],
[10, 13, 14],
[15, 18, 19],
[20, 23, 24]], dtype=int16)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=int16)
array([0, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=int16)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=int16)
array([[0.0, 1.0, 2.0, 3.0, 4.0],
[15.0, 16.0, 17.0, 18.0, 19.0],
[20.0, 21.0, 22.0, 23.0, 24.0]], dtype=float64)
Expand All @@ -42,7 +82,17 @@ array([[0.0, 3.0, 4.0],
[10.0, 13.0, 14.0],
[15.0, 18.0, 19.0],
[20.0, 23.0, 24.0]], dtype=float64)
array([[0.0, 1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0, 9.0],
[10.0, 11.0, 12.0, 13.0, 14.0],
[15.0, 16.0, 17.0, 18.0, 19.0],
[20.0, 21.0, 22.0, 23.0, 24.0]], dtype=float64)
array([0.0, 2.0, 3.0, 4.0, 6.0, 7.0, 8.0, 9.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0], dtype=float64)
array([[0.0, 1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0, 9.0],
[10.0, 11.0, 12.0, 13.0, 14.0],
[15.0, 16.0, 17.0, 18.0, 19.0],
[20.0, 21.0, 22.0, 23.0, 24.0]], dtype=float64)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[15, 16, 17, 18, 19],
Expand Down

0 comments on commit 7a93706

Please sign in to comment.