From 95f0b1eefc8cb9782b008b9cf7a51bb9c25592f0 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Fri, 22 Jul 2022 00:15:08 +0000 Subject: [PATCH 1/9] add dump graph utils --- torchinductor/compile_fx.py | 81 +++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index 91345d1192..ab267b1861 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -15,6 +15,9 @@ from torchdynamo.testing import same from torchdynamo.utils import identity from torchdynamo.utils import init_logging +from functorch._src.partitioners import draw_graph +from torch.fx.graph_module import GraphModule + from . import config from .decomposition import decompositions @@ -160,6 +163,84 @@ def compile_fx_inner( raise +class FakeModule(object): + def __init__(self, _name): + self.__name__ = _name + + +def get_fake_func(name): + def func1(*args): + return 0 + func1.__name__ = name + return func1 + + +def create_fx_graph(nodes, fname): + name_to_fx_node = {} + fake_root = {} + graph = torch.fx.Graph() + first_node = None + for node in nodes: + name = node.get_name() + # fx_node = graph.call_module(name, args=(), kwargs=None) + fake_f = get_fake_func(name) + fx_node = graph.call_function(fake_f, args=(), kwargs=None) + fake_root[name] = torch.fx.GraphModule({}, torch.fx.Graph()) + name_to_fx_node[name] = fx_node + if first_node is None: + first_node = fx_node + + for node in nodes: + name = node.get_name() + deps = node.get_reads() + fx_node = name_to_fx_node[node.name] + + new_args = [] + for dep in deps: + if dep.name in name_to_fx_node: + dep_node = name_to_fx_node[dep.name] + else: + fake_root[dep.name] = FakeModule(dep.name) + with graph.inserting_before(first_node): + dep_node = graph.placeholder(dep.name) # assume it's a placeholder if not a computebox + new_args.append(dep_node) + + fx_node.args = tuple(new_args) + + outputs = [] + for k,v in name_to_fx_node.items(): + if len(v.users) == 0: + outputs.append(v) + + graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) + + print(graph) + gm = GraphModule({}, graph) + draw_graph(gm, fname) + + +def draw_compute_box(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fname = "image"): + """ + Dump the graph of a compute box to a file with fname. + """ + init_logging() + wrap=identity + + try: + graph = GraphLowering(gm, num_dynamic_inputs=len(example_inputs)) + with V.set_graph_handler(graph): + wrap(graph.run)(*example_inputs) + # import pprint + # pprint.pprint(graph.buffers) + # breakpoint() + create_fx_graph(graph.buffers, fname) + except Exception: + if os.environ.get("TORCHINDUCTOR_DUMP_REPRO") == "1": + wrap(functools.partial(dump_to_repro, gm))(*example_inputs) + + raise + + def cudagraphify(model, inputs, static_input_idxs=()): """ Assumes inputs[static_input_idxs[i]] are always the same memory address From f4dee87046d3eae114b4c78d617485b8048c5b85 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Fri, 22 Jul 2022 00:19:11 +0000 Subject: [PATCH 2/9] cleanup --- torchinductor/compile_fx.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index ab267b1861..6e8c6ff910 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -163,11 +163,6 @@ def compile_fx_inner( raise -class FakeModule(object): - def __init__(self, _name): - self.__name__ = _name - - def get_fake_func(name): def func1(*args): return 0 @@ -177,15 +172,12 @@ def func1(*args): def create_fx_graph(nodes, fname): name_to_fx_node = {} - fake_root = {} graph = torch.fx.Graph() first_node = None for node in nodes: name = node.get_name() - # fx_node = graph.call_module(name, args=(), kwargs=None) fake_f = get_fake_func(name) fx_node = graph.call_function(fake_f, args=(), kwargs=None) - fake_root[name] = torch.fx.GraphModule({}, torch.fx.Graph()) name_to_fx_node[name] = fx_node if first_node is None: first_node = fx_node @@ -200,7 +192,6 @@ def create_fx_graph(nodes, fname): if dep.name in name_to_fx_node: dep_node = name_to_fx_node[dep.name] else: - fake_root[dep.name] = FakeModule(dep.name) with graph.inserting_before(first_node): dep_node = graph.placeholder(dep.name) # assume it's a placeholder if not a computebox new_args.append(dep_node) @@ -208,7 +199,7 @@ def create_fx_graph(nodes, fname): fx_node.args = tuple(new_args) outputs = [] - for k,v in name_to_fx_node.items(): + for _,v in name_to_fx_node.items(): if len(v.users) == 0: outputs.append(v) From 761b8d79c5fe458046f3f1d2a851444a2ea9c90a Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Fri, 22 Jul 2022 01:13:06 +0000 Subject: [PATCH 3/9] add meta data --- torchinductor/compile_fx.py | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index 6e8c6ff910..379566dd7c 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -17,7 +17,10 @@ from torchdynamo.utils import init_logging from functorch._src.partitioners import draw_graph from torch.fx.graph_module import GraphModule - +from torch.fx.passes.shape_prop import TensorMetadata +from . import ir +from .codegen.cpp import CppScheduling +from .codegen.triton import TritonScheduling from . import config from .decomposition import decompositions @@ -170,14 +173,39 @@ def func1(*args): return func1 -def create_fx_graph(nodes, fname): +def create_fx_graph(nodes, fname, backend = "triton"): name_to_fx_node = {} graph = torch.fx.Graph() first_node = None + + if backend == "triton": + group_fn = TritonScheduling(None).group_fn + group_fn_NHW_C = TritonScheduling(None).group_fn_NHW_C + else: + group_fn = CppScheduling(None).group_fn + for node in nodes: name = node.get_name() fake_f = get_fake_func(name) fx_node = graph.call_function(fake_f, args=(), kwargs=None) + dtype = None + if isinstance(node, ir.ComputedBuffer): + dtype = node.data.dtype + + sizes = node.get_size() + if isinstance(node, ir.ComputedBuffer): + sizes, _ = node.simplify_reorder_and_tile() + elif isinstance(node, ir.ExternKernel): + sizes, _ = node.get_group_stride() + + if isinstance(node, ir.Convolution): + group = group_fn_NHW_C(sizes) + else: + group = group_fn(sizes) + + metadata = TensorMetadata(group, dtype, False, node.get_stride(), type(node.layout), None, None) + fx_node.meta["tensor_meta"] = metadata + name_to_fx_node[name] = fx_node if first_node is None: first_node = fx_node @@ -207,7 +235,7 @@ def create_fx_graph(nodes, fname): print(graph) gm = GraphModule({}, graph) - draw_graph(gm, fname) + draw_graph(gm, fname, clear_meta=False) def draw_compute_box(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fname = "image"): From cafe8d8a672eab8aa9b8957b6132015e4e0a52a7 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Fri, 22 Jul 2022 23:45:55 +0000 Subject: [PATCH 4/9] add node type to target --- torchinductor/compile_fx.py | 154 ++++++++++++++++++++++++++++++------ 1 file changed, 131 insertions(+), 23 deletions(-) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index 379566dd7c..485e2c4a6b 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -165,6 +165,58 @@ def compile_fx_inner( raise +def legalize_graph(graph: torch.fx.Graph): + """ + Replace the graph of the given GraphModule with one that contains the same nodes as the + original, but in topologically sorted order. + + This is used by the merge_matmul transformation below, which disturbs the topologically sorted + order of its input GraphModule, so that this order is restored before further transformation. + + Arguments: + gm: The graph module to topologically sort. It is modified in-place. + + """ + # Build an adjacency list representation of node dependencies in the graph. This also + # serves as a list of nodes that still need to be inserted into the new, topologically + # sorted graph. + dependencies = {node: node.all_input_nodes.copy() for node in graph.nodes} + + # Construct a new graph that will contain all nodes in topologically sorted order. + new_graph = torch.fx.Graph() + value_remap = {} + + # Copy over all nodes with no dependencies. + for node, deps in dependencies.items(): + if not deps: + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) + + # Remove the copied over nodes from the adjacency list. + for copied_node in value_remap.keys(): + del dependencies[copied_node] + + # While there are still nodes to insert into the new graph: + while dependencies: + copied_this_round = [] + + # Copy over all nodes whose dependencies already exist in the new graph. + for node, deps in dependencies.items(): + all_deps_copied = True + for dep in deps: + if dep not in value_remap: + all_deps_copied = False + + if all_deps_copied: + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) + copied_this_round.append(node) + + # Delete all nodes copied over in this iteration from dependencies. + for copied_node in copied_this_round: + del dependencies[copied_node] + + # Replace the old graph with the new, topologically sorted one. + return new_graph + def get_fake_func(name): def func1(*args): @@ -173,7 +225,11 @@ def func1(*args): return func1 -def create_fx_graph(nodes, fname, backend = "triton"): +def create_fx_graph(nodes, fname, backend = "triton", print_graph = False): + + func_dict = {} + # import pprint + # pprint.pprint(nodes) name_to_fx_node = {} graph = torch.fx.Graph() first_node = None @@ -184,32 +240,47 @@ def create_fx_graph(nodes, fname, backend = "triton"): else: group_fn = CppScheduling(None).group_fn + # create call_function node for each Buffer and Kernel for node in nodes: - name = node.get_name() - fake_f = get_fake_func(name) + name = node.get_name() + node_type = str(type(node)).split(".")[-1].replace("'>","") + if node_type in func_dict: + fake_f = func_dict[node_type] + else: + fake_f = get_fake_func(node_type) + func_dict[node_type] = fake_f fx_node = graph.call_function(fake_f, args=(), kwargs=None) + fx_node.name = name + + # gather meta data dtype = None if isinstance(node, ir.ComputedBuffer): dtype = node.data.dtype - sizes = node.get_size() - if isinstance(node, ir.ComputedBuffer): - sizes, _ = node.simplify_reorder_and_tile() - elif isinstance(node, ir.ExternKernel): - sizes, _ = node.get_group_stride() - - if isinstance(node, ir.Convolution): - group = group_fn_NHW_C(sizes) - else: - group = group_fn(sizes) + try: + stride = node.get_stride() + layout = type(node.layout) + sizes = node.get_size() + if isinstance(node, ir.ComputedBuffer): + sizes, _ = node.simplify_reorder_and_tile() + elif isinstance(node, ir.ExternKernel): + sizes, _ = node.get_group_stride() + + if isinstance(node, ir.Convolution): + group = group_fn_NHW_C(sizes) + else: + group = group_fn(sizes) + except: + group = torch.Size([0]) - metadata = TensorMetadata(group, dtype, False, node.get_stride(), type(node.layout), None, None) + metadata = TensorMetadata(group, dtype, False, stride, layout, None, None) fx_node.meta["tensor_meta"] = metadata name_to_fx_node[name] = fx_node if first_node is None: first_node = fx_node + # create edges between nodes for node in nodes: name = node.get_name() deps = node.get_reads() @@ -222,6 +293,7 @@ def create_fx_graph(nodes, fname, backend = "triton"): else: with graph.inserting_before(first_node): dep_node = graph.placeholder(dep.name) # assume it's a placeholder if not a computebox + name_to_fx_node[dep.name] = dep_node new_args.append(dep_node) fx_node.args = tuple(new_args) @@ -230,15 +302,20 @@ def create_fx_graph(nodes, fname, backend = "triton"): for _,v in name_to_fx_node.items(): if len(v.users) == 0: outputs.append(v) - graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) - - print(graph) + graph = legalize_graph(graph) + graph.lint() + + # if print_graph: + # print(graph) + print("starting creating module") gm = GraphModule({}, graph) + print(gm) + print("starting drawing") draw_graph(gm, fname, clear_meta=False) -def draw_compute_box(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fname = "image"): +def draw_compute_box(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fname = "image", print_graph = False): """ Dump the graph of a compute box to a file with fname. """ @@ -252,7 +329,7 @@ def draw_compute_box(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor # import pprint # pprint.pprint(graph.buffers) # breakpoint() - create_fx_graph(graph.buffers, fname) + create_fx_graph(graph.buffers, fname, print_graph=print_graph) except Exception: if os.environ.get("TORCHINDUCTOR_DUMP_REPRO") == "1": wrap(functools.partial(dump_to_repro, gm))(*example_inputs) @@ -332,19 +409,50 @@ def is_not_gradout(x): assert static_arg_idxs == list(range(len(static_arg_idxs))) return len(static_arg_idxs) +# def get_input_meta(args): +# input_meta = [] +# if len(args) > 0 and isinstance(args[0], tuple): # joint input +# input_meta += get_input_meta(args[0]) +# input_meta += get_input_meta(args[1]) +# return input_meta +# for arg in args: +# if(type(arg) == int or type(arg) == float): +# input_meta.append((type(arg),)) +# else: +# input_meta.append((type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)) +# return input_meta + +model_name = "alexnet" def compile_fx_aot(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]): """Main entrypoint to a compile given FX graph""" model_ = normalize_ir(model_, example_inputs_) num_example_inputs = len(example_inputs_) + # def fw_compiler(model: torch.fx.GraphModule, example_inputs): + # fixed = len(example_inputs) - num_example_inputs + # return compile_fx_inner(model, example_inputs, num_fixed=fixed) + + # def bw_compiler(model: torch.fx.GraphModule, example_inputs): + # fixed = count_tangents(model) + # return compile_fx_inner(model, example_inputs, num_fixed=fixed) + def fw_compiler(model: torch.fx.GraphModule, example_inputs): - fixed = len(example_inputs) - num_example_inputs - return compile_fx_inner(model, example_inputs, num_fixed=fixed) + # import pickle + # model.graph.set_codegen(torch.fx.graph.CodeGen()) # remove codegen + # model.to_folder("hf_Bert_forward_0") + # input_meta = get_input_meta(example_inputs) + # pickle.dump(input_meta, open("hf_Bert_forward_0/hf_Bert_forward_0.input", "wb")) # noqa: E501 + global model_name + draw_compute_box(model, example_inputs, f"{model_name}_fw", print_graph=True) + return model + def bw_compiler(model: torch.fx.GraphModule, example_inputs): - fixed = count_tangents(model) - return compile_fx_inner(model, example_inputs, num_fixed=fixed) + # print(model) + global model_name + draw_compute_box(model, example_inputs, f"{model_name}_bw", print_graph=True) + return model return aot_autograd( model_, From f7768ae73436a8926ed81af7c3e4726666c533ca Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Sat, 23 Jul 2022 01:18:58 +0000 Subject: [PATCH 5/9] add model name --- torchinductor/compile_fx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index 485e2c4a6b..b433999298 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -310,7 +310,7 @@ def create_fx_graph(nodes, fname, backend = "triton", print_graph = False): # print(graph) print("starting creating module") gm = GraphModule({}, graph) - print(gm) + # print(gm) print("starting drawing") draw_graph(gm, fname, clear_meta=False) @@ -422,7 +422,7 @@ def is_not_gradout(x): # input_meta.append((type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)) # return input_meta -model_name = "alexnet" +model_name = "hf_Bert" def compile_fx_aot(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]): """Main entrypoint to a compile given FX graph""" From c584cdd2fa8df37cc027ee298227850f30eb1489 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Sat, 23 Jul 2022 02:43:18 +0000 Subject: [PATCH 6/9] refactor --- benchmarks/common.py | 40 +++++++++++++++++++++++++++++++--- torchdynamo/convert_frame.py | 6 +++++ torchinductor/compile_fx.py | 28 +++++++++++++++++++----- torchinductor/decomposition.py | 1 + 4 files changed, 66 insertions(+), 9 deletions(-) diff --git a/benchmarks/common.py b/benchmarks/common.py index 992db7c06c..da6f6391d1 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -331,6 +331,30 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs): return format_speedup(speedup, pvalue, is_correct=is_correct) +def dump_experiment(args, model_iter_fn, model, example_inputs): + """ + Run the model to dump the graph + """ + timings = np.zeros((1, 2), np.float64) + # if we randomize the input, we should also check the result is correct + should_check_result = should_randomize_input = args.randomize_input + is_correct = True + + inputs = ( + randomize_input(copy.deepcopy(example_inputs)) + if should_randomize_input + else example_inputs + ) + + with torchdynamo.run(): + timed( + model, model_iter_fn, inputs, return_result=True + ) + + + return current_name + + def overhead_experiment(*args, model_iter_fn): """ Measure overheads of TorchDynamo by running with no backend (only @@ -822,6 +846,11 @@ def parse_args(): action="store_true", help="Use same settings as --inductor for baseline comparisons", ) + parser.add_argument( + "--inductor-dump", + action="store_true", + help="Dump the graphs of computebuffers", + ) parser.add_argument( "--raise-on-assertion-error", action="store_true", @@ -1076,7 +1105,7 @@ def main(runner, original_dir=None): args.isolate = True # TODO(whc) should we move this to a more general part of the script? torch.backends.cuda.matmul.allow_tf32 = True - elif args.inductor or args.inductor_dynamic: + elif args.inductor or args.inductor_dynamic or args.inductor_dump: import torchinductor.config torchinductor.config.debug = args.verbose @@ -1089,8 +1118,13 @@ def main(runner, original_dir=None): else: torchinductor.config.dynamic_shapes = False - optimize_ctx = torchdynamo.optimize("inductor", nopython=args.nopython) - experiment = speedup_experiment + if args.inductor_dump: + optimize_ctx = torchdynamo.optimize("inductor_dump", nopython=args.nopython) + experiment = dump_experiment + output_filename = "inductor_dump.csv" + else: + optimize_ctx = torchdynamo.optimize("inductor", nopython=args.nopython) + experiment = speedup_experiment output_filename = "inductor.csv" elif args.online_autotune: optimize_ctx = torchdynamo.optimize(online_autotuner, nopython=args.nopython) diff --git a/torchdynamo/convert_frame.py b/torchdynamo/convert_frame.py index 6d967ab606..28bd3e05f9 100644 --- a/torchdynamo/convert_frame.py +++ b/torchdynamo/convert_frame.py @@ -83,6 +83,12 @@ def _wrap_compiler_fn(compiler_fn): from torchinductor.compile_fx import compile_fx return compile_fx + + elif compiler_fn == "inductor_dump": + from torchinductor.compile_fx import compile_fx_aot_dump + + return compile_fx_aot_dump + elif isinstance(compiler_fn, str): from .optimizations import BACKENDS diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index b433999298..d897ef030c 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -429,13 +429,29 @@ def compile_fx_aot(model_: torch.fx.GraphModule, example_inputs_: List[torch.Ten model_ = normalize_ir(model_, example_inputs_) num_example_inputs = len(example_inputs_) - # def fw_compiler(model: torch.fx.GraphModule, example_inputs): - # fixed = len(example_inputs) - num_example_inputs - # return compile_fx_inner(model, example_inputs, num_fixed=fixed) + def fw_compiler(model: torch.fx.GraphModule, example_inputs): + fixed = len(example_inputs) - num_example_inputs + return compile_fx_inner(model, example_inputs, num_fixed=fixed) + + def bw_compiler(model: torch.fx.GraphModule, example_inputs): + fixed = count_tangents(model) + return compile_fx_inner(model, example_inputs, num_fixed=fixed) + + + return aot_autograd( + model_, + example_inputs_, + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + decompositions=decompositions, + partition_fn=min_cut_rematerialization_partition, + ) + + +def compile_fx_aot_dump(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]): + """Main entrypoint to a compile given FX graph""" + model_ = normalize_ir(model_, example_inputs_) - # def bw_compiler(model: torch.fx.GraphModule, example_inputs): - # fixed = count_tangents(model) - # return compile_fx_inner(model, example_inputs, num_fixed=fixed) def fw_compiler(model: torch.fx.GraphModule, example_inputs): # import pickle diff --git a/torchinductor/decomposition.py b/torchinductor/decomposition.py index 43aeb88a39..92e621fd52 100644 --- a/torchinductor/decomposition.py +++ b/torchinductor/decomposition.py @@ -67,6 +67,7 @@ aten.threshold_backward, aten.transpose.int, aten.upsample_nearest2d_backward, + aten.lift_fresh_copy, ] ) decompositions.update(aot_autograd_decompositions) From 1e8941ce64d803dac40a52ab593932840614d102 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Sat, 23 Jul 2022 02:48:56 +0000 Subject: [PATCH 7/9] clean --- torchinductor/compile_fx.py | 62 ++++--------------------------------- 1 file changed, 6 insertions(+), 56 deletions(-) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index d897ef030c..86fc0370f4 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -21,6 +21,7 @@ from . import ir from .codegen.cpp import CppScheduling from .codegen.triton import TritonScheduling +from torch.fx.passes.tools_common import legalize_graph from . import config from .decomposition import decompositions @@ -165,58 +166,6 @@ def compile_fx_inner( raise -def legalize_graph(graph: torch.fx.Graph): - """ - Replace the graph of the given GraphModule with one that contains the same nodes as the - original, but in topologically sorted order. - - This is used by the merge_matmul transformation below, which disturbs the topologically sorted - order of its input GraphModule, so that this order is restored before further transformation. - - Arguments: - gm: The graph module to topologically sort. It is modified in-place. - - """ - # Build an adjacency list representation of node dependencies in the graph. This also - # serves as a list of nodes that still need to be inserted into the new, topologically - # sorted graph. - dependencies = {node: node.all_input_nodes.copy() for node in graph.nodes} - - # Construct a new graph that will contain all nodes in topologically sorted order. - new_graph = torch.fx.Graph() - value_remap = {} - - # Copy over all nodes with no dependencies. - for node, deps in dependencies.items(): - if not deps: - value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) - - # Remove the copied over nodes from the adjacency list. - for copied_node in value_remap.keys(): - del dependencies[copied_node] - - # While there are still nodes to insert into the new graph: - while dependencies: - copied_this_round = [] - - # Copy over all nodes whose dependencies already exist in the new graph. - for node, deps in dependencies.items(): - all_deps_copied = True - for dep in deps: - if dep not in value_remap: - all_deps_copied = False - - if all_deps_copied: - value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) - copied_this_round.append(node) - - # Delete all nodes copied over in this iteration from dependencies. - for copied_node in copied_this_round: - del dependencies[copied_node] - - # Replace the old graph with the new, topologically sorted one. - return new_graph - def get_fake_func(name): def func1(*args): @@ -303,13 +252,14 @@ def create_fx_graph(nodes, fname, backend = "triton", print_graph = False): if len(v.users) == 0: outputs.append(v) graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) - graph = legalize_graph(graph) - graph.lint() - # if print_graph: - # print(graph) + + if print_graph: + print(graph) print("starting creating module") gm = GraphModule({}, graph) + graph = legalize_graph(gm) + gm.graph.lint() # print(gm) print("starting drawing") draw_graph(gm, fname, clear_meta=False) From 9d2ee6027a346a278584661580d5aacebf290537 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Sat, 23 Jul 2022 02:58:23 +0000 Subject: [PATCH 8/9] clean --- torchinductor/compile_fx.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index 86fc0370f4..b58d46e4c4 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -359,18 +359,6 @@ def is_not_gradout(x): assert static_arg_idxs == list(range(len(static_arg_idxs))) return len(static_arg_idxs) -# def get_input_meta(args): -# input_meta = [] -# if len(args) > 0 and isinstance(args[0], tuple): # joint input -# input_meta += get_input_meta(args[0]) -# input_meta += get_input_meta(args[1]) -# return input_meta -# for arg in args: -# if(type(arg) == int or type(arg) == float): -# input_meta.append((type(arg),)) -# else: -# input_meta.append((type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)) -# return input_meta model_name = "hf_Bert" @@ -404,20 +392,14 @@ def compile_fx_aot_dump(model_: torch.fx.GraphModule, example_inputs_: List[torc def fw_compiler(model: torch.fx.GraphModule, example_inputs): - # import pickle - # model.graph.set_codegen(torch.fx.graph.CodeGen()) # remove codegen - # model.to_folder("hf_Bert_forward_0") - # input_meta = get_input_meta(example_inputs) - # pickle.dump(input_meta, open("hf_Bert_forward_0/hf_Bert_forward_0.input", "wb")) # noqa: E501 global model_name - draw_compute_box(model, example_inputs, f"{model_name}_fw", print_graph=True) + draw_compute_box(model, example_inputs, f"{model_name}_fw", print_graph=False) return model def bw_compiler(model: torch.fx.GraphModule, example_inputs): - # print(model) global model_name - draw_compute_box(model, example_inputs, f"{model_name}_bw", print_graph=True) + draw_compute_box(model, example_inputs, f"{model_name}_bw", print_graph=False) return model return aot_autograd( From 627356f95b453722ea76631d699282dd191c6505 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Mon, 25 Jul 2022 18:36:47 +0000 Subject: [PATCH 9/9] revert decomposition --- torchinductor/decomposition.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchinductor/decomposition.py b/torchinductor/decomposition.py index 92e621fd52..43aeb88a39 100644 --- a/torchinductor/decomposition.py +++ b/torchinductor/decomposition.py @@ -67,7 +67,6 @@ aten.threshold_backward, aten.transpose.int, aten.upsample_nearest2d_backward, - aten.lift_fresh_copy, ] ) decompositions.update(aot_autograd_decompositions)