Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6d566a2
[WIP][inductor] generate fused rms/layer norm bwd
shunting314 Oct 13, 2025
a781691
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 13, 2025
8d97d1f
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 14, 2025
7a687ec
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 14, 2025
23c0155
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 15, 2025
9c81a1b
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 15, 2025
d2d9dd5
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 15, 2025
6fbfdc4
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 15, 2025
3d4193b
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 15, 2025
f5c65cd
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 15, 2025
cd0f4e8
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 15, 2025
8745ca9
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 15, 2025
70e3f8d
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 15, 2025
a4cbcf3
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 16, 2025
12615bd
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 16, 2025
df0b8a2
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 16, 2025
21d226c
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 16, 2025
4f40d41
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 17, 2025
ac51f02
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 17, 2025
2adf0dd
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 17, 2025
4b5902a
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 18, 2025
0c1102b
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 18, 2025
cc9e8ba
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 20, 2025
7662e7b
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 20, 2025
220f760
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 20, 2025
0d0f7b7
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 21, 2025
b2979ae
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 22, 2025
3dc3f1b
Update on "[WIP][inductor] generate fused rms/layer norm bwd"
shunting314 Oct 22, 2025
f1bd189
Update on "[inductor] generate fused rms/layer norm bwd"
shunting314 Oct 22, 2025
70680fe
Update on "[inductor] generate fused rms/layer norm bwd"
shunting314 Oct 23, 2025
bae45b0
Update on "[inductor] generate fused rms/layer norm bwd"
shunting314 Oct 24, 2025
a138763
Update on "[inductor] generate fused rms/layer norm bwd"
shunting314 Oct 24, 2025
f600fab
Update on "[inductor] generate fused rms/layer norm bwd"
shunting314 Oct 27, 2025
cf82e0a
Update on "[inductor] generate fused rms/layer norm bwd"
shunting314 Oct 27, 2025
1ac94b7
Update on "[inductor] generate fused rms/layer norm bwd"
shunting314 Oct 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 282 additions & 0 deletions test/inductor/test_mix_order_reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# Owner(s): ["module: inductor"]

import torch
import torch._inductor.config as inductor_config
import torch.nn.functional as F
from torch._dynamo.utils import same
from torch._inductor import metrics, utils
from torch._inductor.test_case import run_tests, TestCase
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU


class TestBase(TestCase):
def setUp(self):
super().setUp()
metrics.reset()

def check_numeric(self, f, args, tol=1e-3):
ref = f(*args)
act = torch.compile(f)(*args)
self.assertTrue(same(ref, act, tol=tol))


class SkipPatternTest(TestBase):
"""
Illustate the cases that we skip mix-order reduction. We skip in cases
like when the outer reduction is followed by a pointwise that load
the un-reduced tensor.
"""

@inductor_config.patch(split_reductions=False)
def test_dimension_too_close(self):
"""
Skip if the two reduction size are too close.
We require one reduction dimension to be much larger so we can split
that dimension and make it efficient.
"""

def f(x):
out1 = x.sum(dim=1)
out2 = x.sum(dim=0)
return out1, out2

x = torch.randn(768, 768, device=GPU_TYPE)
torch.compile(f)(x)
self.assertEqual(2, metrics.generated_kernel_count)

@inductor_config.patch(split_reductions=False)
def test_skip_if_outer_reduction_followed_by_full_pointwise(self):
"""
Skip for now if the outer reduction is followed by a pointwise node
accessing the original tensor. Accessing the reduced tensor is fine
(e.g. to support torch.mean).
"""

def f(x):
out1 = x.sum(dim=1)
out2 = x.sum(dim=0, keepdim=True) + x
return out1, out2

x = torch.randn(32768, 768, device=GPU_TYPE)
self.check_numeric(f, (x,))
self.assertEqual(0, metrics.codegen_mix_order_reduction)

@inductor_config.patch(split_reductions=False)
def test_skip_due_to_non_persistent_reduction(self):
"""
We only generate mix order reduction if one of the reduction is
persistent reduction.
"""

def f(x):
return x.sum(dim=1), x.sum(dim=0)

x = torch.randn(32768, 2048, device=GPU_TYPE)
self.check_numeric(f, (x,))
self.assertEqual(0, metrics.codegen_mix_order_reduction)


@instantiate_parametrized_tests
class MixOrderReductionTest(TestBase):
@parametrize(
"name",
[
"sum",
"prod",
"mean",
],
)
@parametrize("swap", (False, True))
@parametrize("shape", ((32768, 768), (32769, 768)))
@inductor_config.patch(split_reductions=False)
def test_mix_order_reduction(self, name, swap, shape):
def f(x):
if swap:
return reduction_fn(x, dim=0), reduction_fn(x, dim=1)
else:
return reduction_fn(x, dim=1), reduction_fn(x, dim=0)

