Skip to content

Commit

Permalink
[inductor] fix benchmark call for inplace update (#103547)
Browse files Browse the repository at this point in the history
Enabling coordinate descent tuning for a few models cause illegal memory access (or trigger a device assert before that). Command:
```
TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 python benchmarks/dynamo/huggingface.py --amp --performance --training --inductor -d cuda --only CamemBert
```

It turns out that we can not benchmark this kernel: https://gist.github.com/shunting314/a78997f54b5751f2887f4576956036ce

Digging more, it shows that this kernel has a inplace argument that will be changed after running the kernel. Our benchmark API simply call a kernel multiple times. Since each run may have side effect. The previous calls may change the inplace argument in a way that fail following calls.

This PR clone those inplace arguments before each benchmark call. This can increase the time for each benchmark call. But this should not affect autotuning since we increase the equal amount of time for each tuning configs.

Pull Request resolved: #103547
Approved by: https://github.com/jansel
  • Loading branch information
shunting314 authored and pytorchmergebot committed Jun 14, 2023
1 parent 8761619 commit 2e1369d
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions torch/_inductor/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,10 @@ def kernel_call():
launcher.config.pre_hook(
{**zip(self.arg_names, args), **launcher.config.kwargs}
)

cloned_args = self.clone_args(*args)
launcher(
*args,
*cloned_args,
grid=grid,
stream=stream,
)
Expand All @@ -222,9 +224,8 @@ def clone_args(self, *args):

@dynamo_timed
def benchmark_all_configs(self, *args, **kwargs):
cloned_args = self.clone_args(*args)
timings = {
launcher: self.bench(launcher, *cloned_args, **kwargs)
launcher: self.bench(launcher, *args, **kwargs)
for launcher in self.launchers
}

Expand Down

0 comments on commit 2e1369d

Please sign in to comment.