Skip to content

Commit

Permalink
[inductor] graph replayer
Browse files Browse the repository at this point in the history
ghstack-source-id: 955a5e0ed33ad030a7998367e376857c5c29e92a
Pull Request resolved: #106952
  • Loading branch information
shunting314 committed Aug 10, 2023
1 parent f8817d8 commit a2ca918
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 1 deletion.
3 changes: 3 additions & 0 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import make_boxed_func
from torch._inductor.codecache import code_hash, CompiledFxGraph

from torch._inductor.debug import save_args_for_compile_fx_inner
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
Expand Down Expand Up @@ -273,6 +275,7 @@ def fake_tensor_prop(
@DebugContext.wrap
@torch.utils._python_dispatch._disable_current_modes()
@time_and_log(attr="compilation time (in seconds)")
@save_args_for_compile_fx_inner
def compile_fx_inner(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
# enable searching global and local cache regardless of `max_autotune`
search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1"

save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1"

# We will disable creating subprocess for autotuning if this is False
autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"

Expand Down
102 changes: 101 additions & 1 deletion torch/_inductor/debug.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import collections
import contextlib
import cProfile
import dataclasses
import functools
import itertools
import logging
import os
import os.path
import pickle
import pstats
import shutil
import subprocess
Expand All @@ -14,13 +17,15 @@
from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled

import torch
import torch._dynamo.utils as dynamo_utils
from torch import fx as fx

from torch._dynamo.repro.after_aot import save_graph_repro, wrap_compiler_debug
from torch._dynamo.utils import get_debug_dir
from torch.fx.graph_module import GraphModule
from torch.fx.passes.shape_prop import TensorMetadata
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
from torch.fx.passes.tools_common import legalize_graph
from torch.utils._pytree import tree_map

from . import config, ir # noqa: F811, this is needed
from .scheduler import (
Expand Down Expand Up @@ -387,3 +392,98 @@ def graph_diagram(self, nodes: SchedulerNodeList):

def output_code(self, filename):
shutil.copy(filename, self.filename("output_code.py"))


@dataclasses.dataclass
class TensorMetadataHolder:
tensor_metadata: TensorMetadata
device: torch.device


save_args_cnt = itertools.count()


def save_args_for_compile_fx_inner(fn):
"""
This is a decorator applied on the compile_fx_inner function.
When config.save_args is True, this function will save the arguments
for each compile_fx_inner call to the filesystem. Later on
one can replay the compile_fx_inner call with the saved arguments
using load_args_and_run_compile_fx_inner.
"""

assert (
fn.__name__ == "compile_fx_inner"
), "This decorator only works for compile_fx_inner right now"

@functools.wraps(fn)
def wrapper(*args, **kwargs):
if not config.save_args:
return fn(*args, **kwargs)
gm = args[0]
if dynamo_utils.count_calls(gm.graph) == 0:
return fn(*args, **kwargs)

folder = "/tmp/inductor_saved_args"
if not os.path.exists(folder):
os.mkdir(folder)

def handle_tensor(x):
"""
Pickle FakeTensor will result in error:
AttributeError: Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'
Convert all Tensor to metadata. This may also makes pickle faster.
"""
if isinstance(x, torch.Tensor):
return TensorMetadataHolder(_extract_tensor_metadata(x), x.device)
else:
return x

args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs))

path = f"{folder}/{fn.__name__}_{next(save_args_cnt)}.pkl"
with open(path, "wb") as f:
pickle.dump((args_to_save, kwargs_to_save), f)

if log.isEnabledFor(logging.DEBUG):
message = f"""
Arguments for a compile_fx_inner call is saved to {path}. To replay the call,
run the following:
from torch._inductor.debug import load_args_and_run_compile_fx_inner
load_args_and_run_compile_fx_inner({path!r})
"""
# call print rather than log.debug. log.debug will print message
# prefix for each line which makes the code snippet harder to be
# copied.
# Not a big deal since the code is already been guarded by checking
# the log level.
print(message)

return fn(*args, **kwargs)

return wrapper


def load_args_and_run_compile_fx_inner(path):
from torch._inductor.compile_fx import compile_fx_inner

with open(path, "rb") as f:
args, kwargs = pickle.load(f)

def handle_tensor(x):
if isinstance(x, TensorMetadataHolder):
return torch._dynamo.testing.rand_strided(
x.tensor_metadata.shape,
x.tensor_metadata.stride,
x.tensor_metadata.dtype,
x.device,
)
else:
return x

fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
with fake_mode, config.patch("save_args", False):
args, kwargs = tree_map(handle_tensor, (args, kwargs))
return compile_fx_inner(*args, **kwargs)

0 comments on commit a2ca918

Please sign in to comment.