Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions benchmarks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,30 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs):
return format_speedup(speedup, pvalue, is_correct=is_correct)


def dump_experiment(args, model_iter_fn, model, example_inputs):
"""
Run the model to dump the graph
"""
timings = np.zeros((1, 2), np.float64)
# if we randomize the input, we should also check the result is correct
should_check_result = should_randomize_input = args.randomize_input
is_correct = True

inputs = (
randomize_input(copy.deepcopy(example_inputs))
if should_randomize_input
else example_inputs
)

with torchdynamo.run():
timed(
model, model_iter_fn, inputs, return_result=True
)


return current_name


def overhead_experiment(*args, model_iter_fn):
"""
Measure overheads of TorchDynamo by running with no backend (only
Expand Down Expand Up @@ -822,6 +846,11 @@ def parse_args():
action="store_true",
help="Use same settings as --inductor for baseline comparisons",
)
parser.add_argument(
"--inductor-dump",
action="store_true",
help="Dump the graphs of computebuffers",
)
parser.add_argument(
"--raise-on-assertion-error",
action="store_true",
Expand Down Expand Up @@ -1076,7 +1105,7 @@ def main(runner, original_dir=None):
args.isolate = True
# TODO(whc) should we move this to a more general part of the script?
torch.backends.cuda.matmul.allow_tf32 = True
elif args.inductor or args.inductor_dynamic:
elif args.inductor or args.inductor_dynamic or args.inductor_dump:
import torchinductor.config

torchinductor.config.debug = args.verbose
Expand All @@ -1089,8 +1118,13 @@ def main(runner, original_dir=None):
else:
torchinductor.config.dynamic_shapes = False

optimize_ctx = torchdynamo.optimize("inductor", nopython=args.nopython)
experiment = speedup_experiment
if args.inductor_dump:
optimize_ctx = torchdynamo.optimize("inductor_dump", nopython=args.nopython)
experiment = dump_experiment
output_filename = "inductor_dump.csv"
else:
optimize_ctx = torchdynamo.optimize("inductor", nopython=args.nopython)
experiment = speedup_experiment
output_filename = "inductor.csv"
elif args.online_autotune:
optimize_ctx = torchdynamo.optimize(online_autotuner, nopython=args.nopython)
Expand Down
6 changes: 6 additions & 0 deletions torchdynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def _wrap_compiler_fn(compiler_fn):
from torchinductor.compile_fx import compile_fx

return compile_fx

elif compiler_fn == "inductor_dump":
from torchinductor.compile_fx import compile_fx_aot_dump

return compile_fx_aot_dump

elif isinstance(compiler_fn, str):
from .optimizations import BACKENDS

Expand Down
156 changes: 156 additions & 0 deletions torchinductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
from torchdynamo.testing import same
from torchdynamo.utils import identity
from torchdynamo.utils import init_logging
from functorch._src.partitioners import draw_graph
from torch.fx.graph_module import GraphModule
from torch.fx.passes.shape_prop import TensorMetadata
from . import ir
from .codegen.cpp import CppScheduling
from .codegen.triton import TritonScheduling
from torch.fx.passes.tools_common import legalize_graph

from . import config
from .decomposition import decompositions
Expand Down Expand Up @@ -160,6 +167,126 @@ def compile_fx_inner(
raise


def get_fake_func(name):
def func1(*args):
return 0
func1.__name__ = name
return func1


def create_fx_graph(nodes, fname, backend = "triton", print_graph = False):

func_dict = {}
# import pprint
# pprint.pprint(nodes)
name_to_fx_node = {}
graph = torch.fx.Graph()
first_node = None

if backend == "triton":
group_fn = TritonScheduling(None).group_fn
group_fn_NHW_C = TritonScheduling(None).group_fn_NHW_C
else:
group_fn = CppScheduling(None).group_fn

# create call_function node for each Buffer and Kernel
for node in nodes:
name = node.get_name()
node_type = str(type(node)).split(".")[-1].replace("'>","")
if node_type in func_dict:
fake_f = func_dict[node_type]
else:
fake_f = get_fake_func(node_type)
func_dict[node_type] = fake_f
fx_node = graph.call_function(fake_f, args=(), kwargs=None)
fx_node.name = name

# gather meta data
dtype = None
if isinstance(node, ir.ComputedBuffer):
dtype = node.data.dtype

try:
stride = node.get_stride()
layout = type(node.layout)
sizes = node.get_size()
if isinstance(node, ir.ComputedBuffer):
sizes, _ = node.simplify_reorder_and_tile()
elif isinstance(node, ir.ExternKernel):
sizes, _ = node.get_group_stride()

if isinstance(node, ir.Convolution):
group = group_fn_NHW_C(sizes)
else:
group = group_fn(sizes)
except:
group = torch.Size([0])

metadata = TensorMetadata(group, dtype, False, stride, layout, None, None)
fx_node.meta["tensor_meta"] = metadata

name_to_fx_node[name] = fx_node
if first_node is None:
first_node = fx_node

# create edges between nodes
for node in nodes:
name = node.get_name()
deps = node.get_reads()
fx_node = name_to_fx_node[node.name]

new_args = []
for dep in deps:
if dep.name in name_to_fx_node:
dep_node = name_to_fx_node[dep.name]
else:
with graph.inserting_before(first_node):
dep_node = graph.placeholder(dep.name) # assume it's a placeholder if not a computebox
name_to_fx_node[dep.name] = dep_node
new_args.append(dep_node)

fx_node.args = tuple(new_args)

outputs = []
for _,v in name_to_fx_node.items():
if len(v.users) == 0:
outputs.append(v)
graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))


if print_graph:
print(graph)
print("starting creating module")
gm = GraphModule({}, graph)
graph = legalize_graph(gm)
gm.graph.lint()
# print(gm)
print("starting drawing")
draw_graph(gm, fname, clear_meta=False)


def draw_compute_box(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fname = "image", print_graph = False):
"""
Dump the graph of a compute box to a file with fname.
"""
init_logging()
wrap=identity

try:
graph = GraphLowering(gm, num_dynamic_inputs=len(example_inputs))
with V.set_graph_handler(graph):
wrap(graph.run)(*example_inputs)
# import pprint
# pprint.pprint(graph.buffers)
# breakpoint()
create_fx_graph(graph.buffers, fname, print_graph=print_graph)
except Exception:
if os.environ.get("TORCHINDUCTOR_DUMP_REPRO") == "1":
wrap(functools.partial(dump_to_repro, gm))(*example_inputs)

raise


def cudagraphify(model, inputs, static_input_idxs=()):
"""
Assumes inputs[static_input_idxs[i]] are always the same memory address
Expand Down Expand Up @@ -233,6 +360,8 @@ def is_not_gradout(x):
return len(static_arg_idxs)


model_name = "hf_Bert"

def compile_fx_aot(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]):
"""Main entrypoint to a compile given FX graph"""
model_ = normalize_ir(model_, example_inputs_)
Expand All @@ -246,6 +375,33 @@ def bw_compiler(model: torch.fx.GraphModule, example_inputs):
fixed = count_tangents(model)
return compile_fx_inner(model, example_inputs, num_fixed=fixed)


return aot_autograd(
model_,
example_inputs_,
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
decompositions=decompositions,
partition_fn=min_cut_rematerialization_partition,
)


def compile_fx_aot_dump(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]):
"""Main entrypoint to a compile given FX graph"""
model_ = normalize_ir(model_, example_inputs_)


def fw_compiler(model: torch.fx.GraphModule, example_inputs):
global model_name
draw_compute_box(model, example_inputs, f"{model_name}_fw", print_graph=False)
return model


def bw_compiler(model: torch.fx.GraphModule, example_inputs):
global model_name
draw_compute_box(model, example_inputs, f"{model_name}_bw", print_graph=False)
return model

return aot_autograd(
model_,
example_inputs_,
Expand Down