@@ -71,53 +71,52 @@ Tensor& avg_pool2d_out(
7171 // @lint-ignore CLANGTIDY facebook-hte-CArray
7272 static constexpr const char op_name[] = " avg_pool2d.out" ;
7373
74- ET_SWITCH_FLOATHBF16_TYPES_AND (
75- Long, in_type, ctx, op_name, CTYPE, [&]() {
76- if (divisor_override.has_value ()) {
77- int64_t divisor = divisor_override.value ();
78- // If divisor_override is specified, then we don't need to use `count`
79- // in the calculation. Simply sum x / divisor to get the output.
80- apply_kernel_2d_reduce_then_map_fn<CTYPE>(
81- [](const CTYPE in_val,
82- int64_t in_idx,
83- CTYPE accum,
84- int64_t accum_idx) {
85- // Average pooling does not track indexes, so return 0 for
86- // accum_idx
87- return std::tuple<CTYPE, int64_t >(in_val + accum, 0 );
88- },
89- [divisor](const int64_t count, const CTYPE accum) {
90- return accum / static_cast <CTYPE>(divisor);
91- },
92- count_include_pad,
93- in,
94- kernel_size,
95- stride,
96- padding,
97- {},
98- out);
99- } else {
100- apply_kernel_2d_reduce_then_map_fn<CTYPE>(
101- [](const CTYPE in_val,
102- int64_t in_idx,
103- CTYPE accum,
104- int64_t accum_idx) {
105- // Average pooling does not track indexes, so return 0 for
106- // accum_idx
107- return std::tuple<CTYPE, int64_t >(in_val + accum, 0 );
108- },
109- [](const int64_t count, const CTYPE accum) {
110- return accum / static_cast <CTYPE>(count);
111- },
112- count_include_pad,
113- in,
114- kernel_size,
115- stride,
116- padding,
117- {},
118- out);
119- }
120- });
74+ ET_SWITCH_FLOATHBF16_TYPES_AND (Long, in_type, ctx, op_name, CTYPE, [&]() {
75+ if (divisor_override.has_value ()) {
76+ int64_t divisor = divisor_override.value ();
77+ // If divisor_override is specified, then we don't need to use `count`
78+ // in the calculation. Simply sum x / divisor to get the output.
79+ apply_kernel_2d_reduce_then_map_fn<CTYPE>(
80+ [](const CTYPE in_val,
81+ int64_t in_idx,
82+ CTYPE accum,
83+ int64_t accum_idx) {
84+ // Average pooling does not track indexes, so return 0 for
85+ // accum_idx
86+ return std::tuple<CTYPE, int64_t >(in_val + accum, 0 );
87+ },
88+ [divisor](const int64_t count, const CTYPE accum) {
89+ return accum / static_cast <CTYPE>(divisor);
90+ },
91+ count_include_pad,
92+ in,
93+ kernel_size,
94+ stride,
95+ padding,
96+ {},
97+ out);
98+ } else {
99+ apply_kernel_2d_reduce_then_map_fn<CTYPE>(
100+ [](const CTYPE in_val,
101+ int64_t in_idx,
102+ CTYPE accum,
103+ int64_t accum_idx) {
104+ // Average pooling does not track indexes, so return 0 for
105+ // accum_idx
106+ return std::tuple<CTYPE, int64_t >(in_val + accum, 0 );
107+ },
108+ [](const int64_t count, const CTYPE accum) {
109+ return accum / static_cast <CTYPE>(count);
110+ },
111+ count_include_pad,
112+ in,
113+ kernel_size,
114+ stride,
115+ padding,
116+ {},
117+ out);
118+ }
119+ });
121120
122121 return out;
123122}
0 commit comments