Skip to content

Commit

Permalink
Floordiv (#593)
Browse files Browse the repository at this point in the history
* implement floor division

* fix 3D, 4D loops

* add missing array declaration in 3D, and 4D

* Add test cases for floor division and fix it for ints (#599)

* Add test cases for floor division

* Fix define name in comment

* Fix floor division of ints

---------

Co-authored-by: Maciej Sokołowski <matemaciek@gmail.com>
  • Loading branch information
v923z and matemaciek committed Apr 23, 2023
1 parent 4407f8c commit 47ad73a
Show file tree
Hide file tree
Showing 6 changed files with 392 additions and 2 deletions.
6 changes: 6 additions & 0 deletions code/ndarray.c
Original file line number Diff line number Diff line change
Expand Up @@ -1936,6 +1936,12 @@ mp_obj_t ndarray_binary_op(mp_binary_op_t _op, mp_obj_t lobj, mp_obj_t robj) {
return ndarray_binary_power(lhs, rhs, ndim, shape, lstrides, rstrides);
break;
#endif
#if NDARRAY_HAS_BINARY_OP_FLOOR_DIVIDE
case MP_BINARY_OP_FLOOR_DIVIDE:
COMPLEX_DTYPE_NOT_IMPLEMENTED(lhs->dtype);
return ndarray_binary_floor_divide(lhs, rhs, ndim, shape, lstrides, rstrides);
break;
#endif
default:
return MP_OBJ_NULL; // op not supported
break;
Expand Down
98 changes: 97 additions & 1 deletion code/ndarray_operators.c
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,102 @@ mp_obj_t ndarray_binary_true_divide(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
}
#endif /* NDARRAY_HAS_BINARY_OP_TRUE_DIVIDE */

#if NDARRAY_HAS_BINARY_OP_FLOOR_DIVIDE
mp_obj_t ndarray_binary_floor_divide(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides) {

ndarray_obj_t *results = NULL;
uint8_t *larray = (uint8_t *)lhs->array;
uint8_t *rarray = (uint8_t *)rhs->array;

if(lhs->dtype == NDARRAY_UINT8) {
if(rhs->dtype == NDARRAY_UINT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT8);
FLOOR_DIVIDE_LOOP_UINT(results, uint8_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_INT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
FLOOR_DIVIDE_LOOP(results, int16_t, uint8_t, int8_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_UINT16) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
FLOOR_DIVIDE_LOOP_UINT(results, uint16_t, uint8_t, uint16_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_INT16) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
FLOOR_DIVIDE_LOOP(results, int16_t, uint8_t, int16_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_FLOAT) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, uint8_t, mp_float_t, larray, lstrides, rarray, rstrides);
}
} else if(lhs->dtype == NDARRAY_INT8) {
if(rhs->dtype == NDARRAY_UINT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
FLOOR_DIVIDE_LOOP(results, int16_t, int8_t, uint8_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_INT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT8);
FLOOR_DIVIDE_LOOP(results, int8_t, int8_t, int8_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_UINT16) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
FLOOR_DIVIDE_LOOP(results, uint16_t, int8_t, uint16_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_INT16) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
FLOOR_DIVIDE_LOOP(results, int16_t, int8_t, int16_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_FLOAT) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, int8_t, mp_float_t, larray, lstrides, rarray, rstrides);
}
} else if(lhs->dtype == NDARRAY_UINT16) {
if(rhs->dtype == NDARRAY_UINT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
FLOOR_DIVIDE_LOOP_UINT(results, uint16_t, uint16_t, uint8_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_INT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
FLOOR_DIVIDE_LOOP(results, uint16_t, uint16_t, int8_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_UINT16) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
FLOOR_DIVIDE_LOOP_UINT(results, uint16_t, uint16_t, uint16_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_INT16) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, uint16_t, int16_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_FLOAT) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, uint16_t, mp_float_t, larray, lstrides, rarray, rstrides);
}
} else if(lhs->dtype == NDARRAY_INT16) {
if(rhs->dtype == NDARRAY_UINT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
FLOOR_DIVIDE_LOOP(results, int16_t, int16_t, uint8_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_INT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
FLOOR_DIVIDE_LOOP(results, int16_t, int16_t, int8_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_UINT16) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, int16_t, uint16_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_INT16) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
FLOOR_DIVIDE_LOOP(results, int16_t, int16_t, int16_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_FLOAT) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, uint16_t, mp_float_t, larray, lstrides, rarray, rstrides);
}
} else if(lhs->dtype == NDARRAY_FLOAT) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
if(rhs->dtype == NDARRAY_UINT8) {
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, mp_float_t, uint8_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_INT8) {
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, mp_float_t, int8_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_UINT16) {
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, mp_float_t, uint16_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_INT16) {
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, mp_float_t, int16_t, larray, lstrides, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_FLOAT) {
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, mp_float_t, mp_float_t, larray, lstrides, rarray, rstrides);
}
}

return MP_OBJ_FROM_PTR(results);

}
#endif /* NDARRAY_HAS_BINARY_OP_FLOOR_DIVIDE */

#if NDARRAY_HAS_BINARY_OP_POWER
mp_obj_t ndarray_binary_power(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides) {
Expand Down Expand Up @@ -812,7 +908,7 @@ mp_obj_t ndarray_inplace_divide(ndarray_obj_t *lhs, ndarray_obj_t *rhs, int32_t
}
return MP_OBJ_FROM_PTR(lhs);
}
#endif /* NDARRAY_HAS_INPLACE_DIVIDE */
#endif /* NDARRAY_HAS_INPLACE_TRUE_DIVIDE */

#if NDARRAY_HAS_INPLACE_POWER
mp_obj_t ndarray_inplace_power(ndarray_obj_t *lhs, ndarray_obj_t *rhs, int32_t *rstrides) {
Expand Down
Loading

0 comments on commit 47ad73a

Please sign in to comment.