Skip to content

Commit

Permalink
Merge pull request numpy#298 from howjmay/vrecps_vrsqrts
Browse files Browse the repository at this point in the history
feat: Add vrecpsq_f32 and vrsqrtsq_f32
  • Loading branch information
howjmay committed Dec 23, 2023
2 parents 90ef4fc + 7a3604d commit 51067ad
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
8 changes: 6 additions & 2 deletions neon2rvv.h
Expand Up @@ -2248,13 +2248,17 @@ FORCE_INLINE float32x2_t vrecps_f32(float32x2_t __a, float32x2_t __b) {
return __riscv_vfnmsac_vv_f32m1(vdup_n_f32(2.0), __a, __b, 2);
}

// FORCE_INLINE float32x4_t vrecpsq_f32(float32x4_t __a, float32x4_t __b);
FORCE_INLINE float32x4_t vrecpsq_f32(float32x4_t __a, float32x4_t __b) {
return __riscv_vfnmsac_vv_f32m1(vdupq_n_f32(2.0), __a, __b, 4);
}

FORCE_INLINE float32x2_t vrsqrts_f32(float32x2_t __a, float32x2_t __b) {
return __riscv_vfdiv_vf_f32m1(__riscv_vfnmsac_vv_f32m1(vdup_n_f32(3.0), __a, __b, 2), 2.0, 2);
}

// FORCE_INLINE float32x4_t vrsqrtsq_f32(float32x4_t __a, float32x4_t __b);
FORCE_INLINE float32x4_t vrsqrtsq_f32(float32x4_t __a, float32x4_t __b) {
return __riscv_vfdiv_vf_f32m1(__riscv_vfnmsac_vv_f32m1(vdupq_n_f32(3.0), __a, __b, 4), 2.0, 4);
}

FORCE_INLINE int8x8_t vshl_s8(int8x8_t __a, int8x8_t __b) {
// implementation only works within defined range 'b' in [0, 7]
Expand Down
40 changes: 38 additions & 2 deletions tests/impl.cpp
Expand Up @@ -8600,7 +8600,25 @@ result_t test_vrecps_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#endif // ENABLE_TEST_ALL
}

result_t test_vrecpsq_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }
result_t test_vrecpsq_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
const float *_a = impl.test_cases_float_pointer1;
const float *_b = impl.test_cases_float_pointer2;
float _c[4];
_c[0] = 2.0 - _a[0] * _b[0];
_c[1] = 2.0 - _a[1] * _b[1];
_c[2] = 2.0 - _a[2] * _b[2];
_c[3] = 2.0 - _a[3] * _b[3];

float32x4_t a = vld1q_f32(_a);
float32x4_t b = vld1q_f32(_b);
float32x4_t c = vrecpsq_f32(a, b);

return validate_float_error(c, _c[0], _c[1], _c[2], _c[3], 0.0001f);
#else
return TEST_UNIMPL;
#endif // ENABLE_TEST_ALL
}

result_t test_vrsqrts_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
Expand All @@ -8620,7 +8638,25 @@ result_t test_vrsqrts_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#endif // ENABLE_TEST_ALL
}

result_t test_vrsqrtsq_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }
result_t test_vrsqrtsq_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
const float *_a = impl.test_cases_float_pointer1;
const float *_b = impl.test_cases_float_pointer2;
float _c[4];
_c[0] = (3.0 - _a[0] * _b[0]) / 2.0;
_c[1] = (3.0 - _a[1] * _b[1]) / 2.0;
_c[2] = (3.0 - _a[2] * _b[2]) / 2.0;
_c[3] = (3.0 - _a[3] * _b[3]) / 2.0;

float32x4_t a = vld1q_f32(_a);
float32x4_t b = vld1q_f32(_b);
float32x4_t c = vrsqrtsq_f32(a, b);

return validate_float_error(c, _c[0], _c[1], _c[2], _c[3], 0.0001f);
#else
return TEST_UNIMPL;
#endif // ENABLE_TEST_ALL
}

result_t test_vshl_s8(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
Expand Down

0 comments on commit 51067ad

Please sign in to comment.