Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] graph replayer #106952

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import make_boxed_func
from torch._inductor.codecache import code_hash, CompiledFxGraph

from torch._inductor.debug import save_args
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
Expand Down Expand Up @@ -273,6 +275,7 @@ def fake_tensor_prop(
@DebugContext.wrap
@torch.utils._python_dispatch._disable_current_modes()
@time_and_log(attr="compilation time (in seconds)")
@save_args
def compile_fx_inner(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
# enable searching global and local cache regardless of `max_autotune`
search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1"

save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1"

# We will disable creating subprocess for autotuning if this is False
autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"

Expand Down
69 changes: 68 additions & 1 deletion torch/_inductor/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import functools
import itertools
import logging
import os
import os.path
import pickle
import pstats
import shutil
import subprocess
Expand All @@ -14,13 +16,15 @@
from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled

import torch
import torch._dynamo.utils as dynamo_utils
from torch import fx as fx

from torch._dynamo.repro.after_aot import save_graph_repro, wrap_compiler_debug
from torch._dynamo.utils import get_debug_dir
from torch.fx.graph_module import GraphModule
from torch.fx.passes.shape_prop import TensorMetadata
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
from torch.fx.passes.tools_common import legalize_graph
from torch.utils._pytree import tree_map

from . import config, ir # noqa: F811, this is needed
from .scheduler import (
Expand Down Expand Up @@ -387,3 +391,66 @@ def graph_diagram(self, nodes: SchedulerNodeList):

def output_code(self, filename):
shutil.copy(filename, self.filename("output_code.py"))


save_args_cnt = itertools.count()


def save_args(fn):
shunting314 marked this conversation as resolved.
Show resolved Hide resolved
if not config.save_args:
shunting314 marked this conversation as resolved.
Show resolved Hide resolved
return fn

@functools.wraps(fn)
def wrapper(*args, **kwargs):
gm = args[0] # NOTE this is specific to compile_fx_inner
shunting314 marked this conversation as resolved.
Show resolved Hide resolved
if dynamo_utils.count_calls(gm.graph) == 0:
return fn(*args, **kwargs)

folder = "/tmp/inductor_saved_args"
if not os.path.exists(folder):
os.mkdir(folder)

def handle_tensor(x):
if isinstance(x, torch.Tensor):
return _extract_tensor_metadata(x)
else:
return x

args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs))
shunting314 marked this conversation as resolved.
Show resolved Hide resolved

with open(f"{folder}/{fn.__name__}_{next(save_args_cnt)}.pkl", "wb") as f:
pickle.dump((args_to_save, kwargs_to_save), f)

return fn(*args, **kwargs)
shunting314 marked this conversation as resolved.
Show resolved Hide resolved

return wrapper


def load_args_and_run_compile_fx_inner(path):
from torch._inductor.compile_fx import compile_fx_inner

with open(path, "rb") as f:
args, kwargs = pickle.load(f)

def handle_tensor(x):
# TODO don't use TensorMetadata since it has a few drawbacks
# 1. pytree will flatten it
# 2. it does not store device information.
if isinstance(x, TensorMetadata):
shunting314 marked this conversation as resolved.
Show resolved Hide resolved
return torch._dynamo.testing.rand_strided(
x.shape,
x.stride,
x.dtype,
torch.device("cuda"),
)
else:
return x

fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
with fake_mode, config.patch("save_args", False):
# don't call tree_map since TensorMetadata as a namedtuple will be
# flattened through.
# args, kwargs = tree_map(handle_tensor, (args, kwargs))
args = list(args)
args[1] = list(map(handle_tensor, args[1]))
return compile_fx_inner(*args, **kwargs)