diff --git a/code/numpy/poly.c b/code/numpy/poly.c index 62eb1688..ff4965d8 100644 --- a/code/numpy/poly.c +++ b/code/numpy/poly.c @@ -145,9 +145,18 @@ MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(poly_polyfit_obj, 2, 3, poly_polyfit); #if ULAB_NUMPY_HAS_POLYVAL +static mp_float_t poly_eval(mp_float_t x, mp_float_t *p, uint8_t plen) { + mp_float_t y = p[0]; + for(uint8_t j=0; j < plen-1; j++) { + y *= x; + y += p[j+1]; + } + return y; +} + mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) { - if(!ndarray_object_is_array_like(o_p) || !ndarray_object_is_array_like(o_x)) { - mp_raise_TypeError(translate("inputs are not iterable")); + if(!ndarray_object_is_array_like(o_p)) { + mp_raise_TypeError(translate("input is not iterable")); } #if ULAB_SUPPORTS_COMPLEX ndarray_obj_t *input; @@ -171,6 +180,10 @@ mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) { i++; } + if(!ndarray_object_is_array_like(o_x)) { + return mp_obj_new_float(poly_eval(mp_obj_get_float(o_x), p, plen)); + } + // polynomials are going to be of type float, except, when both // the coefficients and the independent variable are integers ndarray_obj_t *ndarray; @@ -198,13 +211,7 @@ mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) { #endif size_t l = 0; do { - mp_float_t y = p[0]; - mp_float_t _x = func(sarray); - for(uint8_t m=0; m < plen-1; m++) { - y *= _x; - y += p[m+1]; - } - *array++ = y; + *array++ = poly_eval(func(sarray), p, plen); sarray += source->strides[ULAB_MAX_DIMS - 1]; l++; } while(l < source->shape[ULAB_MAX_DIMS - 1]); @@ -233,13 +240,7 @@ mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) { mp_obj_iter_buf_t x_buf; mp_obj_t x_item, x_iterable = mp_getiter(o_x, &x_buf); while ((x_item = mp_iternext(x_iterable)) != MP_OBJ_STOP_ITERATION) { - mp_float_t _x = mp_obj_get_float(x_item); - mp_float_t y = p[0]; - for(uint8_t j=0; j < plen-1; j++) { - y *= _x; - y += p[j+1]; - } - *array++ = y; + *array++ = poly_eval(mp_obj_get_float(x_item), p, plen); } } m_del(mp_float_t, p, plen); diff --git a/code/ulab.c b/code/ulab.c index fc770e9b..07113d18 100644 --- a/code/ulab.c +++ b/code/ulab.c @@ -33,7 +33,7 @@ #include "user/user.h" #include "utils/utils.h" -#define ULAB_VERSION 6.3.2 +#define ULAB_VERSION 6.3.3 #define xstr(s) str(s) #define str(s) #s