From 95f0b1eefc8cb9782b008b9cf7a51bb9c25592f0 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Fri, 22 Jul 2022 00:15:08 +0000 Subject: [PATCH 01/19] add dump graph utils --- torchinductor/compile_fx.py | 81 +++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index 91345d1192..ab267b1861 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -15,6 +15,9 @@ from torchdynamo.testing import same from torchdynamo.utils import identity from torchdynamo.utils import init_logging +from functorch._src.partitioners import draw_graph +from torch.fx.graph_module import GraphModule + from . import config from .decomposition import decompositions @@ -160,6 +163,84 @@ def compile_fx_inner( raise +class FakeModule(object): + def __init__(self, _name): + self.__name__ = _name + + +def get_fake_func(name): + def func1(*args): + return 0 + func1.__name__ = name + return func1 + + +def create_fx_graph(nodes, fname): + name_to_fx_node = {} + fake_root = {} + graph = torch.fx.Graph() + first_node = None + for node in nodes: + name = node.get_name() + # fx_node = graph.call_module(name, args=(), kwargs=None) + fake_f = get_fake_func(name) + fx_node = graph.call_function(fake_f, args=(), kwargs=None) + fake_root[name] = torch.fx.GraphModule({}, torch.fx.Graph()) + name_to_fx_node[name] = fx_node + if first_node is None: + first_node = fx_node + + for node in nodes: + name = node.get_name() + deps = node.get_reads() + fx_node = name_to_fx_node[node.name] + + new_args = [] + for dep in deps: + if dep.name in name_to_fx_node: + dep_node = name_to_fx_node[dep.name] + else: + fake_root[dep.name] = FakeModule(dep.name) + with graph.inserting_before(first_node): + dep_node = graph.placeholder(dep.name) # assume it's a placeholder if not a computebox + new_args.append(dep_node) + + fx_node.args = tuple(new_args) + + outputs = [] + for k,v in name_to_fx_node.items(): + if len(v.users) == 0: + outputs.append(v) + + graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) + + print(graph) + gm = GraphModule({}, graph) + draw_graph(gm, fname) + + +def draw_compute_box(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fname = "image"): + """ + Dump the graph of a compute box to a file with fname. + """ + init_logging() + wrap=identity + + try: + graph = GraphLowering(gm, num_dynamic_inputs=len(example_inputs)) + with V.set_graph_handler(graph): + wrap(graph.run)(*example_inputs) + # import pprint + # pprint.pprint(graph.buffers) + # breakpoint() + create_fx_graph(graph.buffers, fname) + except Exception: + if os.environ.get("TORCHINDUCTOR_DUMP_REPRO") == "1": + wrap(functools.partial(dump_to_repro, gm))(*example_inputs) + + raise + + def cudagraphify(model, inputs, static_input_idxs=()): """ Assumes inputs[static_input_idxs[i]] are always the same memory address From f4dee87046d3eae114b4c78d617485b8048c5b85 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Fri, 22 Jul 2022 00:19:11 +0000 Subject: [PATCH 02/19] cleanup --- torchinductor/compile_fx.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index ab267b1861..6e8c6ff910 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -163,11 +163,6 @@ def compile_fx_inner( raise -class FakeModule(object): - def __init__(self, _name): - self.__name__ = _name - - def get_fake_func(name): def func1(*args): return 0 @@ -177,15 +172,12 @@ def func1(*args): def create_fx_graph(nodes, fname): name_to_fx_node = {} - fake_root = {} graph = torch.fx.Graph() first_node = None for node in nodes: name = node.get_name() - # fx_node = graph.call_module(name, args=(), kwargs=None) fake_f = get_fake_func(name) fx_node = graph.call_function(fake_f, args=(), kwargs=None) - fake_root[name] = torch.fx.GraphModule({}, torch.fx.Graph()) name_to_fx_node[name] = fx_node if first_node is None: first_node = fx_node @@ -200,7 +192,6 @@ def create_fx_graph(nodes, fname): if dep.name in name_to_fx_node: dep_node = name_to_fx_node[dep.name] else: - fake_root[dep.name] = FakeModule(dep.name) with graph.inserting_before(first_node): dep_node = graph.placeholder(dep.name) # assume it's a placeholder if not a computebox new_args.append(dep_node) @@ -208,7 +199,7 @@ def create_fx_graph(nodes, fname): fx_node.args = tuple(new_args) outputs = [] - for k,v in name_to_fx_node.items(): + for _,v in name_to_fx_node.items(): if len(v.users) == 0: outputs.append(v) From 761b8d79c5fe458046f3f1d2a851444a2ea9c90a Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Fri, 22 Jul 2022 01:13:06 +0000 Subject: [PATCH 03/19] add meta data --- torchinductor/compile_fx.py | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index 6e8c6ff910..379566dd7c 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -17,7 +17,10 @@ from torchdynamo.utils import init_logging from functorch._src.partitioners import draw_graph from torch.fx.graph_module import GraphModule - +from torch.fx.passes.shape_prop import TensorMetadata +from . import ir +from .codegen.cpp import CppScheduling +from .codegen.triton import TritonScheduling from . import config from .decomposition import decompositions @@ -170,14 +173,39 @@ def func1(*args): return func1 -def create_fx_graph(nodes, fname): +def create_fx_graph(nodes, fname, backend = "triton"): name_to_fx_node = {} graph = torch.fx.Graph() first_node = None + + if backend == "triton": + group_fn = TritonScheduling(None).group_fn + group_fn_NHW_C = TritonScheduling(None).group_fn_NHW_C + else: + group_fn = CppScheduling(None).group_fn + for node in nodes: name = node.get_name() fake_f = get_fake_func(name) fx_node = graph.call_function(fake_f, args=(), kwargs=None) + dtype = None + if isinstance(node, ir.ComputedBuffer): + dtype = node.data.dtype + + sizes = node.get_size() + if isinstance(node, ir.ComputedBuffer): + sizes, _ = node.simplify_reorder_and_tile() + elif isinstance(node, ir.ExternKernel): + sizes, _ = node.get_group_stride() + + if isinstance(node, ir.Convolution): + group = group_fn_NHW_C(sizes) + else: + group = group_fn(sizes) + + metadata = TensorMetadata(group, dtype, False, node.get_stride(), type(node.layout), None, None) + fx_node.meta["tensor_meta"] = metadata + name_to_fx_node[name] = fx_node if first_node is None: first_node = fx_node @@ -207,7 +235,7 @@ def create_fx_graph(nodes, fname): print(graph) gm = GraphModule({}, graph) - draw_graph(gm, fname) + draw_graph(gm, fname, clear_meta=False) def draw_compute_box(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fname = "image"): From cafe8d8a672eab8aa9b8957b6132015e4e0a52a7 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Fri, 22 Jul 2022 23:45:55 +0000 Subject: [PATCH 04/19] add node type to target --- torchinductor/compile_fx.py | 154 ++++++++++++++++++++++++++++++------ 1 file changed, 131 insertions(+), 23 deletions(-) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index 379566dd7c..485e2c4a6b 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -165,6 +165,58 @@ def compile_fx_inner( raise +def legalize_graph(graph: torch.fx.Graph): + """ + Replace the graph of the given GraphModule with one that contains the same nodes as the + original, but in topologically sorted order. + + This is used by the merge_matmul transformation below, which disturbs the topologically sorted + order of its input GraphModule, so that this order is restored before further transformation. + + Arguments: + gm: The graph module to topologically sort. It is modified in-place. + + """ + # Build an adjacency list representation of node dependencies in the graph. This also + # serves as a list of nodes that still need to be inserted into the new, topologically + # sorted graph. + dependencies = {node: node.all_input_nodes.copy() for node in graph.nodes} + + # Construct a new graph that will contain all nodes in topologically sorted order. + new_graph = torch.fx.Graph() + value_remap = {} + + # Copy over all nodes with no dependencies. + for node, deps in dependencies.items(): + if not deps: + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) + + # Remove the copied over nodes from the adjacency list. + for copied_node in value_remap.keys(): + del dependencies[copied_node] + + # While there are still nodes to insert into the new graph: + while dependencies: + copied_this_round = [] + + # Copy over all nodes whose dependencies already exist in the new graph. + for node, deps in dependencies.items(): + all_deps_copied = True + for dep in deps: + if dep not in value_remap: + all_deps_copied = False + + if all_deps_copied: + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) + copied_this_round.append(node) + + # Delete all nodes copied over in this iteration from dependencies. + for copied_node in copied_this_round: + del dependencies[copied_node] + + # Replace the old graph with the new, topologically sorted one. + return new_graph + def get_fake_func(name): def func1(*args): @@ -173,7 +225,11 @@ def func1(*args): return func1 -def create_fx_graph(nodes, fname, backend = "triton"): +def create_fx_graph(nodes, fname, backend = "triton", print_graph = False): + + func_dict = {} + # import pprint + # pprint.pprint(nodes) name_to_fx_node = {} graph = torch.fx.Graph() first_node = None @@ -184,32 +240,47 @@ def create_fx_graph(nodes, fname, backend = "triton"): else: group_fn = CppScheduling(None).group_fn + # create call_function node for each Buffer and Kernel for node in nodes: - name = node.get_name() - fake_f = get_fake_func(name) + name = node.get_name() + node_type = str(type(node)).split(".")[-1].replace("'>","") + if node_type in func_dict: + fake_f = func_dict[node_type] + else: + fake_f = get_fake_func(node_type) + func_dict[node_type] = fake_f fx_node = graph.call_function(fake_f, args=(), kwargs=None) + fx_node.name = name + + # gather meta data dtype = None if isinstance(node, ir.ComputedBuffer): dtype = node.data.dtype - sizes = node.get_size() - if isinstance(node, ir.ComputedBuffer): - sizes, _ = node.simplify_reorder_and_tile() - elif isinstance(node, ir.ExternKernel): - sizes, _ = node.get_group_stride() - - if isinstance(node, ir.Convolution): - group = group_fn_NHW_C(sizes) - else: - group = group_fn(sizes) + try: + stride = node.get_stride() + layout = type(node.layout) + sizes = node.get_size() + if isinstance(node, ir.ComputedBuffer): + sizes, _ = node.simplify_reorder_and_tile() + elif isinstance(node, ir.ExternKernel): + sizes, _ = node.get_group_stride() + + if isinstance(node, ir.Convolution): + group = group_fn_NHW_C(sizes) + else: + group = group_fn(sizes) + except: + group = torch.Size([0]) - metadata = TensorMetadata(group, dtype, False, node.get_stride(), type(node.layout), None, None) + metadata = TensorMetadata(group, dtype, False, stride, layout, None, None) fx_node.meta["tensor_meta"] = metadata name_to_fx_node[name] = fx_node if first_node is None: first_node = fx_node + # create edges between nodes for node in nodes: name = node.get_name() deps = node.get_reads() @@ -222,6 +293,7 @@ def create_fx_graph(nodes, fname, backend = "triton"): else: with graph.inserting_before(first_node): dep_node = graph.placeholder(dep.name) # assume it's a placeholder if not a computebox + name_to_fx_node[dep.name] = dep_node new_args.append(dep_node) fx_node.args = tuple(new_args) @@ -230,15 +302,20 @@ def create_fx_graph(nodes, fname, backend = "triton"): for _,v in name_to_fx_node.items(): if len(v.users) == 0: outputs.append(v) - graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) - - print(graph) + graph = legalize_graph(graph) + graph.lint() + + # if print_graph: + # print(graph) + print("starting creating module") gm = GraphModule({}, graph) + print(gm) + print("starting drawing") draw_graph(gm, fname, clear_meta=False) -def draw_compute_box(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fname = "image"): +def draw_compute_box(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fname = "image", print_graph = False): """ Dump the graph of a compute box to a file with fname. """ @@ -252,7 +329,7 @@ def draw_compute_box(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor # import pprint # pprint.pprint(graph.buffers) # breakpoint() - create_fx_graph(graph.buffers, fname) + create_fx_graph(graph.buffers, fname, print_graph=print_graph) except Exception: if os.environ.get("TORCHINDUCTOR_DUMP_REPRO") == "1": wrap(functools.partial(dump_to_repro, gm))(*example_inputs) @@ -332,19 +409,50 @@ def is_not_gradout(x): assert static_arg_idxs == list(range(len(static_arg_idxs))) return len(static_arg_idxs) +# def get_input_meta(args): +# input_meta = [] +# if len(args) > 0 and isinstance(args[0], tuple): # joint input +# input_meta += get_input_meta(args[0]) +# input_meta += get_input_meta(args[1]) +# return input_meta +# for arg in args: +# if(type(arg) == int or type(arg) == float): +# input_meta.append((type(arg),)) +# else: +# input_meta.append((type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)) +# return input_meta + +model_name = "alexnet" def compile_fx_aot(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]): """Main entrypoint to a compile given FX graph""" model_ = normalize_ir(model_, example_inputs_) num_example_inputs = len(example_inputs_) + # def fw_compiler(model: torch.fx.GraphModule, example_inputs): + # fixed = len(example_inputs) - num_example_inputs + # return compile_fx_inner(model, example_inputs, num_fixed=fixed) + + # def bw_compiler(model: torch.fx.GraphModule, example_inputs): + # fixed = count_tangents(model) + # return compile_fx_inner(model, example_inputs, num_fixed=fixed) + def fw_compiler(model: torch.fx.GraphModule, example_inputs): - fixed = len(example_inputs) - num_example_inputs - return compile_fx_inner(model, example_inputs, num_fixed=fixed) + # import pickle + # model.graph.set_codegen(torch.fx.graph.CodeGen()) # remove codegen + # model.to_folder("hf_Bert_forward_0") + # input_meta = get_input_meta(example_inputs) + # pickle.dump(input_meta, open("hf_Bert_forward_0/hf_Bert_forward_0.input", "wb")) # noqa: E501 + global model_name + draw_compute_box(model, example_inputs, f"{model_name}_fw", print_graph=True) + return model + def bw_compiler(model: torch.fx.GraphModule, example_inputs): - fixed = count_tangents(model) - return compile_fx_inner(model, example_inputs, num_fixed=fixed) + # print(model) + global model_name + draw_compute_box(model, example_inputs, f"{model_name}_bw", print_graph=True) + return model return aot_autograd( model_, From f7768ae73436a8926ed81af7c3e4726666c533ca Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Sat, 23 Jul 2022 01:18:58 +0000 Subject: [PATCH 05/19] add model name --- torchinductor/compile_fx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index 485e2c4a6b..b433999298 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -310,7 +310,7 @@ def create_fx_graph(nodes, fname, backend = "triton", print_graph = False): # print(graph) print("starting creating module") gm = GraphModule({}, graph) - print(gm) + # print(gm) print("starting drawing") draw_graph(gm, fname, clear_meta=False) @@ -422,7 +422,7 @@ def is_not_gradout(x): # input_meta.append((type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)) # return input_meta -model_name = "alexnet" +model_name = "hf_Bert" def compile_fx_aot(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]): """Main entrypoint to a compile given FX graph""" From c584cdd2fa8df37cc027ee298227850f30eb1489 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Sat, 23 Jul 2022 02:43:18 +0000 Subject: [PATCH 06/19] refactor --- benchmarks/common.py | 40 +++++++++++++++++++++++++++++++--- torchdynamo/convert_frame.py | 6 +++++ torchinductor/compile_fx.py | 28 +++++++++++++++++++----- torchinductor/decomposition.py | 1 + 4 files changed, 66 insertions(+), 9 deletions(-) diff --git a/benchmarks/common.py b/benchmarks/common.py index 992db7c06c..da6f6391d1 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -331,6 +331,30 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs): return format_speedup(speedup, pvalue, is_correct=is_correct) +def dump_experiment(args, model_iter_fn, model, example_inputs): + """ + Run the model to dump the graph + """ + timings = np.zeros((1, 2), np.float64) + # if we randomize the input, we should also check the result is correct + should_check_result = should_randomize_input = args.randomize_input + is_correct = True + + inputs = ( + randomize_input(copy.deepcopy(example_inputs)) + if should_randomize_input + else example_inputs + ) + + with torchdynamo.run(): + timed( + model, model_iter_fn, inputs, return_result=True + ) + + + return current_name + + def overhead_experiment(*args, model_iter_fn): """ Measure overheads of TorchDynamo by running with no backend (only @@ -822,6 +846,11 @@ def parse_args(): action="store_true", help="Use same settings as --inductor for baseline comparisons", ) + parser.add_argument( + "--inductor-dump", + action="store_true", + help="Dump the graphs of computebuffers", + ) parser.add_argument( "--raise-on-assertion-error", action="store_true", @@ -1076,7 +1105,7 @@ def main(runner, original_dir=None): args.isolate = True # TODO(whc) should we move this to a more general part of the script? torch.backends.cuda.matmul.allow_tf32 = True - elif args.inductor or args.inductor_dynamic: + elif args.inductor or args.inductor_dynamic or args.inductor_dump: import torchinductor.config torchinductor.config.debug = args.verbose @@ -1089,8 +1118,13 @@ def main(runner, original_dir=None): else: torchinductor.config.dynamic_shapes = False - optimize_ctx = torchdynamo.optimize("inductor", nopython=args.nopython) - experiment = speedup_experiment + if args.inductor_dump: + optimize_ctx = torchdynamo.optimize("inductor_dump", nopython=args.nopython) + experiment = dump_experiment + output_filename = "inductor_dump.csv" + else: + optimize_ctx = torchdynamo.optimize("inductor", nopython=args.nopython) + experiment = speedup_experiment output_filename = "inductor.csv" elif args.online_autotune: optimize_ctx = torchdynamo.optimize(online_autotuner, nopython=args.nopython) diff --git a/torchdynamo/convert_frame.py b/torchdynamo/convert_frame.py index 6d967ab606..28bd3e05f9 100644 --- a/torchdynamo/convert_frame.py +++ b/torchdynamo/convert_frame.py @@ -83,6 +83,12 @@ def _wrap_compiler_fn(compiler_fn): from torchinductor.compile_fx import compile_fx return compile_fx + + elif compiler_fn == "inductor_dump": + from torchinductor.compile_fx import compile_fx_aot_dump + + return compile_fx_aot_dump + elif isinstance(compiler_fn, str): from .optimizations import BACKENDS diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index b433999298..d897ef030c 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -429,13 +429,29 @@ def compile_fx_aot(model_: torch.fx.GraphModule, example_inputs_: List[torch.Ten model_ = normalize_ir(model_, example_inputs_) num_example_inputs = len(example_inputs_) - # def fw_compiler(model: torch.fx.GraphModule, example_inputs): - # fixed = len(example_inputs) - num_example_inputs - # return compile_fx_inner(model, example_inputs, num_fixed=fixed) + def fw_compiler(model: torch.fx.GraphModule, example_inputs): + fixed = len(example_inputs) - num_example_inputs + return compile_fx_inner(model, example_inputs, num_fixed=fixed) + + def bw_compiler(model: torch.fx.GraphModule, example_inputs): + fixed = count_tangents(model) + return compile_fx_inner(model, example_inputs, num_fixed=fixed) + + + return aot_autograd( + model_, + example_inputs_, + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + decompositions=decompositions, + partition_fn=min_cut_rematerialization_partition, + ) + + +def compile_fx_aot_dump(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]): + """Main entrypoint to a compile given FX graph""" + model_ = normalize_ir(model_, example_inputs_) - # def bw_compiler(model: torch.fx.GraphModule, example_inputs): - # fixed = count_tangents(model) - # return compile_fx_inner(model, example_inputs, num_fixed=fixed) def fw_compiler(model: torch.fx.GraphModule, example_inputs): # import pickle diff --git a/torchinductor/decomposition.py b/torchinductor/decomposition.py index 43aeb88a39..92e621fd52 100644 --- a/torchinductor/decomposition.py +++ b/torchinductor/decomposition.py @@ -67,6 +67,7 @@ aten.threshold_backward, aten.transpose.int, aten.upsample_nearest2d_backward, + aten.lift_fresh_copy, ] ) decompositions.update(aot_autograd_decompositions) From 1e8941ce64d803dac40a52ab593932840614d102 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Sat, 23 Jul 2022 02:48:56 +0000 Subject: [PATCH 07/19] clean --- torchinductor/compile_fx.py | 62 ++++--------------------------------- 1 file changed, 6 insertions(+), 56 deletions(-) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index d897ef030c..86fc0370f4 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -21,6 +21,7 @@ from . import ir from .codegen.cpp import CppScheduling from .codegen.triton import TritonScheduling +from torch.fx.passes.tools_common import legalize_graph from . import config from .decomposition import decompositions @@ -165,58 +166,6 @@ def compile_fx_inner( raise -def legalize_graph(graph: torch.fx.Graph): - """ - Replace the graph of the given GraphModule with one that contains the same nodes as the - original, but in topologically sorted order. - - This is used by the merge_matmul transformation below, which disturbs the topologically sorted - order of its input GraphModule, so that this order is restored before further transformation. - - Arguments: - gm: The graph module to topologically sort. It is modified in-place. - - """ - # Build an adjacency list representation of node dependencies in the graph. This also - # serves as a list of nodes that still need to be inserted into the new, topologically - # sorted graph. - dependencies = {node: node.all_input_nodes.copy() for node in graph.nodes} - - # Construct a new graph that will contain all nodes in topologically sorted order. - new_graph = torch.fx.Graph() - value_remap = {} - - # Copy over all nodes with no dependencies. - for node, deps in dependencies.items(): - if not deps: - value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) - - # Remove the copied over nodes from the adjacency list. - for copied_node in value_remap.keys(): - del dependencies[copied_node] - - # While there are still nodes to insert into the new graph: - while dependencies: - copied_this_round = [] - - # Copy over all nodes whose dependencies already exist in the new graph. - for node, deps in dependencies.items(): - all_deps_copied = True - for dep in deps: - if dep not in value_remap: - all_deps_copied = False - - if all_deps_copied: - value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) - copied_this_round.append(node) - - # Delete all nodes copied over in this iteration from dependencies. - for copied_node in copied_this_round: - del dependencies[copied_node] - - # Replace the old graph with the new, topologically sorted one. - return new_graph - def get_fake_func(name): def func1(*args): @@ -303,13 +252,14 @@ def create_fx_graph(nodes, fname, backend = "triton", print_graph = False): if len(v.users) == 0: outputs.append(v) graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) - graph = legalize_graph(graph) - graph.lint() - # if print_graph: - # print(graph) + + if print_graph: + print(graph) print("starting creating module") gm = GraphModule({}, graph) + graph = legalize_graph(gm) + gm.graph.lint() # print(gm) print("starting drawing") draw_graph(gm, fname, clear_meta=False) From 9d2ee6027a346a278584661580d5aacebf290537 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Sat, 23 Jul 2022 02:58:23 +0000 Subject: [PATCH 08/19] clean --- torchinductor/compile_fx.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index 86fc0370f4..b58d46e4c4 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -359,18 +359,6 @@ def is_not_gradout(x): assert static_arg_idxs == list(range(len(static_arg_idxs))) return len(static_arg_idxs) -# def get_input_meta(args): -# input_meta = [] -# if len(args) > 0 and isinstance(args[0], tuple): # joint input -# input_meta += get_input_meta(args[0]) -# input_meta += get_input_meta(args[1]) -# return input_meta -# for arg in args: -# if(type(arg) == int or type(arg) == float): -# input_meta.append((type(arg),)) -# else: -# input_meta.append((type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)) -# return input_meta model_name = "hf_Bert" @@ -404,20 +392,14 @@ def compile_fx_aot_dump(model_: torch.fx.GraphModule, example_inputs_: List[torc def fw_compiler(model: torch.fx.GraphModule, example_inputs): - # import pickle - # model.graph.set_codegen(torch.fx.graph.CodeGen()) # remove codegen - # model.to_folder("hf_Bert_forward_0") - # input_meta = get_input_meta(example_inputs) - # pickle.dump(input_meta, open("hf_Bert_forward_0/hf_Bert_forward_0.input", "wb")) # noqa: E501 global model_name - draw_compute_box(model, example_inputs, f"{model_name}_fw", print_graph=True) + draw_compute_box(model, example_inputs, f"{model_name}_fw", print_graph=False) return model def bw_compiler(model: torch.fx.GraphModule, example_inputs): - # print(model) global model_name - draw_compute_box(model, example_inputs, f"{model_name}_bw", print_graph=True) + draw_compute_box(model, example_inputs, f"{model_name}_bw", print_graph=False) return model return aot_autograd( From 627356f95b453722ea76631d699282dd191c6505 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Mon, 25 Jul 2022 18:36:47 +0000 Subject: [PATCH 09/19] revert decomposition --- torchinductor/decomposition.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchinductor/decomposition.py b/torchinductor/decomposition.py index 92e621fd52..43aeb88a39 100644 --- a/torchinductor/decomposition.py +++ b/torchinductor/decomposition.py @@ -67,7 +67,6 @@ aten.threshold_backward, aten.transpose.int, aten.upsample_nearest2d_backward, - aten.lift_fresh_copy, ] ) decompositions.update(aot_autograd_decompositions) From b358b7fd626fd523442aecbc2ef8f49ff60e63c3 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 27 Jul 2022 16:53:25 +0000 Subject: [PATCH 10/19] change path --- benchmarks/common.py | 33 +------- torchinductor/compile_fx.py | 153 ------------------------------------ torchinductor/scheduler.py | 107 +++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 184 deletions(-) diff --git a/benchmarks/common.py b/benchmarks/common.py index da6f6391d1..0dc39306dd 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -331,30 +331,6 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs): return format_speedup(speedup, pvalue, is_correct=is_correct) -def dump_experiment(args, model_iter_fn, model, example_inputs): - """ - Run the model to dump the graph - """ - timings = np.zeros((1, 2), np.float64) - # if we randomize the input, we should also check the result is correct - should_check_result = should_randomize_input = args.randomize_input - is_correct = True - - inputs = ( - randomize_input(copy.deepcopy(example_inputs)) - if should_randomize_input - else example_inputs - ) - - with torchdynamo.run(): - timed( - model, model_iter_fn, inputs, return_result=True - ) - - - return current_name - - def overhead_experiment(*args, model_iter_fn): """ Measure overheads of TorchDynamo by running with no backend (only @@ -1118,13 +1094,8 @@ def main(runner, original_dir=None): else: torchinductor.config.dynamic_shapes = False - if args.inductor_dump: - optimize_ctx = torchdynamo.optimize("inductor_dump", nopython=args.nopython) - experiment = dump_experiment - output_filename = "inductor_dump.csv" - else: - optimize_ctx = torchdynamo.optimize("inductor", nopython=args.nopython) - experiment = speedup_experiment + optimize_ctx = torchdynamo.optimize("inductor", nopython=args.nopython) + experiment = speedup_experiment output_filename = "inductor.csv" elif args.online_autotune: optimize_ctx = torchdynamo.optimize(online_autotuner, nopython=args.nopython) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index b58d46e4c4..3f175a505b 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -15,13 +15,6 @@ from torchdynamo.testing import same from torchdynamo.utils import identity from torchdynamo.utils import init_logging -from functorch._src.partitioners import draw_graph -from torch.fx.graph_module import GraphModule -from torch.fx.passes.shape_prop import TensorMetadata -from . import ir -from .codegen.cpp import CppScheduling -from .codegen.triton import TritonScheduling -from torch.fx.passes.tools_common import legalize_graph from . import config from .decomposition import decompositions @@ -167,126 +160,6 @@ def compile_fx_inner( raise -def get_fake_func(name): - def func1(*args): - return 0 - func1.__name__ = name - return func1 - - -def create_fx_graph(nodes, fname, backend = "triton", print_graph = False): - - func_dict = {} - # import pprint - # pprint.pprint(nodes) - name_to_fx_node = {} - graph = torch.fx.Graph() - first_node = None - - if backend == "triton": - group_fn = TritonScheduling(None).group_fn - group_fn_NHW_C = TritonScheduling(None).group_fn_NHW_C - else: - group_fn = CppScheduling(None).group_fn - - # create call_function node for each Buffer and Kernel - for node in nodes: - name = node.get_name() - node_type = str(type(node)).split(".")[-1].replace("'>","") - if node_type in func_dict: - fake_f = func_dict[node_type] - else: - fake_f = get_fake_func(node_type) - func_dict[node_type] = fake_f - fx_node = graph.call_function(fake_f, args=(), kwargs=None) - fx_node.name = name - - # gather meta data - dtype = None - if isinstance(node, ir.ComputedBuffer): - dtype = node.data.dtype - - try: - stride = node.get_stride() - layout = type(node.layout) - sizes = node.get_size() - if isinstance(node, ir.ComputedBuffer): - sizes, _ = node.simplify_reorder_and_tile() - elif isinstance(node, ir.ExternKernel): - sizes, _ = node.get_group_stride() - - if isinstance(node, ir.Convolution): - group = group_fn_NHW_C(sizes) - else: - group = group_fn(sizes) - except: - group = torch.Size([0]) - - metadata = TensorMetadata(group, dtype, False, stride, layout, None, None) - fx_node.meta["tensor_meta"] = metadata - - name_to_fx_node[name] = fx_node - if first_node is None: - first_node = fx_node - - # create edges between nodes - for node in nodes: - name = node.get_name() - deps = node.get_reads() - fx_node = name_to_fx_node[node.name] - - new_args = [] - for dep in deps: - if dep.name in name_to_fx_node: - dep_node = name_to_fx_node[dep.name] - else: - with graph.inserting_before(first_node): - dep_node = graph.placeholder(dep.name) # assume it's a placeholder if not a computebox - name_to_fx_node[dep.name] = dep_node - new_args.append(dep_node) - - fx_node.args = tuple(new_args) - - outputs = [] - for _,v in name_to_fx_node.items(): - if len(v.users) == 0: - outputs.append(v) - graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) - - - if print_graph: - print(graph) - print("starting creating module") - gm = GraphModule({}, graph) - graph = legalize_graph(gm) - gm.graph.lint() - # print(gm) - print("starting drawing") - draw_graph(gm, fname, clear_meta=False) - - -def draw_compute_box(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fname = "image", print_graph = False): - """ - Dump the graph of a compute box to a file with fname. - """ - init_logging() - wrap=identity - - try: - graph = GraphLowering(gm, num_dynamic_inputs=len(example_inputs)) - with V.set_graph_handler(graph): - wrap(graph.run)(*example_inputs) - # import pprint - # pprint.pprint(graph.buffers) - # breakpoint() - create_fx_graph(graph.buffers, fname, print_graph=print_graph) - except Exception: - if os.environ.get("TORCHINDUCTOR_DUMP_REPRO") == "1": - wrap(functools.partial(dump_to_repro, gm))(*example_inputs) - - raise - - def cudagraphify(model, inputs, static_input_idxs=()): """ Assumes inputs[static_input_idxs[i]] are always the same memory address @@ -386,32 +259,6 @@ def bw_compiler(model: torch.fx.GraphModule, example_inputs): ) -def compile_fx_aot_dump(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]): - """Main entrypoint to a compile given FX graph""" - model_ = normalize_ir(model_, example_inputs_) - - - def fw_compiler(model: torch.fx.GraphModule, example_inputs): - global model_name - draw_compute_box(model, example_inputs, f"{model_name}_fw", print_graph=False) - return model - - - def bw_compiler(model: torch.fx.GraphModule, example_inputs): - global model_name - draw_compute_box(model, example_inputs, f"{model_name}_bw", print_graph=False) - return model - - return aot_autograd( - model_, - example_inputs_, - fw_compiler=fw_compiler, - bw_compiler=bw_compiler, - decompositions=decompositions, - partition_fn=min_cut_rematerialization_partition, - ) - - def compile_fx(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]): """Main entrypoint to a compile given FX graph""" if config.aot_autograd: diff --git a/torchinductor/scheduler.py b/torchinductor/scheduler.py index 4e4ab300cf..3db08a4d83 100644 --- a/torchinductor/scheduler.py +++ b/torchinductor/scheduler.py @@ -9,6 +9,7 @@ import numpy as np import torch +import os from . import config from . import dependencies @@ -18,6 +19,10 @@ from .dependencies import StarDep from .sizevars import SimplifyIndexing from .virtualized import V +from torch.fx.passes.tools_common import legalize_graph +from functorch._src.partitioners import draw_graph +from torch.fx.graph_module import GraphModule +from torch.fx.passes.shape_prop import TensorMetadata template_kernels = [ir.Convolution] @@ -409,6 +414,103 @@ def get_name(self): return self.node.get_name() +def get_fake_func(name): + def func1(*args): + return 0 + func1.__name__ = name + return func1 + + +def create_fx_graph(nodes, fname, print_graph = False): + """ + Draw a graph in fname.svg. + + nodes is a list of SchedulerNode objects. + """ + func_dict = {} + name_to_fx_node = {} + graph = torch.fx.Graph() + first_node = None + + # create call_function node for each Buffer and Kernel + for snode in nodes: + node = snode.node + name = node.get_name() + node_type = str(type(node)).split(".")[-1].replace("'>","") + + if node_type in func_dict: + fake_f = func_dict[node_type] + else: + fake_f = get_fake_func(node_type) + func_dict[node_type] = fake_f + fx_node = graph.call_function(fake_f, args=(), kwargs=None) + fx_node.name = name + + # gather meta data + dtype = None + if isinstance(node, ir.ComputedBuffer): + dtype = node.data.dtype + + # try: + stride = node.get_stride() + layout = type(node.layout) + + if isinstance(snode, NopKernelSchedulerNode): + group = "nop" + elif isinstance(snode, ExternKernelSchedulerNode): + if should_use_template(node): + group = snode.group[1] + else: + group = "extern" + else: # SchedulerNode + group = snode.group[1] + # except: + # group = torch.Size([0]) + + metadata = TensorMetadata(group, dtype, False, stride, layout, None, None) + fx_node.meta["tensor_meta"] = metadata + + name_to_fx_node[name] = fx_node + if first_node is None: + first_node = fx_node + + # create edges between nodes + for snode in nodes: + node = snode.node + name = node.get_name() + deps = node.get_reads() + fx_node = name_to_fx_node[name] + + new_args = [] + for dep in deps: + if dep.name in name_to_fx_node: + dep_node = name_to_fx_node[dep.name] + else: + with graph.inserting_before(first_node): + dep_node = graph.placeholder(dep.name) # assume it's a placeholder if not a computebox + name_to_fx_node[dep.name] = dep_node + new_args.append(dep_node) + + fx_node.args = tuple(new_args) + + outputs = [] + for _,v in name_to_fx_node.items(): + if len(v.users) == 0: + outputs.append(v) + graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) + + + if print_graph: + print(graph) + print("starting creating module") + gm = GraphModule({}, graph) + graph = legalize_graph(gm) + gm.graph.lint() + print("starting drawing") + draw_graph(gm, fname, clear_meta=False) + +graph_dump_index = 0 + class Scheduler: def __init__(self, nodes): super(Scheduler, self).__init__() @@ -449,6 +551,11 @@ def __init__(self, nodes): assert False, node self.name_to_node = {node.get_name(): node for node in self.nodes} + if bool(os.environ.get('INDUCTOR_DEBUG', False)): + global graph_dump_index + create_fx_graph(self.nodes, f"compute_buffer_{graph_dump_index}", print_graph=True) + graph_dump_index += 1 + # some new constants could have been created above self.available_buffer_names.update(V.graph.constants.keys()) From ddb863e445a15d3ce4e85102a98f0c07277e4298 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 27 Jul 2022 16:56:36 +0000 Subject: [PATCH 11/19] revert redundant --- benchmarks/common.py | 7 +------ torchdynamo/convert_frame.py | 5 ----- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/benchmarks/common.py b/benchmarks/common.py index 0dc39306dd..992db7c06c 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -822,11 +822,6 @@ def parse_args(): action="store_true", help="Use same settings as --inductor for baseline comparisons", ) - parser.add_argument( - "--inductor-dump", - action="store_true", - help="Dump the graphs of computebuffers", - ) parser.add_argument( "--raise-on-assertion-error", action="store_true", @@ -1081,7 +1076,7 @@ def main(runner, original_dir=None): args.isolate = True # TODO(whc) should we move this to a more general part of the script? torch.backends.cuda.matmul.allow_tf32 = True - elif args.inductor or args.inductor_dynamic or args.inductor_dump: + elif args.inductor or args.inductor_dynamic: import torchinductor.config torchinductor.config.debug = args.verbose diff --git a/torchdynamo/convert_frame.py b/torchdynamo/convert_frame.py index 28bd3e05f9..ddc3312fdc 100644 --- a/torchdynamo/convert_frame.py +++ b/torchdynamo/convert_frame.py @@ -84,11 +84,6 @@ def _wrap_compiler_fn(compiler_fn): return compile_fx - elif compiler_fn == "inductor_dump": - from torchinductor.compile_fx import compile_fx_aot_dump - - return compile_fx_aot_dump - elif isinstance(compiler_fn, str): from .optimizations import BACKENDS From d51fa1966d2cb5803e2568feb0f54c136e6ac9c8 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 27 Jul 2022 16:57:35 +0000 Subject: [PATCH 12/19] revert --- torchdynamo/convert_frame.py | 1 - torchinductor/compile_fx.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/torchdynamo/convert_frame.py b/torchdynamo/convert_frame.py index ddc3312fdc..6d967ab606 100644 --- a/torchdynamo/convert_frame.py +++ b/torchdynamo/convert_frame.py @@ -83,7 +83,6 @@ def _wrap_compiler_fn(compiler_fn): from torchinductor.compile_fx import compile_fx return compile_fx - elif isinstance(compiler_fn, str): from .optimizations import BACKENDS diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index 3f175a505b..91345d1192 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -233,8 +233,6 @@ def is_not_gradout(x): return len(static_arg_idxs) -model_name = "hf_Bert" - def compile_fx_aot(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]): """Main entrypoint to a compile given FX graph""" model_ = normalize_ir(model_, example_inputs_) @@ -248,7 +246,6 @@ def bw_compiler(model: torch.fx.GraphModule, example_inputs): fixed = count_tangents(model) return compile_fx_inner(model, example_inputs, num_fixed=fixed) - return aot_autograd( model_, example_inputs_, From d1c816aea973f4060ef9a963edd8a53830f37439 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 27 Jul 2022 17:03:49 +0000 Subject: [PATCH 13/19] lint --- torchinductor/scheduler.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/torchinductor/scheduler.py b/torchinductor/scheduler.py index 3db08a4d83..110814c5fe 100644 --- a/torchinductor/scheduler.py +++ b/torchinductor/scheduler.py @@ -417,11 +417,12 @@ def get_name(self): def get_fake_func(name): def func1(*args): return 0 + func1.__name__ = name return func1 -def create_fx_graph(nodes, fname, print_graph = False): +def create_fx_graph(nodes, fname, print_graph=False): """ Draw a graph in fname.svg. @@ -435,8 +436,8 @@ def create_fx_graph(nodes, fname, print_graph = False): # create call_function node for each Buffer and Kernel for snode in nodes: node = snode.node - name = node.get_name() - node_type = str(type(node)).split(".")[-1].replace("'>","") + name = node.get_name() + node_type = str(type(node)).split(".")[-1].replace("'>", "") if node_type in func_dict: fake_f = func_dict[node_type] @@ -451,7 +452,6 @@ def create_fx_graph(nodes, fname, print_graph = False): if isinstance(node, ir.ComputedBuffer): dtype = node.data.dtype - # try: stride = node.get_stride() layout = type(node.layout) @@ -464,9 +464,7 @@ def create_fx_graph(nodes, fname, print_graph = False): group = "extern" else: # SchedulerNode group = snode.group[1] - # except: - # group = torch.Size([0]) - + metadata = TensorMetadata(group, dtype, False, stride, layout, None, None) fx_node.meta["tensor_meta"] = metadata @@ -477,29 +475,28 @@ def create_fx_graph(nodes, fname, print_graph = False): # create edges between nodes for snode in nodes: node = snode.node - name = node.get_name() + name = node.get_name() deps = node.get_reads() fx_node = name_to_fx_node[name] - + new_args = [] for dep in deps: if dep.name in name_to_fx_node: dep_node = name_to_fx_node[dep.name] else: with graph.inserting_before(first_node): - dep_node = graph.placeholder(dep.name) # assume it's a placeholder if not a computebox + dep_node = graph.placeholder(dep.name) name_to_fx_node[dep.name] = dep_node new_args.append(dep_node) fx_node.args = tuple(new_args) - + outputs = [] - for _,v in name_to_fx_node.items(): + for _, v in name_to_fx_node.items(): if len(v.users) == 0: outputs.append(v) graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) - - + if print_graph: print(graph) print("starting creating module") @@ -509,8 +506,10 @@ def create_fx_graph(nodes, fname, print_graph = False): print("starting drawing") draw_graph(gm, fname, clear_meta=False) + graph_dump_index = 0 + class Scheduler: def __init__(self, nodes): super(Scheduler, self).__init__() @@ -551,9 +550,11 @@ def __init__(self, nodes): assert False, node self.name_to_node = {node.get_name(): node for node in self.nodes} - if bool(os.environ.get('INDUCTOR_DEBUG', False)): + if bool(os.environ.get("INDUCTOR_DEBUG", False)): global graph_dump_index - create_fx_graph(self.nodes, f"compute_buffer_{graph_dump_index}", print_graph=True) + create_fx_graph( + self.nodes, f"compute_buffer_{graph_dump_index}", print_graph=True + ) graph_dump_index += 1 # some new constants could have been created above From c535eb0a022f65a51c3d5748c93aa3cfe7a083d0 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 27 Jul 2022 17:13:18 +0000 Subject: [PATCH 14/19] isort lint --- torchinductor/scheduler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchinductor/scheduler.py b/torchinductor/scheduler.py index 110814c5fe..7c67e0a626 100644 --- a/torchinductor/scheduler.py +++ b/torchinductor/scheduler.py @@ -3,13 +3,17 @@ import dataclasses import functools import itertools +import os from typing import Any from typing import Dict from typing import List import numpy as np import torch -import os +from functorch._src.partitioners import draw_graph +from torch.fx.graph_module import GraphModule +from torch.fx.passes.shape_prop import TensorMetadata +from torch.fx.passes.tools_common import legalize_graph from . import config from . import dependencies @@ -19,10 +23,6 @@ from .dependencies import StarDep from .sizevars import SimplifyIndexing from .virtualized import V -from torch.fx.passes.tools_common import legalize_graph -from functorch._src.partitioners import draw_graph -from torch.fx.graph_module import GraphModule -from torch.fx.passes.shape_prop import TensorMetadata template_kernels = [ir.Convolution] From 512c206f3f296ec36380b93e47d8d8143ee729b8 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 27 Jul 2022 17:15:56 +0000 Subject: [PATCH 15/19] resolve with main --- torchinductor/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchinductor/scheduler.py b/torchinductor/scheduler.py index 7c67e0a626..2d644229cc 100644 --- a/torchinductor/scheduler.py +++ b/torchinductor/scheduler.py @@ -3,6 +3,7 @@ import dataclasses import functools import itertools +import logging import os from typing import Any from typing import Dict From 4480aad46531a9942a0d00a3eca6c2797b9adbc6 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 27 Jul 2022 20:54:15 +0000 Subject: [PATCH 16/19] change import and env var name --- torchinductor/scheduler.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torchinductor/scheduler.py b/torchinductor/scheduler.py index 00b2bf5eed..fddaae19c3 100644 --- a/torchinductor/scheduler.py +++ b/torchinductor/scheduler.py @@ -11,10 +11,6 @@ import numpy as np import torch -from functorch._src.partitioners import draw_graph -from torch.fx.graph_module import GraphModule -from torch.fx.passes.shape_prop import TensorMetadata -from torch.fx.passes.tools_common import legalize_graph from . import config from . import dependencies @@ -437,6 +433,12 @@ def create_fx_graph(nodes, fname, print_graph=False): nodes is a list of SchedulerNode objects. """ + + from functorch._src.partitioners import draw_graph + from torch.fx.graph_module import GraphModule + from torch.fx.passes.shape_prop import TensorMetadata + from torch.fx.passes.tools_common import legalize_graph + func_dict = {} name_to_fx_node = {} graph = torch.fx.Graph() @@ -559,7 +561,7 @@ def __init__(self, nodes): assert False, node self.name_to_node = {node.get_name(): node for node in self.nodes} - if bool(os.environ.get("INDUCTOR_DEBUG", False)): + if bool(os.environ.get("INDUCTOR_SCHEDULER_GRAPH", 0)==1): global graph_dump_index create_fx_graph( self.nodes, f"compute_buffer_{graph_dump_index}", print_graph=True From b24d2ba44eea02a1959b199f8201c795e02ab5bd Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 27 Jul 2022 21:38:49 +0000 Subject: [PATCH 17/19] use functorch's global variable --- benchmarks/common.py | 4 ++++ torchinductor/scheduler.py | 14 +++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/benchmarks/common.py b/benchmarks/common.py index 1ed5f03b9e..467ca45ad5 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -1414,10 +1414,14 @@ def main(runner, original_dir=None): current_name = name current_device = device + from functorch._src.aot_autograd import set_model_name + set_model_name(name) + if args.float32: model, example_inputs = cast_to_fp32(model, example_inputs) elif args.float16: model, example_inputs = cast_to_fp16(model, example_inputs) + runner.run_one_model( name, diff --git a/torchinductor/scheduler.py b/torchinductor/scheduler.py index fddaae19c3..92dcaf2088 100644 --- a/torchinductor/scheduler.py +++ b/torchinductor/scheduler.py @@ -26,6 +26,9 @@ log = logging.getLogger(__name__) +INDUCTOR_SCHEDULER_GRAPH = bool(os.environ.get("INDUCTOR_SCHEDULER_GRAPH", None)=='1') + + def cmp(a, b): return int(a > b) - int(a < b) @@ -518,9 +521,6 @@ def create_fx_graph(nodes, fname, print_graph=False): draw_graph(gm, fname, clear_meta=False) -graph_dump_index = 0 - - class Scheduler: def __init__(self, nodes): super(Scheduler, self).__init__() @@ -561,12 +561,12 @@ def __init__(self, nodes): assert False, node self.name_to_node = {node.get_name(): node for node in self.nodes} - if bool(os.environ.get("INDUCTOR_SCHEDULER_GRAPH", 0)==1): - global graph_dump_index + if INDUCTOR_SCHEDULER_GRAPH: + from functorch._src.aot_autograd import get_graph_being_compiled + graph_name = get_graph_being_compiled() create_fx_graph( - self.nodes, f"compute_buffer_{graph_dump_index}", print_graph=True + self.nodes, graph_name, print_graph=True ) - graph_dump_index += 1 # some new constants could have been created above self.available_buffer_names.update(V.graph.constants.keys()) From 89e458b1bf476efaad968476f4e05c82f3ab9167 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 27 Jul 2022 22:21:13 +0000 Subject: [PATCH 18/19] lint --- benchmarks/common.py | 2 +- torchinductor/scheduler.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/benchmarks/common.py b/benchmarks/common.py index 467ca45ad5..e6e0c97102 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -1415,13 +1415,13 @@ def main(runner, original_dir=None): current_name = name current_device = device from functorch._src.aot_autograd import set_model_name + set_model_name(name) if args.float32: model, example_inputs = cast_to_fp32(model, example_inputs) elif args.float16: model, example_inputs = cast_to_fp16(model, example_inputs) - runner.run_one_model( name, diff --git a/torchinductor/scheduler.py b/torchinductor/scheduler.py index 92dcaf2088..7d0291c496 100644 --- a/torchinductor/scheduler.py +++ b/torchinductor/scheduler.py @@ -26,7 +26,7 @@ log = logging.getLogger(__name__) -INDUCTOR_SCHEDULER_GRAPH = bool(os.environ.get("INDUCTOR_SCHEDULER_GRAPH", None)=='1') +INDUCTOR_SCHEDULER_GRAPH = bool(os.environ.get("INDUCTOR_SCHEDULER_GRAPH", None) == "1") def cmp(a, b): @@ -563,10 +563,9 @@ def __init__(self, nodes): if INDUCTOR_SCHEDULER_GRAPH: from functorch._src.aot_autograd import get_graph_being_compiled + graph_name = get_graph_being_compiled() - create_fx_graph( - self.nodes, graph_name, print_graph=True - ) + create_fx_graph(self.nodes, graph_name, print_graph=True) # some new constants could have been created above self.available_buffer_names.update(V.graph.constants.keys()) From 336bd81c4f4a52aef10f835572da662e5b0590ed Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 28 Jul 2022 01:04:57 +0000 Subject: [PATCH 19/19] rename --- torchinductor/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchinductor/scheduler.py b/torchinductor/scheduler.py index 7d0291c496..b77612ffa8 100644 --- a/torchinductor/scheduler.py +++ b/torchinductor/scheduler.py @@ -430,7 +430,7 @@ def func1(*args): return func1 -def create_fx_graph(nodes, fname, print_graph=False): +def create_fx_from_buffers(nodes, fname, print_graph=False): """ Draw a graph in fname.svg. @@ -565,7 +565,7 @@ def __init__(self, nodes): from functorch._src.aot_autograd import get_graph_being_compiled graph_name = get_graph_being_compiled() - create_fx_graph(self.nodes, graph_name, print_graph=True) + create_fx_from_buffers(self.nodes, graph_name, print_graph=True) # some new constants could have been created above self.available_buffer_names.update(V.graph.constants.keys())