Skip to content

Commit

Permalink
Update on "[inductor] do benchmark in sub processes for max autotuning"
Browse files Browse the repository at this point in the history
This PR implements the support to benchmark max-autotune choices in subprocesses. This way crash like triton-lang/triton#1298 will only abort the autotuning child process but the parent process can continue.

There are a few things to note:
- cuda runtime does not work with fork. So we have to use spawn to create child processes. Check the best practice from pytorch multithreading module: https://pytorch.org/docs/stable/notes/multiprocessing.html
- to run a job in a child process, the multiprocessing module needs to pickle both the target function and arguments and pass them to child process. This is the major complexity of this prototype since there are quite a lot of corner cases making pickle fail.

Here I list the pickle related issues I encountered:
- pickle a StorageBox cause infinite recursion. Error: https://gist.github.com/171e5ab404b7855dee2dfa1d9f093442 . Work around by pickle the inner buffer.
- IRNode store fx.Node's in its origin fields. However, we can not pickle a fx.Node. It fails when with the following error when picking the fx.Node.graph: https://gist.github.com/9c289e895d7091d7ec787c67bc3c0d70. Work around by skip origins when pickling a IRNode.
- jinja Template in TritonTemplateKernel can not be pickled: `TypeError: Template.__new__() missing 1 required positional argument: 'source' `. Workaround by pickle the source rather than jinjia Template. During unpickling, rebuild the jinja template.
- due to how select_algorithm.template_kernels is populated, in child process, it's empty. Work around by passing select_algorithm.template_kernels from parent process to child process directly.
  - There is some change in TritonTemplate.generate to make a TritonTemplateKernel pickle'able. A TritonTemplate is refered to in the closure for a TritonTemplateKernel object.
- We can not pass choice to child process directly because of pickle failure for lambda/local function being used. However cloudpickle can handle lambda. Work around by passing the cloudpickle'd choice object to child process. The child project need to unpickle it explictly.

Test:
```
python test/inductor/test_max_autotune.py -k test_max_autotune_mm_plus_mm
```
This is basically the repro I get from Bert Maher.


Benchmark in sub process is about 4x slower than benchmark in the same process. Without doing any profiling, I feel the time may be cost by starting a new process and doing initialization. Some ~thread~ process pool may help.

```
AUTOTUNE ref_mm_plus_mm(2048x64, 64x1536, 2048x64, 64x1536)
  triton_mm_plus_mm_0 0.0276s 100.0%
  triton_mm_plus_mm_6 0.0287s 96.4%
  triton_mm_plus_mm_5 0.0317s 87.1%
  triton_mm_plus_mm_1 0.0328s 84.4%
  ref_mm_plus_mm 0.0379s 73.0%
  triton_mm_plus_mm_7 0.0379s 73.0%
  triton_mm_plus_mm_2 0.0399s 69.2%
  triton_mm_plus_mm_3 0.0410s 67.5%
  triton_mm_plus_mm_4 0.0410s 67.5%
AUTOTUNE takes 12.001659393310547 seconds

AUTOTUNE ref_mm_plus_mm(2048x64, 64x1536, 2048x64, 64x1536)
  triton_mm_plus_mm_0 0.0276s 100.0%
  triton_mm_plus_mm_6 0.0287s 96.4%
  triton_mm_plus_mm_1 0.0317s 87.1%
  triton_mm_plus_mm_5 0.0317s 87.1%
  ref_mm_plus_mm 0.0379s 73.0%
  triton_mm_plus_mm_7 0.0389s 71.1%
  triton_mm_plus_mm_2 0.0399s 69.2%
  triton_mm_plus_mm_3 0.0410s 67.5%
  triton_mm_plus_mm_4 0.0410s 67.5%
AUTOTUNE takes 51.39659810066223 seconds
``` 

The feature is disabled by default and can be enabled by setting the following config or envvar:
```
autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
```


cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

Differential Revision: [D43996048](https://our.internmc.facebook.com/intern/diff/D43996048)

[ghstack-poisoned]
  • Loading branch information
shunting314 committed Mar 17, 2023
2 parents da538cd + 2029a41 commit e987601
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions torch/_inductor/select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,10 @@ def __init__(
self.debug_extra = debug_extra
self.bmreq = bmreq

def benchmark(self, *args, out):
assert self.bmreq is not None
return self.bmreq.benchmark(*args, output_tensor=out)

def __str__(self):
return (
f"TritonTemplateCaller({self.to_callable().__file__}, {self.debug_extra})"
Expand Down Expand Up @@ -716,8 +720,7 @@ def benchmark_in_current_process(choice):
result = choice.benchmark(*example_inputs_extern, out=out_extern)[0]
else:
# triton templates want the base pointer for sliced tensors
assert choice.bmreq is not None
result = choice.bmreq.benchmark(*example_inputs, output_tensor=out)
result = choice.benchmark(*example_inputs, out=out)
if VERIFY:
torch.testing.assert_close(out_extern, expected, **VERIFY)
torch.cuda.synchronize() # shake out any CUDA errors
Expand Down

0 comments on commit e987601

Please sign in to comment.