Skip to content

Commit

Permalink
Polyval handles non-array as second argument (#601)
Browse files Browse the repository at this point in the history
* Factorize polynomial evaluation

* Polyval handles non-array as second argument

---------

Co-authored-by: Zoltán Vörös <zvoros@gmail.com>
  • Loading branch information
HugoNumworks and v923z committed Jun 27, 2023
1 parent 319df10 commit 112d4f8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
33 changes: 17 additions & 16 deletions code/numpy/poly.c
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,18 @@ MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(poly_polyfit_obj, 2, 3, poly_polyfit);

#if ULAB_NUMPY_HAS_POLYVAL

static mp_float_t poly_eval(mp_float_t x, mp_float_t *p, uint8_t plen) {
mp_float_t y = p[0];
for(uint8_t j=0; j < plen-1; j++) {
y *= x;
y += p[j+1];
}
return y;
}

mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) {
if(!ndarray_object_is_array_like(o_p) || !ndarray_object_is_array_like(o_x)) {
mp_raise_TypeError(translate("inputs are not iterable"));
if(!ndarray_object_is_array_like(o_p)) {
mp_raise_TypeError(translate("input is not iterable"));
}
#if ULAB_SUPPORTS_COMPLEX
ndarray_obj_t *input;
Expand All @@ -171,6 +180,10 @@ mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) {
i++;
}

if(!ndarray_object_is_array_like(o_x)) {
return mp_obj_new_float(poly_eval(mp_obj_get_float(o_x), p, plen));
}

// polynomials are going to be of type float, except, when both
// the coefficients and the independent variable are integers
ndarray_obj_t *ndarray;
Expand Down Expand Up @@ -198,13 +211,7 @@ mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) {
#endif
size_t l = 0;
do {
mp_float_t y = p[0];
mp_float_t _x = func(sarray);
for(uint8_t m=0; m < plen-1; m++) {
y *= _x;
y += p[m+1];
}
*array++ = y;
*array++ = poly_eval(func(sarray), p, plen);
sarray += source->strides[ULAB_MAX_DIMS - 1];
l++;
} while(l < source->shape[ULAB_MAX_DIMS - 1]);
Expand Down Expand Up @@ -233,13 +240,7 @@ mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) {
mp_obj_iter_buf_t x_buf;
mp_obj_t x_item, x_iterable = mp_getiter(o_x, &x_buf);
while ((x_item = mp_iternext(x_iterable)) != MP_OBJ_STOP_ITERATION) {
mp_float_t _x = mp_obj_get_float(x_item);
mp_float_t y = p[0];
for(uint8_t j=0; j < plen-1; j++) {
y *= _x;
y += p[j+1];
}
*array++ = y;
*array++ = poly_eval(mp_obj_get_float(x_item), p, plen);
}
}
m_del(mp_float_t, p, plen);
Expand Down
2 changes: 1 addition & 1 deletion code/ulab.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include "user/user.h"
#include "utils/utils.h"

#define ULAB_VERSION 6.3.2
#define ULAB_VERSION 6.3.3
#define xstr(s) str(s)
#define str(s) #s

Expand Down

0 comments on commit 112d4f8

Please sign in to comment.