Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aotinductor] Allow specifying a .so name in the aot_inductor.output_path config #112651

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 21 additions & 1 deletion test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def forward(self, x):
with config.patch({"always_keep_tensor_constants": True}):
self.check_model(Model().to(self.device), example_inputs)

def test_output_path(self):
def test_output_path_1(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -231,6 +231,26 @@ def forward(self, x, y):
with config.patch("aot_inductor.output_path", "tmp_output_"):
self.check_model(Model(), example_inputs)

def test_output_path_2(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)

def forward(self, x, y):
return x + self.linear(y)

model = Model().to(device=self.device)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
expected_path = os.path.join(tempfile.mkdtemp(), "model.so")
actual_path = AOTInductorModelRunner.compile(
model, example_inputs, options={"aot_inductor.output_path": expected_path}
)
self.assertTrue(actual_path == expected_path)

@requires_cuda()
def test_multi_device(self):
class Model(torch.nn.Module):
Expand Down
35 changes: 30 additions & 5 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,15 @@ def run_command_and_check(cmd: str):
raise exc.CppCompileError(cmd, e.output) from e


@functools.lru_cache(None)
def split_aot_inductor_output_path(path: str) -> Tuple[str, str]:
"""Returns the path where the AOT Inductor compiled kernels are stored."""
if path.endswith(".so"):
return os.path.split(path)
else:
return path, ""


class CudaKernelParamCache:
cache: Dict[str, Dict[str, str]] = dict()
clear = staticmethod(cache.clear)
Expand All @@ -1504,7 +1513,9 @@ def set(cls, key: str, params: Dict[str, str], cubin: str) -> None:
cubin,
"cubin",
hash_type="cubin",
specified_dir=config.aot_inductor.output_path,
specified_dir=split_aot_inductor_output_path(
config.aot_inductor.output_path
)[0],
)
params["cubin_path"] = path
cls.cache[key] = params
Expand Down Expand Up @@ -1545,14 +1556,24 @@ def compile(
else:
ld_command = "ld"
objcopy_command = "objcopy"

(
specified_output_path,
specified_so_name,
) = split_aot_inductor_output_path(config.aot_inductor.output_path)
key, input_path = write(
source_code,
"cpp",
extra=cpp_command,
specified_dir=config.aot_inductor.output_path,
specified_dir=specified_output_path,
)

if key not in cls.cache:
if key not in cls.cache or (
specified_output_path
and os.path.dirname(cls.cache[key]) != specified_output_path
or specified_so_name
and os.path.basename(cls.cache[key]) != specified_so_name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cache the full path for the output.so, which is either config.aot_inductor.output_path or os.path.splitext(input_path)[0] + ".so". So, seems we wouldn't have any case where we cache specified_so_name?

):
from filelock import FileLock

lock_dir = get_lock_dir()
Expand All @@ -1565,7 +1586,11 @@ def compile(
with open(output_json, "w") as f:
f.write(serialized_extern_kernel_nodes)

output_so = os.path.splitext(input_path)[0] + ".so"
output_so = (
config.aot_inductor.output_path
if specified_so_name
else os.path.splitext(input_path)[0] + ".so"
)

if not os.path.exists(output_so):
output_o = os.path.splitext(input_path)[0] + ".o"
Expand Down Expand Up @@ -1605,7 +1630,7 @@ def _to_bytes(t: torch.Tensor) -> bytes:
consts_key, consts_path = write(
aot_constants,
"bin",
specified_dir=config.aot_inductor.output_path,
specified_dir=specified_output_path,
)

consts_o = os.path.splitext(consts_path)[0] + ".o"
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,9 @@ class aot_inductor:
# AOTInductor output path
# If an absolute path is specified, the generated lib files will be stored under the directory;
# If a relative path is specified, it will be used as a subdirectory under the default caching path;
# If not specified, a temp directory will be created under the default caching path
# If not specified, a temp directory will be created under the default caching path.
# If the specified path contains something like "model.so", the sub-string will be used
# to name the generated library.
output_path = ""

debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1"
Expand Down