Skip to content

Commit

Permalink
[inductor] Fix edge case in JIT vs. AOT fusion after finalizing Multi…
Browse files Browse the repository at this point in the history
…TemplateBuffer (#126622)

# Context
Here's a peripheral scenario causing the JIT-pass and AOT-pass to pick different fusions.
```py
# JIT -- buf3 is a MultiTemplateBuffer
V.graph.buffers = [buf0, buf1, buf2, buf3, buf4]
                                ^          ^
# JIT pass calls finalize_multi_template_buffers()
V.graph.buffers = [buf0, buf1, buf2, buf4, *buf3*]

# AOT, note proximity_score(buf2, buf4) is "better" for fusion than JIT
V.graph.buffers = [buf0, buf1, buf2, buf4, *buf3*]
                                ^    ^
```

It happens like this:
* JIT starts with the original set nodes using V.graph.buffers
* In JIT, finalize_multi_template_buffers() is called which can change the order of the buffers.
* This makes the order of buffers/scheduler nodes different.
* Now, each node's min/max-order is different than before.
* As a result, the proximity between two nodes is different. https://github.com/pytorch/pytorch/blob/ad67553c5c1672d65b810acd7a6a01e11695098b/torch/_inductor/scheduler.py#L2316-L2335

# Error
```
$ TORCH_LOGS="+fusion" python test/inductor/test_max_autotune.py -k test_jit_fusion_matches_aot_fusion
======================================================================
FAIL: test_jit_fusion_matches_aot_fusion (__main__.TestMaxAutotune)
----------------------------------------------------------------------
Traceback (most recent call last):
  ...
  File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1718, in compile_to_fn
    code, linemap = self.codegen_with_cpp_wrapper()
  File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1618, in codegen_with_cpp_wrapper
    return self.codegen()
  File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1636, in codegen
    self.scheduler.codegen()
  File "/data/users/colinpeppler/pytorch/torch/_dynamo/utils.py", line 210, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2602, in codegen
    self.get_backend(device).codegen_node(node)  # type: ignore[possibly-undefined]
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 66, in codegen_node
    return self._triton_scheduling.codegen_node(node)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3377, in codegen_node
    return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3602, in codegen_node_schedule
    final_kernel.call_kernel(final_kernel.kernel_name)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3055, in call_kernel
    grid = wrapper.generate_default_grid(name, grid)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cpp_wrapper_cuda.py", line 174, in generate_default_grid
    params is not None
AssertionError: cuda kernel parameters for triton_poi_fused_add_0 should already exist at this moment, only found dict_keys(['Placeholder.DESCRIPTIVE_NAME', 'triton_poi_fused_add_mul_0', 'triton_poi_fused_pow_1'])
```

Pull Request resolved: #126622
Approved by: https://github.com/chenyang78
ghstack dependencies: #125982
  • Loading branch information
ColinPeppler authored and pytorchmergebot committed May 20, 2024
1 parent 7aa853a commit 8c38d0c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
21 changes: 21 additions & 0 deletions test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,27 @@ def fn(a, b, c):
self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2)
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0)

@skipIfRocm
@fresh_inductor_cache()
@config.patch(max_autotune=True, max_fusion_size=2)
def test_jit_fusion_matches_aot_fusion(self):
# In this example, AOTInductor's JIT-compile will fuse(buf1, buf2) due
# to proximity, we want to make sure AOT-compile pass does the same.
# AOT could do fuse(buf2, buf4) instead if buf3 was pushed to the end
# of the V.graph.buffers list because fuse(buf2, buf4) would have a
# better proximity score than fuse(buf1, buf2). This scenario is possible
# since finalizing MultiTemplateBuffers needs to replace buffers.
def fn(x, number):
buf0 = x + x
buf1 = number.item()
buf2 = x * x
buf3 = x @ x # MultiTemplateBuffer
buf4 = x**2
return buf0, buf1, buf2, buf3, buf4

inputs = (torch.rand([256, 256], device="cuda"), torch.tensor(3, device="cuda"))
torch._export.aot_compile(fn, args=inputs)

@config.patch(autotune_local_cache=False, autotune_remote_cache=False)
def test_precompilations(self):
def fn(a, b, c):
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,7 +1752,9 @@ def replace_buffer(orig_node: ir.MultiTemplateBuffer, new_node: ir.Buffer):
del V.graph.name_to_buffer[replaced_name]
new_node.name = orig_name

V.graph.buffers.remove(orig_node)
orig = V.graph.buffers.index(orig_node)
V.graph.buffers.remove(new_node)
V.graph.buffers[orig] = new_node
V.graph.name_to_buffer[orig_name] = new_node

for i, node in enumerate(self.nodes):
Expand Down

0 comments on commit 8c38d0c

Please sign in to comment.