Skip to content

Commit

Permalink
Fix reduction + () + multi-level reduction optimization (#111781)
Browse files Browse the repository at this point in the history
Summary:

In #111122, an optimization is introduced for reduction() + () + multi-level reduction. In this case, we make a multi-level reduction first-level reduction ranges the same as the previous reduction ranges so that the Inductor has better chances to fuse the first reduction and the first-level reduction of the multi-level reduction kernel together.

There is a corner case that the multi-level reduction kernel has `keepdim=True`. In this case, ranges of the multi-level reduction kernel is not empty, and the dim info needs to be used to create the inner loader of the first-level reduction kernel. To keep the logic simple, for now we simply disable optimization when `keepdim=True`.




imported-using-ghimport

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D50544876

Pulled By: ipiszy
  • Loading branch information
ipiszy authored and facebook-github-bot committed Oct 23, 2023
1 parent 185e762 commit 903763c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
9 changes: 7 additions & 2 deletions test/inductor/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,13 @@ def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("amax_keep_dim", (True, False))
@parametrize(
"shape", ((1, 1, 15), (1, 10, 15), (1, 10, 512), (1, 10, 4096), (4, 2048, 4096))
)
def test_layernorm_fp8_quant(self, float8_dtype: torch.dtype, shape: Tuple[int]):
def test_layernorm_fp8_quant(
self, float8_dtype: torch.dtype, amax_keep_dim: bool, shape: Tuple[int]
):
batch_size, sequence_length, hidden_size = shape

def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
Expand All @@ -217,7 +220,9 @@ def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
bias=None,
eps=1e-05,
)
amax_buffer.fill_(torch.amax(torch.abs(x)))
amax_buffer.fill_(
torch.amax(torch.abs(x), keepdim=amax_keep_dim).reshape(-1)[0]
)
x_scaled = x * scale
bits_fp8 = _to_fp8_saturated(x_scaled, float8_dtype)
return bits_fp8
Expand Down
15 changes: 11 additions & 4 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,13 +736,20 @@ def outer_reduction_splits(reduction_numel_hint, numel_hint):
if split == 1:
# No need to split.
return ReductionHint.INNER, split
if input_node is not None and isinstance(input_node, TensorBox):
ranges, reduction_ranges = extract_input_node_reduction_ranges(
if (
len(ranges) == 0
and input_node is not None
and isinstance(input_node, TensorBox)
):
# Only handles the case where keep_dim = False.
# Otherwise, we need to propagate reduction dim info to the stage where
# the intermediate loader of the first Reduction is generated.
new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges(
input_node
)
if reduction_ranges is not None:
if new_ranges is not None and new_reduction_ranges is not None:
extracted_numel_hint = V.graph.sizevars.symbolic_hint(
sympy_product(ranges + reduction_ranges)
sympy_product(new_ranges + new_reduction_ranges)
)
if reduction_numel_hint == extracted_numel_hint:
# If the input_node or its dependent nodes are also Reduction nodes,
Expand Down

0 comments on commit 903763c

Please sign in to comment.