Skip to content

Commit

Permalink
[inductor] Parallelize Max Autotune step 2: Use multiple GPUs
Browse files Browse the repository at this point in the history
Test Plan:
`python test/inductor/test_max_autotune.py`
`TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`
`TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE=1 TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --inference --only hf_Bart`

ghstack-source-id: 290da2b3040ca6a5913d13c0c5bc3ac37ee71d4a
Pull Request resolved: #109127
  • Loading branch information
masnesral committed Sep 12, 2023
1 parent 9d01dc3 commit 56b36bc
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 107 deletions.
64 changes: 61 additions & 3 deletions test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
# Owner(s): ["module: inductor"]
from typing import Optional

import torch
from torch import multiprocessing as mp
from torch._dynamo.test_case import run_tests, TestCase
from torch._inductor import config
from torch._inductor.autotune_process import BenchmarkRequest, TuningProcessPool
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import Buffer, FixedLayout
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
from torch._inductor.select_algorithm import AlgorithmSelectorCache, ChoiceCaller
from torch._inductor.select_algorithm import (
AlgorithmSelectorCache,
ChoiceCaller,
TritonTemplateCaller,
)
from torch._inductor.utils import run_and_get_code
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
Expand Down Expand Up @@ -118,7 +124,8 @@ def test_benchmark_choice_fail_in_subproc(self):
self.assertNotEqual(0, child.exitcode)

@parametrize("autotune_in_subproc", (True, False))
def test_max_autotune_mm_plus_mm(self, autotune_in_subproc):
@parametrize("autotune_multi_device", (True, False))
def test_max_autotune_mm_plus_mm(self, autotune_in_subproc, autotune_multi_device):
"""
This crash previously due to a triton issue: https://github.com/openai/triton/issues/1298 .
With autotuning in subprocess, we don't crash anymore.
Expand All @@ -134,7 +141,11 @@ def mm_plus_mm(a, b, c, d):
d = torch.randn(k, n).cuda()

with config.patch(
{"max_autotune": True, "autotune_in_subproc": autotune_in_subproc}
{
"max_autotune": True,
"autotune_in_subproc": autotune_in_subproc,
"autotune_multi_device": autotune_multi_device,
}
):
torch.compile(mm_plus_mm)(a, b, c, d)

Expand Down Expand Up @@ -318,6 +329,53 @@ def fn(
torch.testing.assert_close(y1, y1_expected)


class TestBenchmarkRequest(BenchmarkRequest):
def __init__(self, value: Optional[float] = None) -> None:
self.value = value

def benchmark(
self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
) -> float:
if self.value is None:
raise Exception("Failed to run")
return self.value


class TestTritonTemplateCaller(TritonTemplateCaller):
def __init__(self, bmreq: TestBenchmarkRequest):
self.bmreq = bmreq

def __str__(self) -> str:
return "test"


class TestTuningProcess(TestCase):
def test_tuning_pool(self):
# Use only one device:
with config.patch({"autotune_multi_device": False}):
tuning_pool = TuningProcessPool()
tuning_pool.initialize()

# First cause the tuning process to crash.
bmreq = TestBenchmarkRequest(value=None)
choice = TestTritonTemplateCaller(bmreq)

timings = tuning_pool.benchmark([choice])
self.assertTrue(choice in timings)
self.assertEqual(timings[choice], float("inf"))

# Then send another request and make sure the sub-process
# has restarted and is operational.
value = 3.14
choice.bmreq.value = value

timings = tuning_pool.benchmark([choice])
self.assertTrue(choice in timings)
self.assertEqual(timings[choice], value)

tuning_pool.terminate()


if __name__ == "__main__":
if HAS_CUDA:
run_tests()
2 changes: 1 addition & 1 deletion test/inductor/test_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def patches(fn):
def skip_cache(self, choices, name, key, generate):
return {choice: generate(choice) for choice in choices}
return generate(choices)

for patcher in [
dynamo_config.patch(verbose=True),
Expand Down

0 comments on commit 56b36bc

Please sign in to comment.