diff --git a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h index a93cb8bb5aad2..7c64e5c2ce457 100644 --- a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h +++ b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h @@ -2245,31 +2245,52 @@ struct Vectorized()>> { return Vectorized{ret}; } + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + static typename Vectorized::vinner_type real_neg(const typename Vectorized::vinner_type &a) + { + const auto swap_mask = ZSimdVectBinary{ + 0, 1, 2, 3, 20, 21, 22, 23, 8, 9, 10, 11, 28, 29, 30, 31}; + + auto a_neg = a.neg(); + vtype v0 = vec_perm(a_neg.vec0(), a.vec0(), swap_mask); + vtype v1 = vec_perm(a_neg.vec1(), a.vec1(), swap_mask); + return {v0, v1}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + static typename Vectorized::vinner_type real_neg(const typename Vectorized::vinner_type &a) + { + auto a_neg = a.neg(); + auto v0 = vec_permi(a_neg.vec0(), a.vec0(), 1); + auto v1 = vec_permi(a_neg.vec1(), a.vec1(), 1); + return { v0, v1 }; + } + Vectorized inline operator/(const Vectorized& b) const { - // re + im*i = (a + bi) / (c + di) - // re = (ac + bd)/abs_2() - // im = (bc - ad)/abs_2() - vinner_type bv = b.vec(); -#if !defined(ZVECTOR_SIMULATE_X86_MULT) - vinner_type vi = bv.mergeo(); - vinner_type vr = bv.mergee(); - vinner_type abs_b = b.abs_2_(); - vi = vi ^ isign_mask(); - vinner_type ret = _vec * vr; - vinner_type vx_swapped = _vec.swapped(); - ret = fmadd(vx_swapped, vi, ret); - ret = ret / abs_b; -#else - // Vectorized x86 simulation - vinner_type ac_bd = _vec * b; - vinner_type d_c = bv.swapped(); - d_c = d_c ^ rsign_mask(); - vinner_type ad_bc = _vec * d_c; - vinner_type abs_b = b.abs_2_(); - vinner_type re_im = vinner_type::horizontal_add_perm(ac_bd, ad_bc); - vinner_type ret = re_im / abs_b; -#endif - return Vectorized{ret}; + // Unfortunately, this breaks some tests + // Implement it like it's done for avx2 + auto fabs_cd = b.vec().abs(); // |c| |d| + auto fabs_dc = fabs_cd.swapped(); // |d| |c| + auto scale = vinner_type {1.0} / maximum(fabs_cd, fabs_dc); // 1/sc 1/sc + auto a2 = vec() * scale; // a/sc b/sc + auto b2 = b.vec() * scale; // c/sc d/sc + auto acbd2 = a2 * b2; // ac/sc^2 bd/sc^2 + + auto dc2 = b2.swapped(); // d/sc c/sc + dc2 = Vectorized::real_neg(dc2); // -d/|c,d| c/sc + auto adbc2 = a2 * dc2; // -ad/sc^2 bc/sc^2 + auto sum1 = acbd2 + acbd2.swapped(); // (ac+bd)/sc^2 (ac+bd)/sc^2 + auto sum2 = adbc2 + adbc2.swapped(); // (bc-ad)/sc^2 (bc-ad)/sc^2 + auto res2 = vinner_type::mergee(sum1, sum2); // (ac+bd)/sc^2 (bc-ad)/sc^2 + + // get the denominator + auto denom2 = Vectorized{b2}.abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 + res2 = res2 / denom2; + return Vectorized{ res2 }; } Vectorized angle2_() const {