Skip to content

Commit

Permalink
[Inductor] support vertical reduction in cpp
Browse files Browse the repository at this point in the history
ghstack-source-id: ad095c04d6af249e14cbc4f44ba952fa2101bc7a
Pull Request resolved: #97644
  • Loading branch information
jgong5 committed Mar 27, 2023
1 parent 906a694 commit ac46096
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 19 deletions.
10 changes: 10 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6988,6 +6988,16 @@ def run_node_alt(*args, **kwargs):
fn(torch.randn([8, 128]))
self.assertGreater(len(strings), 3)

def test_vertical_sum_cpu_only(self):
def fn(a):
return a.sum(dim=0)

metrics.reset()
x = torch.randn(100, 100)
opt_fn = torch._dynamo.optimize("inductor")(fn)
self.assertTrue(same(fn(x), opt_fn(x)))
assert metrics.generated_cpp_vec_kernel_count == 1


if HAS_CUDA and not TEST_WITH_ASAN:
import triton
Expand Down
48 changes: 29 additions & 19 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,21 +1253,34 @@ def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
None, f"{reduction_combine_vec(reduction_type, tmpvar_vec, value)};"
)

reduce_all_body = "{"
if reduction_type == "sum":
reduce_all_body += "return x + y;"
else:
reduce_all_body += f"return {vec_ns}::{reduce_map[reduction_type]}(x, y);"
reduce_all_body += "}"
vec_reduce_all_func = f"{vec_ns}::vec_reduce_all<{DTYPE_TO_CPP[dtype]}>"
next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}&y) {reduce_all_body}, {tmpvar_vec})"
self.reduction_suffix.writeline(
name,
f"{reduction_combine(reduction_type, tmpvar, next_value)};",
)
# NOTE(jgong5): we do not generate the real stores here with the assumption that
# the scalar kernel that handles the loop tail would be generated and generates
# the stores there.
if self.tiling_idx >= self.reduction_depth:
# Horizontal reduction
# NOTE(jgong5): we do not generate the real stores here with the assumption that
# the scalar kernel that handles the loop tail would be generated and generates
# the stores there.
reduce_all_body = "{"
if reduction_type == "sum":
reduce_all_body += "return x + y;"
else:
reduce_all_body += (
f"return {vec_ns}::{reduce_map[reduction_type]}(x, y);"
)
reduce_all_body += "}"
vec_reduce_all_func = f"{vec_ns}::vec_reduce_all<{DTYPE_TO_CPP[dtype]}>"
next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}&y) {reduce_all_body}, {tmpvar_vec})"
self.reduction_suffix.writeline(
name,
f"{reduction_combine(reduction_type, tmpvar, next_value)};",
)
elif name not in V.graph.removed_buffers:
# Vertical reduction
var = self.args.output(name)
new_index = self.scale_index_with_offset(
index, self.tiling_factor, itervar_idx=self.tiling_idx
)
self.reduction_suffix.writeline(
name, f"{tmpvar_vec}.store({var} + {cexpr(new_index)});"
)
self.cse.store_cache[name] = tmpvar


Expand Down Expand Up @@ -2153,10 +2166,7 @@ def select_tiling():
) as vec_checker:
run(vec_checker)
if vec_checker.simd_vec:
if (
len(tiling_indices) == 1
and tiling_indices[0] == len(self.itervars) - 1
):
if len(tiling_indices) == 1:
# TODO(jgong5): support vec on outer loop
return [tiling_factor], tiling_indices
if len(tiling_indices) == 2 and self.reduction_depth == len(
Expand Down

0 comments on commit ac46096

Please sign in to comment.