reduction_fn = getattr(torch, name)
M, N = shape
dtype = torch.float
x = torch.randn(M, N, dtype=dtype, device=GPU_TYPE)

opt_f = torch.compile(f)

ref = f(x)
act = opt_f(x)

self.assertTrue(same(ref, act, tol=1e-3), f"ref:\n{ref}\nact:\n{act}")

expected_num_kernel = 1 + (not inductor_config.triton.mix_order_reduction)
if name == "mean" and inductor_config.triton.mix_order_reduction:
# for mean we generate one more kernel to do the division
# this kernel should be very cheap since tensor size is small
expected_num_kernel = 2
self.assertEqual(
expected_num_kernel,
metrics.generated_kernel_count,
)

@inductor_config.patch(split_reductions=False)
def test_multi_workspace_allocation(self):
def f(x, y):
return x.sum(dim=0), x.sum(dim=1), y.sum(dim=0), y.sum(dim=1)

x = torch.randn(128 * 15, 128, device=GPU_TYPE)
y = torch.randn(256 * 15, 256, device=GPU_TYPE)

self.check_numeric(f, (x, y))
expected_mix_order_reduction = (
0 if not inductor_config.triton.mix_order_reduction else 2
)
self.assertEqual(
expected_mix_order_reduction, metrics.codegen_mix_order_reduction
)

@parametrize(
"wdtype",
[
torch.bfloat16, # extra down cast for dw is needed
torch.float,
],
)
@parametrize("shape", ((32768, 768), (32769, 768)))
@inductor_config.patch(split_reductions=False)
def test_rms_norm_bwd(self, wdtype, shape):
def f(x, w, eps):
orig_dtype = x.dtype

x = x.float()
rsqrt = torch.rsqrt((x * x).sum(dim=-1) / x.shape[-1] + eps)
y = (x * rsqrt[:, None] * w).to(dtype=orig_dtype)
return y

def fwd_bwd(f):
x.grad = None
w.grad = None
out = f(x, w, eps)
out.backward(dy)
return x.grad, w.grad

torch.manual_seed(1337)

# M, N = 1152 * 500, 384
M, N = shape
x = torch.randn(M, N, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True)
w = torch.randn(N, dtype=wdtype, device=GPU_TYPE, requires_grad=True)
dy = torch.randn_like(x)
eps = 1e-5

opt_f = torch.compile(f)

ref = fwd_bwd(f)
act, (_, bwd_wrapper) = utils.run_and_get_code(fwd_bwd, opt_f)

self.assertTrue(same(ref, act, tol=1e-2), f"ref:\n{ref}\nact:\n{act}")
expected_num_kernel = 1 + (not inductor_config.triton.mix_order_reduction)
if wdtype == torch.bfloat16 and inductor_config.triton.mix_order_reduction:
# one extra kernel for downcasting
expected_num_kernel = 2
FileCheck().check_count(
"@triton.jit",
expected_num_kernel,
exactly=True,
).run(bwd_wrapper)

@parametrize(
"wbdtype",
[
torch.bfloat16, # extra down cast for dw/db is needed
torch.float,
],
)
@parametrize("shape", ((32768, 768), (32769, 768)))
@inductor_config.patch(split_reductions=False)
def test_layer_norm_bwd_with_bias(self, wbdtype, shape):
def f(x, w, b, eps):
return F.layer_norm(x, x.shape[-1:], w.float(), b.float(), eps)

def fwd_bwd(f):
x.grad = None
w.grad = None
b.grad = None
out = f(x, w, b, eps)
out.backward(dy)
return x.grad, w.grad, b.grad

# M, N = 1152 * 500, 384
M, N = shape
xdtype = torch.float
x = torch.randn(M, N, dtype=xdtype, device=GPU_TYPE, requires_grad=True)
w = torch.randn(N, dtype=wbdtype, device=GPU_TYPE, requires_grad=True)
b = torch.randn(N, dtype=wbdtype, device=GPU_TYPE, requires_grad=True)
dy = torch.randn_like(x)
eps = 1e-5

opt_f = torch.compile(f)

ref = fwd_bwd(f)
act, (_, bwd_wrapper) = utils.run_and_get_code(fwd_bwd, opt_f)

self.assertTrue(same(ref, act, tol=1e-2), f"ref:\n{ref}\nact:\n{act}")
expected_num_kernel = 1 + (not inductor_config.triton.mix_order_reduction)
if wbdtype == torch.bfloat16 and inductor_config.triton.mix_order_reduction:
# one extra kernel for downcasting
expected_num_kernel = 2
FileCheck().check_count(
"@triton.jit",
expected_num_kernel,
exactly=True,
).run(bwd_wrapper)

