Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4816,6 +4816,24 @@ def forward(self, input: torch.Tensor):

self.assertTrue(torch.allclose(module(input), traced(input)))

@patch.object(config.triton, "autotune", True)
def test_inplace_add_alpha_autotune(self):
def fn(x, y):
aten.add_.Tensor(x, y, alpha=0.55)
return (x,)

x1 = torch.zeros(2, 3, 4, 10, device="cuda")
x2 = torch.zeros(2, 3, 4, 10, device="cuda")
x3 = torch.zeros(2, 3, 4, 10, device="cuda")
y = torch.randn(2, 3, 4, 10, device="cuda").to(
memory_format=torch.channels_last
)
fn_fx = make_fx(fn)(x1, y)
fn_compiled = compile_fx_inner(fn_fx, [x1, y])
fn(x2, y)
fn_compiled([x3, y])
assert same(x2, x3)

def test_permute_linear_fusion(self):
class TestModule(torch.nn.Module):
def __init__(self, k: int, n: int):
Expand Down Expand Up @@ -5108,6 +5126,7 @@ def decorator(fn):
meta=meta,
configs=configs,
save_cache_hook=False,
mutated_arg_names=["in_out_ptr0"],
)

return decorator
Expand Down
30 changes: 28 additions & 2 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,11 +515,18 @@ class TritonKernel(Kernel):
overrides = TritonOverrides
sexpr = texpr

def __init__(self, *groups, pid_cache=None, reduction_hint=ReductionHint.DEFAULT):
def __init__(
self,
*groups,
mutations=None,
pid_cache=None,
reduction_hint=ReductionHint.DEFAULT,
):
if pid_cache is None:
pid_cache = {}
super(TritonKernel, self).__init__()
self.numels = [V.graph.sizevars.simplify(s) for s in groups]
self.mutations = mutations
self.range_trees = []
self.range_tree_nodes = {}
self.iter_vars_count = itertools.count()
Expand Down Expand Up @@ -1011,10 +1018,21 @@ def codegen_kernel(self, name=None):
)

argdefs, _, signature = self.args.python_argdefs()

mutated_args = []
for mutation in self.mutations:
if mutation in self.args.input_buffers:
mutated_args.append(self.args.input_buffers[mutation])
if mutation in self.args.inplace_buffers:
mutated_args.append(self.args.inplace_buffers[mutation])
if mutation in self.args.output_buffers:
mutated_args.append(self.args.output_buffers[mutation])

triton_meta = {
"signature": dict(enumerate(map(signature_of, signature))),
"device": V.graph.scheduler.current_device.index,
"constants": {},
"mutated_arg_names": mutated_args,
}

for tree in self.range_trees:
Expand Down Expand Up @@ -1289,7 +1307,15 @@ def codegen_node_schedule(self, node_schedule, numel, reduction_numel):
reduction_hint_val = ReductionHint.DEFAULT
else:
reduction_hint_val = ReductionHint.DEFAULT
with TritonKernel(*tiled_groups, reduction_hint=reduction_hint_val) as kernel:

mutations = set()
for node in node_schedule:
if hasattr(node, "get_mutations"):
mutations.update(node.get_mutations())

with TritonKernel(
*tiled_groups, reduction_hint=reduction_hint_val, mutations=mutations
) as kernel:
stack = contextlib.ExitStack()
for node in node_schedule:
if node not in (EnableReduction, DisableReduction):
Expand Down
28 changes: 20 additions & 8 deletions torch/_inductor/triton_ops/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ class CachingAutotuner(KernelInterface):
configs, and does not rely on the Triton JIT.
"""

def __init__(self, fn, meta, configs, save_cache_hook):
def __init__(self, fn, meta, configs, save_cache_hook, mutated_arg_names):
super().__init__()
self.fn = fn
self.meta = meta
self.save_cache_hook = save_cache_hook
self.mutated_arg_names = mutated_arg_names
self.configs = configs
self.launchers = []
self.lock = threading.Lock()
Expand Down Expand Up @@ -141,12 +142,17 @@ def autotune_to_one_config(self, *args, **kwargs):
"""Do the actual autotuning"""
from ..compile_fx import clone_preserve_strides

# clone the input args to avoid autotune contaminating them if
# the kernel does in-place stores
cloned_args = [
clone_preserve_strides(arg) if isinstance(arg, torch.Tensor) else arg
for arg in args
]
# clone inplace buffers to avoid autotune contaminating them if
# the kernel does in-place stores. avoid cloning other buffers because
# it leads to increase memory use
cloned_args = []
for i, arg in enumerate(args):
if self.fn.arg_names[i] in self.mutated_arg_names:
assert isinstance(arg, torch.Tensor)
cloned_args.append(clone_preserve_strides(arg))
else:
cloned_args.append(arg)

timings = {
launcher: self.bench(launcher, *cloned_args, **kwargs)
for launcher in self.launchers
Expand Down Expand Up @@ -251,9 +257,15 @@ def save_cache_hook(cfg):
else:
save_cache_hook = None

mutated_arg_names = meta.pop("mutated_arg_names", ())

def decorator(fn):
return CachingAutotuner(
fn, meta=meta, configs=configs, save_cache_hook=save_cache_hook
fn,
meta=meta,
configs=configs,
save_cache_hook=save_cache_hook,
mutated_arg_names=mutated_arg_names,
)

return decorator
Expand Down