Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor] support vertical reduction in cpp #97644

Closed
wants to merge 11 commits into from
19 changes: 19 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,25 @@ def run_node_alt(*args, **kwargs):
fn(torch.randn([8, 128]))
self.assertGreater(len(strings), 3)

def test_vertical_sum_cpu_only(self):
def fn1(a):
return a.sum(dim=0)

def fn2(a):
return a.sum(dim=1)

metrics.reset()
x = torch.randn(100, 100)
opt_fn1 = torch._dynamo.optimize("inductor")(fn1)
self.assertTrue(same(fn1(x), opt_fn1(x)))
assert metrics.generated_cpp_vec_kernel_count == 1

metrics.reset()
x = torch.randn(100, 100, 100)
opt_fn2 = torch._dynamo.optimize("inductor")(fn2)
self.assertTrue(same(fn2(x), opt_fn2(x)))
assert metrics.generated_cpp_vec_kernel_count == 1


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
53 changes: 32 additions & 21 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,23 +1281,38 @@ def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
f"{reduction_combine_vec(reduction_type, tmpvar_vec, value)};"
)

reduce_all_body = "{"
if reduction_type == "sum":
reduce_all_body += "return x + y;"
else:
reduce_all_body += f"return {vec_ns}::{reduce_map[reduction_type]}(x, y);"
reduce_all_body += "}"
vec_reduce_all_func = f"{vec_ns}::vec_reduce_all<{DTYPE_TO_CPP[dtype]}>"
next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}&y) {reduce_all_body}, {tmpvar_vec})"
self.reduction_suffix.writeline(
DeferredLine(
name,
f"{reduction_combine(reduction_type, tmpvar, next_value)};",
if self.tiling_idx >= self.reduction_depth:
# Horizontal reduction
# NOTE(jgong5): we do not generate the real stores here with the assumption that
# the scalar kernel that handles the loop tail would be generated and generates
# the stores there.
reduce_all_body = "{"
if reduction_type == "sum":
reduce_all_body += "return x + y;"
else:
reduce_all_body += (
f"return {vec_ns}::{reduce_map[reduction_type]}(x, y);"
)
reduce_all_body += "}"
vec_reduce_all_func = f"{vec_ns}::vec_reduce_all<{DTYPE_TO_CPP[dtype]}>"
next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}&y) {reduce_all_body}, {tmpvar_vec})"
self.reduction_suffix.writeline(
DeferredLine(
name,
f"{reduction_combine(reduction_type, tmpvar, next_value)};",
)
)
elif name not in V.graph.removed_buffers:
# Vertical reduction
var = self.args.output(name)
new_index = self.scale_index_with_offset(
index, self.tiling_factor, itervar_idx=self.tiling_idx
)
self.reduction_suffix.writeline(
DeferredLine(
name, f"{tmpvar_vec}.store({var} + {cexpr_index(new_index)});"
)
)
)
# NOTE(jgong5): we do not generate the real stores here with the assumption that
# the scalar kernel that handles the loop tail would be generated and generates
# the stores there.
self.cse.store_cache[name] = tmpvar


Expand Down Expand Up @@ -2228,11 +2243,7 @@ def select_tiling():
) as vec_checker:
run(vec_checker)
if vec_checker.simd_vec:
if (
len(tiling_indices) == 1
and tiling_indices[0] == len(self.itervars) - 1
):
# TODO(jgong5): support vec on outer loop
if len(tiling_indices) == 1:
return [tiling_factor], tiling_indices
if len(tiling_indices) == 2 and self.reduction_depth == len(
self.itervars
Expand Down