diff --git a/benchmarks/common.py b/benchmarks/common.py index 992db7c06c..da6f6391d1 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -331,6 +331,30 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs): return format_speedup(speedup, pvalue, is_correct=is_correct) +def dump_experiment(args, model_iter_fn, model, example_inputs): + """ + Run the model to dump the graph + """ + timings = np.zeros((1, 2), np.float64) + # if we randomize the input, we should also check the result is correct + should_check_result = should_randomize_input = args.randomize_input + is_correct = True + + inputs = ( + randomize_input(copy.deepcopy(example_inputs)) + if should_randomize_input + else example_inputs + ) + + with torchdynamo.run(): + timed( + model, model_iter_fn, inputs, return_result=True + ) + + + return current_name + + def overhead_experiment(*args, model_iter_fn): """ Measure overheads of TorchDynamo by running with no backend (only @@ -822,6 +846,11 @@ def parse_args(): action="store_true", help="Use same settings as --inductor for baseline comparisons", ) + parser.add_argument( + "--inductor-dump", + action="store_true", + help="Dump the graphs of computebuffers", + ) parser.add_argument( "--raise-on-assertion-error", action="store_true", @@ -1076,7 +1105,7 @@ def main(runner, original_dir=None): args.isolate = True # TODO(whc) should we move this to a more general part of the script? torch.backends.cuda.matmul.allow_tf32 = True - elif args.inductor or args.inductor_dynamic: + elif args.inductor or args.inductor_dynamic or args.inductor_dump: import torchinductor.config torchinductor.config.debug = args.verbose @@ -1089,8 +1118,13 @@ def main(runner, original_dir=None): else: torchinductor.config.dynamic_shapes = False - optimize_ctx = torchdynamo.optimize("inductor", nopython=args.nopython) - experiment = speedup_experiment + if args.inductor_dump: + optimize_ctx = torchdynamo.optimize("inductor_dump", nopython=args.nopython) + experiment = dump_experiment + output_filename = "inductor_dump.csv" + else: + optimize_ctx = torchdynamo.optimize("inductor", nopython=args.nopython) + experiment = speedup_experiment output_filename = "inductor.csv" elif args.online_autotune: optimize_ctx = torchdynamo.optimize(online_autotuner, nopython=args.nopython) diff --git a/torchdynamo/convert_frame.py b/torchdynamo/convert_frame.py index 6d967ab606..28bd3e05f9 100644 --- a/torchdynamo/convert_frame.py +++ b/torchdynamo/convert_frame.py @@ -83,6 +83,12 @@ def _wrap_compiler_fn(compiler_fn): from torchinductor.compile_fx import compile_fx return compile_fx + + elif compiler_fn == "inductor_dump": + from torchinductor.compile_fx import compile_fx_aot_dump + + return compile_fx_aot_dump + elif isinstance(compiler_fn, str): from .optimizations import BACKENDS diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index 91345d1192..b58d46e4c4 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -15,6 +15,13 @@ from torchdynamo.testing import same from torchdynamo.utils import identity from torchdynamo.utils import init_logging +from functorch._src.partitioners import draw_graph +from torch.fx.graph_module import GraphModule +from torch.fx.passes.shape_prop import TensorMetadata +from . import ir +from .codegen.cpp import CppScheduling +from .codegen.triton import TritonScheduling +from torch.fx.passes.tools_common import legalize_graph from . import config from .decomposition import decompositions @@ -160,6 +167,126 @@ def compile_fx_inner( raise +def get_fake_func(name): + def func1(*args): + return 0 + func1.__name__ = name + return func1 + + +def create_fx_graph(nodes, fname, backend = "triton", print_graph = False): + + func_dict = {} + # import pprint + # pprint.pprint(nodes) + name_to_fx_node = {} + graph = torch.fx.Graph() + first_node = None + + if backend == "triton": + group_fn = TritonScheduling(None).group_fn + group_fn_NHW_C = TritonScheduling(None).group_fn_NHW_C + else: + group_fn = CppScheduling(None).group_fn + + # create call_function node for each Buffer and Kernel + for node in nodes: + name = node.get_name() + node_type = str(type(node)).split(".")[-1].replace("'>","") + if node_type in func_dict: + fake_f = func_dict[node_type] + else: + fake_f = get_fake_func(node_type) + func_dict[node_type] = fake_f + fx_node = graph.call_function(fake_f, args=(), kwargs=None) + fx_node.name = name + + # gather meta data + dtype = None + if isinstance(node, ir.ComputedBuffer): + dtype = node.data.dtype + + try: + stride = node.get_stride() + layout = type(node.layout) + sizes = node.get_size() + if isinstance(node, ir.ComputedBuffer): + sizes, _ = node.simplify_reorder_and_tile() + elif isinstance(node, ir.ExternKernel): + sizes, _ = node.get_group_stride() + + if isinstance(node, ir.Convolution): + group = group_fn_NHW_C(sizes) + else: + group = group_fn(sizes) + except: + group = torch.Size([0]) + + metadata = TensorMetadata(group, dtype, False, stride, layout, None, None) + fx_node.meta["tensor_meta"] = metadata + + name_to_fx_node[name] = fx_node + if first_node is None: + first_node = fx_node + + # create edges between nodes + for node in nodes: + name = node.get_name() + deps = node.get_reads() + fx_node = name_to_fx_node[node.name] + + new_args = [] + for dep in deps: + if dep.name in name_to_fx_node: + dep_node = name_to_fx_node[dep.name] + else: + with graph.inserting_before(first_node): + dep_node = graph.placeholder(dep.name) # assume it's a placeholder if not a computebox + name_to_fx_node[dep.name] = dep_node + new_args.append(dep_node) + + fx_node.args = tuple(new_args) + + outputs = [] + for _,v in name_to_fx_node.items(): + if len(v.users) == 0: + outputs.append(v) + graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) + + + if print_graph: + print(graph) + print("starting creating module") + gm = GraphModule({}, graph) + graph = legalize_graph(gm) + gm.graph.lint() + # print(gm) + print("starting drawing") + draw_graph(gm, fname, clear_meta=False) + + +def draw_compute_box(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fname = "image", print_graph = False): + """ + Dump the graph of a compute box to a file with fname. + """ + init_logging() + wrap=identity + + try: + graph = GraphLowering(gm, num_dynamic_inputs=len(example_inputs)) + with V.set_graph_handler(graph): + wrap(graph.run)(*example_inputs) + # import pprint + # pprint.pprint(graph.buffers) + # breakpoint() + create_fx_graph(graph.buffers, fname, print_graph=print_graph) + except Exception: + if os.environ.get("TORCHINDUCTOR_DUMP_REPRO") == "1": + wrap(functools.partial(dump_to_repro, gm))(*example_inputs) + + raise + + def cudagraphify(model, inputs, static_input_idxs=()): """ Assumes inputs[static_input_idxs[i]] are always the same memory address @@ -233,6 +360,8 @@ def is_not_gradout(x): return len(static_arg_idxs) +model_name = "hf_Bert" + def compile_fx_aot(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]): """Main entrypoint to a compile given FX graph""" model_ = normalize_ir(model_, example_inputs_) @@ -246,6 +375,33 @@ def bw_compiler(model: torch.fx.GraphModule, example_inputs): fixed = count_tangents(model) return compile_fx_inner(model, example_inputs, num_fixed=fixed) + + return aot_autograd( + model_, + example_inputs_, + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + decompositions=decompositions, + partition_fn=min_cut_rematerialization_partition, + ) + + +def compile_fx_aot_dump(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]): + """Main entrypoint to a compile given FX graph""" + model_ = normalize_ir(model_, example_inputs_) + + + def fw_compiler(model: torch.fx.GraphModule, example_inputs): + global model_name + draw_compute_box(model, example_inputs, f"{model_name}_fw", print_graph=False) + return model + + + def bw_compiler(model: torch.fx.GraphModule, example_inputs): + global model_name + draw_compute_box(model, example_inputs, f"{model_name}_bw", print_graph=False) + return model + return aot_autograd( model_, example_inputs_,