diff --git a/neon2rvv.h b/neon2rvv.h index 0cf9b21527b2..d67a75fd01a2 100644 --- a/neon2rvv.h +++ b/neon2rvv.h @@ -2248,13 +2248,17 @@ FORCE_INLINE float32x2_t vrecps_f32(float32x2_t __a, float32x2_t __b) { return __riscv_vfnmsac_vv_f32m1(vdup_n_f32(2.0), __a, __b, 2); } -// FORCE_INLINE float32x4_t vrecpsq_f32(float32x4_t __a, float32x4_t __b); +FORCE_INLINE float32x4_t vrecpsq_f32(float32x4_t __a, float32x4_t __b) { + return __riscv_vfnmsac_vv_f32m1(vdupq_n_f32(2.0), __a, __b, 4); +} FORCE_INLINE float32x2_t vrsqrts_f32(float32x2_t __a, float32x2_t __b) { return __riscv_vfdiv_vf_f32m1(__riscv_vfnmsac_vv_f32m1(vdup_n_f32(3.0), __a, __b, 2), 2.0, 2); } -// FORCE_INLINE float32x4_t vrsqrtsq_f32(float32x4_t __a, float32x4_t __b); +FORCE_INLINE float32x4_t vrsqrtsq_f32(float32x4_t __a, float32x4_t __b) { + return __riscv_vfdiv_vf_f32m1(__riscv_vfnmsac_vv_f32m1(vdupq_n_f32(3.0), __a, __b, 4), 2.0, 4); +} FORCE_INLINE int8x8_t vshl_s8(int8x8_t __a, int8x8_t __b) { // implementation only works within defined range 'b' in [0, 7] diff --git a/tests/impl.cpp b/tests/impl.cpp index 89a276b57c13..c14b5ff13f1a 100644 --- a/tests/impl.cpp +++ b/tests/impl.cpp @@ -8600,7 +8600,25 @@ result_t test_vrecps_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #endif // ENABLE_TEST_ALL } -result_t test_vrecpsq_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vrecpsq_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const float *_a = impl.test_cases_float_pointer1; + const float *_b = impl.test_cases_float_pointer2; + float _c[4]; + _c[0] = 2.0 - _a[0] * _b[0]; + _c[1] = 2.0 - _a[1] * _b[1]; + _c[2] = 2.0 - _a[2] * _b[2]; + _c[3] = 2.0 - _a[3] * _b[3]; + + float32x4_t a = vld1q_f32(_a); + float32x4_t b = vld1q_f32(_b); + float32x4_t c = vrecpsq_f32(a, b); + + return validate_float_error(c, _c[0], _c[1], _c[2], _c[3], 0.0001f); +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} result_t test_vrsqrts_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #ifdef ENABLE_TEST_ALL @@ -8620,7 +8638,25 @@ result_t test_vrsqrts_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #endif // ENABLE_TEST_ALL } -result_t test_vrsqrtsq_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vrsqrtsq_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const float *_a = impl.test_cases_float_pointer1; + const float *_b = impl.test_cases_float_pointer2; + float _c[4]; + _c[0] = (3.0 - _a[0] * _b[0]) / 2.0; + _c[1] = (3.0 - _a[1] * _b[1]) / 2.0; + _c[2] = (3.0 - _a[2] * _b[2]) / 2.0; + _c[3] = (3.0 - _a[3] * _b[3]) / 2.0; + + float32x4_t a = vld1q_f32(_a); + float32x4_t b = vld1q_f32(_b); + float32x4_t c = vrsqrtsq_f32(a, b); + + return validate_float_error(c, _c[0], _c[1], _c[2], _c[3], 0.0001f); +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} result_t test_vshl_s8(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #ifdef ENABLE_TEST_ALL