Skip to content

Commit

Permalink
implement (u)int8 upcasting rules as per documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
s-ol committed Dec 11, 2023
1 parent 3c835b7 commit 751d133
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions code/ndarray_operators.c
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ mp_obj_t ndarray_binary_add(ndarray_obj_t *lhs, ndarray_obj_t *rhs,

if(lhs->dtype == NDARRAY_UINT8) {
if(rhs->dtype == NDARRAY_UINT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
BINARY_LOOP(results, uint16_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, +);
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT8);
BINARY_LOOP(results, uint8_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, +);
} else if(rhs->dtype == NDARRAY_INT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
BINARY_LOOP(results, int16_t, uint8_t, int8_t, larray, lstrides, rarray, rstrides, +);
Expand Down Expand Up @@ -264,8 +264,8 @@ mp_obj_t ndarray_binary_multiply(ndarray_obj_t *lhs, ndarray_obj_t *rhs,

if(lhs->dtype == NDARRAY_UINT8) {
if(rhs->dtype == NDARRAY_UINT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
BINARY_LOOP(results, uint16_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, *);
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT8);
BINARY_LOOP(results, uint8_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, *);
} else if(rhs->dtype == NDARRAY_INT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
BINARY_LOOP(results, int16_t, uint8_t, int8_t, larray, lstrides, rarray, rstrides, *);
Expand Down
2 changes: 1 addition & 1 deletion tests/2d/numpy/operators.py.exp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ array([1.0, 2.0, 3.0], dtype=float64)
array([1.0, 32.0, 729.0], dtype=float64)
array([1.0, 32.0, 729.0], dtype=float64)
array([1.0, 32.0, 729.0], dtype=float64)
array([5, 7, 9], dtype=uint16)
array([5, 7, 9], dtype=uint8)
array([5, 7, 9], dtype=int16)
array([5, 7, 9], dtype=int8)
array([5, 7, 9], dtype=uint16)
Expand Down

0 comments on commit 751d133

Please sign in to comment.