Skip to content

Commit

Permalink
[inductor] graph replayer
Browse files Browse the repository at this point in the history
ghstack-source-id: c9abfc68edb85f6de8344189b397daf98c2c0865
Pull Request resolved: #106952
  • Loading branch information
shunting314 committed Aug 10, 2023
1 parent f8817d8 commit 4bb5f76
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 1 deletion.
24 changes: 24 additions & 0 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import make_boxed_func
from torch._inductor.codecache import code_hash, CompiledFxGraph

from torch._inductor.debug import save_args_for_compile_fx_inner
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
Expand Down Expand Up @@ -287,9 +289,31 @@ def compile_fx_inner(
user_visible_outputs=frozenset(),
layout_opt=None,
):
"""
Inductor API that compiles a single graph.
If you change the argument list for this funtion, make sure you
also update the call to save_args_for_compile_fx_inner below accordingly.
"""
if dynamo_utils.count_calls(gm.graph) == 0:
return make_boxed_func(gm.forward)

if config.save_args:
save_args_for_compile_fx_inner(
gm,
example_inputs,
cudagraphs=cudagraphs,
num_fixed=num_fixed,
is_backward=is_backward,
graph_id=graph_id,
cpp_wrapper=cpp_wrapper,
aot_mode=aot_mode,
is_inference=is_inference,
boxed_forward_device_index=boxed_forward_device_index,
user_visible_outputs=user_visible_outputs,
layout_opt=layout_opt,
)

if cudagraphs is None:
cudagraphs = BoxedBool(config.triton.cudagraphs)

Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
# enable searching global and local cache regardless of `max_autotune`
search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1"

save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1"

# We will disable creating subprocess for autotuning if this is False
autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"

Expand Down
84 changes: 83 additions & 1 deletion torch/_inductor/debug.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import collections
import contextlib
import cProfile
import dataclasses
import functools
import itertools
import logging
import os
import os.path
import pickle
import pstats
import shutil
import subprocess
Expand All @@ -19,8 +22,9 @@
from torch._dynamo.repro.after_aot import save_graph_repro, wrap_compiler_debug
from torch._dynamo.utils import get_debug_dir
from torch.fx.graph_module import GraphModule
from torch.fx.passes.shape_prop import TensorMetadata
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
from torch.fx.passes.tools_common import legalize_graph
from torch.utils._pytree import tree_map

from . import config, ir # noqa: F811, this is needed
from .scheduler import (
Expand Down Expand Up @@ -387,3 +391,81 @@ def graph_diagram(self, nodes: SchedulerNodeList):

def output_code(self, filename):
shutil.copy(filename, self.filename("output_code.py"))


@dataclasses.dataclass
class TensorMetadataHolder:
tensor_metadata: TensorMetadata
device: torch.device


save_args_cnt = itertools.count()


def save_args_for_compile_fx_inner(*args, **kwargs):
"""
This function is used to save arguments for a compile_fx_inner function call
to the file system. Later on one can replay the compile_fx_inner call
with the saved arguments using load_args_and_run_compile_fx_inner.
"""

folder = "/tmp/inductor_saved_args"
if not os.path.exists(folder):
os.mkdir(folder)

def handle_tensor(x):
"""
Pickle FakeTensor will result in error:
AttributeError: Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'
Convert all Tensor to metadata. This may also makes pickle faster.
"""
if isinstance(x, torch.Tensor):
return TensorMetadataHolder(_extract_tensor_metadata(x), x.device)
else:
return x

args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs))

fn_name = "compile_fx_inner"
path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl"
with open(path, "wb") as f:
pickle.dump((args_to_save, kwargs_to_save), f)

if log.isEnabledFor(logging.DEBUG):
message = f"""
Arguments for a compile_fx_inner call is saved to {path}. To replay the call,
run the following:
from torch._inductor.debug import load_args_and_run_compile_fx_inner
load_args_and_run_compile_fx_inner({path!r})
"""
# call print rather than log.debug. log.debug will print message
# prefix for each line which makes the code snippet harder to be
# copied.
# Not a big deal since the code is already been guarded by checking
# the log level.
print(message)


def load_args_and_run_compile_fx_inner(path):
from torch._inductor.compile_fx import compile_fx_inner

with open(path, "rb") as f:
args, kwargs = pickle.load(f)

def handle_tensor(x):
if isinstance(x, TensorMetadataHolder):
return torch._dynamo.testing.rand_strided(
x.tensor_metadata.shape,
x.tensor_metadata.stride,
x.tensor_metadata.dtype,
x.device,
)
else:
return x

fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
with fake_mode, config.patch("save_args", False):
args, kwargs = tree_map(handle_tensor, (args, kwargs))
return compile_fx_inner(*args, **kwargs)

0 comments on commit 4bb5f76

Please sign in to comment.