Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Parallelize Max Autotune step 2: Use multiple GPUs #109127

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 61 additions & 3 deletions test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
# Owner(s): ["module: inductor"]
from typing import Optional

import torch
from torch import multiprocessing as mp
from torch._dynamo.test_case import run_tests, TestCase
from torch._inductor import config
from torch._inductor.autotune_process import BenchmarkRequest, TuningProcessPool
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import Buffer, FixedLayout
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
from torch._inductor.select_algorithm import AlgorithmSelectorCache, ChoiceCaller
from torch._inductor.select_algorithm import (
AlgorithmSelectorCache,
ChoiceCaller,
TritonTemplateCaller,
)
from torch._inductor.utils import run_and_get_code
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
Expand Down Expand Up @@ -118,7 +124,8 @@ def test_benchmark_choice_fail_in_subproc(self):
self.assertNotEqual(0, child.exitcode)

@parametrize("autotune_in_subproc", (True, False))
def test_max_autotune_mm_plus_mm(self, autotune_in_subproc):
@parametrize("autotune_multi_device", (True, False))
def test_max_autotune_mm_plus_mm(self, autotune_in_subproc, autotune_multi_device):
"""
This crash previously due to a triton issue: https://github.com/openai/triton/issues/1298 .
With autotuning in subprocess, we don't crash anymore.
Expand All @@ -134,7 +141,11 @@ def mm_plus_mm(a, b, c, d):
d = torch.randn(k, n).cuda()

with config.patch(
{"max_autotune": True, "autotune_in_subproc": autotune_in_subproc}
{
"max_autotune": True,
"autotune_in_subproc": autotune_in_subproc,
"autotune_multi_device": autotune_multi_device,
}
):
torch.compile(mm_plus_mm)(a, b, c, d)

Expand Down Expand Up @@ -318,6 +329,53 @@ def fn(
torch.testing.assert_close(y1, y1_expected)


class TestBenchmarkRequest(BenchmarkRequest):
def __init__(self, value: Optional[float] = None) -> None:
self.value = value

def benchmark(
self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
) -> float:
if self.value is None:
raise Exception("Failed to run")
return self.value


class TestTritonTemplateCaller(TritonTemplateCaller):
def __init__(self, bmreq: TestBenchmarkRequest):
self.bmreq = bmreq

def __str__(self) -> str:
return "test"


class TestTuningProcess(TestCase):
def test_tuning_pool(self):
# Use only one device:
with config.patch({"autotune_multi_device": False}):
tuning_pool = TuningProcessPool()
tuning_pool.initialize()

# First cause the tuning process to crash.
bmreq = TestBenchmarkRequest(value=None)
choice = TestTritonTemplateCaller(bmreq)

timings = tuning_pool.benchmark([choice])
self.assertTrue(choice in timings)
self.assertEqual(timings[choice], float("inf"))

# Then send another request and make sure the sub-process
# has restarted and is operational.
value = 3.14
choice.bmreq.value = value

timings = tuning_pool.benchmark([choice])
self.assertTrue(choice in timings)
self.assertEqual(timings[choice], value)

tuning_pool.terminate()


if __name__ == "__main__":
if HAS_CUDA:
run_tests()
2 changes: 1 addition & 1 deletion test/inductor/test_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def patches(fn):
def skip_cache(self, choices, name, key, generate):
return {choice: generate(choice) for choice in choices}
return generate(choices)

for patcher in [
dynamo_config.patch(verbose=True),
Expand Down
Loading
Loading