Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] graph replayer #106952

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import make_boxed_func
from torch._inductor.codecache import code_hash, CompiledFxGraph

from torch._inductor.debug import save_args_for_compile_fx_inner
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
Expand Down Expand Up @@ -273,6 +275,7 @@ def fake_tensor_prop(
@DebugContext.wrap
@torch.utils._python_dispatch._disable_current_modes()
@time_and_log(attr="compilation time (in seconds)")
@save_args_for_compile_fx_inner
def compile_fx_inner(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
# enable searching global and local cache regardless of `max_autotune`
search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1"

save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1"

# We will disable creating subprocess for autotuning if this is False
autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"

Expand Down
102 changes: 101 additions & 1 deletion torch/_inductor/debug.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import collections
import contextlib
import cProfile
import dataclasses
import functools
import itertools
import logging
import os
import os.path
import pickle
import pstats
import shutil
import subprocess
Expand All @@ -14,13 +17,15 @@
from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled

import torch
import torch._dynamo.utils as dynamo_utils
from torch import fx as fx

from torch._dynamo.repro.after_aot import save_graph_repro, wrap_compiler_debug
from torch._dynamo.utils import get_debug_dir
from torch.fx.graph_module import GraphModule
from torch.fx.passes.shape_prop import TensorMetadata
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
from torch.fx.passes.tools_common import legalize_graph
from torch.utils._pytree import tree_map

from . import config, ir # noqa: F811, this is needed
from .scheduler import (
Expand Down Expand Up @@ -387,3 +392,98 @@ def graph_diagram(self, nodes: SchedulerNodeList):

def output_code(self, filename):
shutil.copy(filename, self.filename("output_code.py"))


@dataclasses.dataclass
class TensorMetadataHolder:
tensor_metadata: TensorMetadata
device: torch.device


save_args_cnt = itertools.count()


def save_args_for_compile_fx_inner(fn):
"""
This is a decorator applied on the compile_fx_inner function.
When config.save_args is True, this function will save the arguments
for each compile_fx_inner call to the filesystem. Later on
one can replay the compile_fx_inner call with the saved arguments
using load_args_and_run_compile_fx_inner.
"""

assert (
fn.__name__ == "compile_fx_inner"
), "This decorator only works for compile_fx_inner right now"
shunting314 marked this conversation as resolved.
Show resolved Hide resolved

@functools.wraps(fn)
def wrapper(*args, **kwargs):
if not config.save_args:
return fn(*args, **kwargs)
gm = args[0]
if dynamo_utils.count_calls(gm.graph) == 0:
return fn(*args, **kwargs)

folder = "/tmp/inductor_saved_args"
if not os.path.exists(folder):
os.mkdir(folder)

def handle_tensor(x):
"""
Pickle FakeTensor will result in error:
AttributeError: Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'

Convert all Tensor to metadata. This may also makes pickle faster.
"""
if isinstance(x, torch.Tensor):
return TensorMetadataHolder(_extract_tensor_metadata(x), x.device)
else:
return x

args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs))

path = f"{folder}/{fn.__name__}_{next(save_args_cnt)}.pkl"
with open(path, "wb") as f:
pickle.dump((args_to_save, kwargs_to_save), f)

if log.isEnabledFor(logging.DEBUG):
message = f"""
Arguments for a compile_fx_inner call is saved to {path}. To replay the call,
run the following:

from torch._inductor.debug import load_args_and_run_compile_fx_inner
load_args_and_run_compile_fx_inner({path!r})
"""
# call print rather than log.debug. log.debug will print message
# prefix for each line which makes the code snippet harder to be
# copied.
# Not a big deal since the code is already been guarded by checking
# the log level.
print(message)

return fn(*args, **kwargs)
shunting314 marked this conversation as resolved.
Show resolved Hide resolved

return wrapper


def load_args_and_run_compile_fx_inner(path):
from torch._inductor.compile_fx import compile_fx_inner

with open(path, "rb") as f:
args, kwargs = pickle.load(f)

def handle_tensor(x):
if isinstance(x, TensorMetadataHolder):
return torch._dynamo.testing.rand_strided(
x.tensor_metadata.shape,
x.tensor_metadata.stride,
x.tensor_metadata.dtype,
x.device,
)
else:
return x

fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
with fake_mode, config.patch("save_args", False):
args, kwargs = tree_map(handle_tensor, (args, kwargs))
return compile_fx_inner(*args, **kwargs)