@parametrize("shape", ((32768, 768), (32769, 768)))
@inductor_config.patch(split_reductions=False)
def test_layer_norm_bwd_no_bias(self, shape):
def f(x, w, eps):
return F.layer_norm(x, x.shape[-1:], w, bias=None, eps=eps)

def fwd_bwd(f):
x.grad = None
w.grad = None
out = f(x, w, eps)
out.backward(dy)
return x.grad, w.grad

# M, N = 1152 * 500, 384
M, N = shape
xdtype = torch.float
wbdtype = torch.float
x = torch.randn(M, N, dtype=xdtype, device=GPU_TYPE, requires_grad=True)
w = torch.randn(N, dtype=wbdtype, device=GPU_TYPE, requires_grad=True)
dy = torch.randn_like(x)
eps = 1e-5

opt_f = torch.compile(f)

ref = fwd_bwd(f)
act, (_, bwd_wrapper) = utils.run_and_get_code(fwd_bwd, opt_f)

self.assertTrue(same(ref, act, tol=1e-2), f"ref:\n{ref}\nact:\n{act}")
FileCheck().check_count(
"@triton.jit",
1 + (not inductor_config.triton.mix_order_reduction),
exactly=True,
).run(bwd_wrapper)


@inductor_config.patch(
"triton.mix_order_reduction", not inductor_config.triton.mix_order_reduction
)
class NoMixOrderReductionTest(MixOrderReductionTest):
pass


if __name__ == "__main__":
if HAS_GPU:
run_tests()
9 changes: 7 additions & 2 deletions torch/_inductor/choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
from torch._inductor.runtime.runtime_utils import next_power_of_2
from torch._inductor.scheduler import MixOrderReduction
from torch.utils._sympy.value_ranges import bound_sympy

from . import config
Expand Down Expand Up @@ -487,7 +488,9 @@ def can_fuse(
- config.triton.tiling_prevents_reduction_fusion
- config.aggressive_fusion (will cause this function to be called more times)
"""
if shared_data_score == 0 and (
if (
shared_data_score == 0 and not MixOrderReduction.can_fuse(node1, node2)
) and (
not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction()
):
if is_metric_table_enabled("fusion_failure_due_to_indexing_mismatch"):
Expand Down Expand Up @@ -547,7 +550,9 @@ def can_fuse_horizontal(
shared_data_score: int,
) -> bool:
"""Hook for heuristics to prevent horizontal (consumer/consumer) fusions"""
if shared_data_score < config.score_fusion_memory_threshold:
if (
shared_data_score < config.score_fusion_memory_threshold
) and not MixOrderReduction.can_fuse(node1, node2):
WhyNoFuse(node1, node2)("score_fusion_memory_threshold")
return False
if scheduler.are_long_distant_nodes(node1, node2):
Expand Down
24 changes: 23 additions & 1 deletion torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,15 @@
triton_type,
unique,
)
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
from ..virtualized import (
NullHandler,
ops,
OpsHandler,
OpsValue,
ReductionType,
StoreMode,
V,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -2162,6 +2170,14 @@ def reduction(
) -> Union[CSEVariable, tuple[CSEVariable, ...]]:
raise NotImplementedError

def partial_accumulate(
self,
name: str,
reduction_type: ReductionType,
value: CSEVariable,
) -> None:
raise NotImplementedError

def scan(
self,
dtypes: tuple[torch.dtype, ...],
Expand Down Expand Up @@ -2626,6 +2642,9 @@ def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[A
if isinstance(V.kernel, CUDATemplateKernel):
return ValueRanges.unknown()

if isinstance(V.interpreter, NullHandler):
return ValueRanges.unknown()

fx_node = V.interpreter.current_node
if fx_node.target == name and self.kernel.node_to_bounds is not None:
assert isinstance(self.kernel.node_to_bounds, dict), type(
Expand Down Expand Up @@ -2753,6 +2772,9 @@ def store(
def device_assert_async(self, cond: CSEVariable, msg: str) -> None:
self.kernel.device_assert_async(cond, msg)

def partial_accumulate(self, *args: Any) -> None:
self.kernel.partial_accumulate(*args)

def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
self.kernel.store_buffer_names.add(name)
self._update_store_cache(name, value)
Expand Down
8 changes: 8 additions & 0 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,14 @@ def sign(x):
code.writeline("()")
return code

def partial_accumulate(
self,
name: str,
reduction_type: str,
value: CSEVariable,
) -> None:
raise NotImplementedError


CppOverrides._initialize_pointwise_overrides("cpp")

Expand Down
3 changes: 3 additions & 0 deletions torch/_inductor/codegen/cuda_combined_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def codegen_template(
template_node, epilogue_nodes, prologue_nodes
)

def codegen_mix_order_reduction(self, node):
return self._triton_scheduling.codegen_mix_order_reduction(node)

def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None:
return self._triton_scheduling.codegen_node(node)

Expand Down
Loading
Loading