-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[Inductor] Add prologue fusion #121211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor] Add prologue fusion #121211
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -802,6 +802,59 @@ def fn( | |
y1_expected = fn(x1, w, b, mul1) | ||
torch.testing.assert_close(y1, y1_expected) | ||
|
||
def test_triton_template_with_prologues_and_dynamic_shape(self): | ||
def const(x: torch.Tensor, val) -> torch.Tensor: | ||
return torch.full(x.size(), val).cuda() | ||
|
||
def fn( | ||
x: torch.Tensor, w: torch.Tensor | ||
) -> torch.Tensor: | ||
return torch.matmul( | ||
torch.transpose(x, 0, 1) * torch.transpose(const(x, 0.05), 0, 1) + torch.transpose(const(x, 0.1), 0, 1), | ||
torch.transpose(w, 0, 1) | ||
) | ||
|
||
torch.backends.cuda.matmul.allow_tf32 = False | ||
|
||
with config.patch( | ||
{ | ||
"max_autotune": True, | ||
"autotune_in_subproc": True, | ||
"max_autotune_gemm_backends": "Triton", | ||
"prologue_fusion": True, | ||
"max_prologue_opcount": 16, | ||
} | ||
): | ||
compiled_fn = torch.compile( | ||
fn, fullgraph=True, dynamic=True, mode="max-autotune-no-cudagraphs" | ||
) | ||
|
||
counters["inductor"]["cuda_prologue_fusion_counter"] = 0 | ||
|
||
M0 = 5 | ||
K = 5 | ||
N = 5 | ||
w = torch.rand(N, K).cuda() | ||
x0 = torch.rand(K, M0).cuda() | ||
y0 = compiled_fn(x0, w) | ||
y0_expected = fn(x0, w) | ||
torch.testing.assert_close(y0, y0_expected) | ||
|
||
M1 = 8 | ||
K = 8 | ||
N = 8 | ||
w = torch.rand(N, K).cuda() | ||
x1 = torch.rand(K, M1).cuda() | ||
y1 = compiled_fn(x1, w) | ||
y1_expected = fn(x1, w) | ||
torch.testing.assert_close(y1, y1_expected) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add an assert that the prologue was actually generated? Either by inspecting the output code (there are some helpers to get this) or by counting kernels. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added an assert to check if the fusion count is expected. |
||
|
||
actual_count = counters["inductor"]["cuda_prologue_fusion_counter"] | ||
assert ( | ||
actual_count == 1 | ||
), f"Expected fuse count of 1 but got {actual_count}" | ||
|
||
|
||
@config.patch( | ||
benchmark_kernel=True, | ||
fallback_random=True, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -181,7 +181,7 @@ def define_kernel(self, src_code: str, node_schedule) -> str: | |
return kernel_name | ||
|
||
def codegen_template( | ||
self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode] | ||
self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode], isEpilogue=True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isEpilogue => is_epilogue There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed a naming convention. |
||
): | ||
""" | ||
Codegen a CUDA template, possibly with fused epilogues | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add another test where this is an additional usage of the prologue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually updated this test so that it would be more aligned with our target.