Skip to content

Commit

Permalink
inductor cpp wrapper: add GIL release and acquire
Browse files Browse the repository at this point in the history
ghstack-source-id: ed6c39793b7d9fc578ff7d8ad1442880a49a326a
Pull Request resolved: #111888
  • Loading branch information
chunyuan-w committed Oct 30, 2023
1 parent 69b9e54 commit 2815962
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
2 changes: 2 additions & 0 deletions test/inductor/test_cpp_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class BaseTest(NamedTuple):
),
BaseTest("test_mm_views"),
BaseTest("test_multihead_attention", "cpu", test_cpu_repro.CPUReproTests()),
BaseTest("test_multi_threading"),
BaseTest("test_profiler_mark_wrapper_call"),
BaseTest("test_randint"),
BaseTest("test_randn_with_dtype_and_device"),
Expand Down Expand Up @@ -267,6 +268,7 @@ class BaseTest(NamedTuple):
BaseTest("test_linear2"),
BaseTest("test_mm_views"),
BaseTest("test_multi_device"),
BaseTest("test_multi_threading"),
BaseTest("test_profiler_mark_wrapper_call"),
BaseTest("test_reduction1"), # Reduction
BaseTest("test_relu"), # multiple inputs
Expand Down
24 changes: 24 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import re
import subprocess
import sys
import threading
import time
import typing
import unittest
Expand Down Expand Up @@ -2600,6 +2601,29 @@ def fn(x):
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)

def test_multi_threading(self):
model = torch.nn.Linear(2, 3).eval()
inp = torch.randn(4, 2)

num_run = 3

def run_weights_sharing_model(m, inp):
with torch.no_grad():
for i in range(num_run):
y = m(inp)

numb_instance = 2
threads = []
compiled_m = torch.compile(model)
for i in range(1, numb_instance + 1):
thread = threading.Thread(
target=run_weights_sharing_model, args=(compiled_m, inp)
)
threads.append(thread)
thread.start()
for thread in threads:
thread.join()

@unittest.skipIf(config.is_fbcode(), "fbcode triton error, needs debugging")
def test_adaptive_avg_pool2d_low_prec(self):
class Model(torch.nn.Module):
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,12 @@ def write_wrapper_decl(self):
auto inputs = alloc_tensors_by_stealing_from_handles(input_handles, num_inputs());
"""
)
else:
self.prefix.splice(
"""
py::gil_scoped_release release;
"""
)

if inputs_len != 0:
for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
Expand Down

0 comments on commit 2815962

Please sign in to comment.