diff --git a/torchinductor/scheduler.py b/torchinductor/scheduler.py index 6f4067b5ba..049aafb246 100644 --- a/torchinductor/scheduler.py +++ b/torchinductor/scheduler.py @@ -4,6 +4,7 @@ import functools import itertools import logging +import os from typing import Any from typing import Dict from typing import List @@ -24,6 +25,8 @@ log = logging.getLogger(__name__) +INDUCTOR_SCHEDULER_GRAPH = bool(os.environ.get("INDUCTOR_SCHEDULER_GRAPH", None) == "1") + def cmp(a, b): return int(a > b) - int(a < b) @@ -418,6 +421,108 @@ def get_name(self): return self.node.get_name() +def get_fake_func(name): + def func1(*args): + return 0 + + func1.__name__ = name + return func1 + + +def create_fx_from_buffers(nodes, fname, print_graph=False): + """ + Draw a graph in fname.svg. + nodes is a list of SchedulerNode objects. + """ + + from functorch._src.partitioners import draw_graph + from torch.fx.graph_module import GraphModule + from torch.fx.passes.shape_prop import TensorMetadata + from torch.fx.passes.tools_common import legalize_graph + + func_dict = {} + name_to_fx_node = {} + graph = torch.fx.Graph() + first_node = None + + # create call_function node for each Buffer and Kernel + for snode in nodes: + node = snode.node + name = node.get_name() + node_type = str(type(node)).split(".")[-1].replace("'>", "") + + if node_type in func_dict: + fake_f = func_dict[node_type] + else: + fake_f = get_fake_func(node_type) + func_dict[node_type] = fake_f + fx_node = graph.call_function(fake_f, args=(), kwargs=None) + fx_node.name = name + + # gather meta data + dtype = None + if isinstance(node, ir.ComputedBuffer): + dtype = node.data.dtype + + try: + stride = node.get_stride() + except AttributeError: + stride = None + + layout = type(node.layout) + + if isinstance(snode, NopKernelSchedulerNode): + group = "nop" + elif isinstance(snode, ExternKernelSchedulerNode): + if should_use_template(node): + group = snode.group[1] + else: + group = "extern" + else: # SchedulerNode + group = snode.group[1] + + metadata = TensorMetadata(group, dtype, False, stride, layout, None, None) + fx_node.meta["tensor_meta"] = metadata + + name_to_fx_node[name] = fx_node + if first_node is None: + first_node = fx_node + + # create edges between nodes + for snode in nodes: + node = snode.node + name = node.get_name() + deps = node.get_reads() + fx_node = name_to_fx_node[name] + + new_args = [] + for dep in deps: + if dep.name in name_to_fx_node: + dep_node = name_to_fx_node[dep.name] + else: + with graph.inserting_before(first_node): + dep_node = graph.placeholder(dep.name) + name_to_fx_node[dep.name] = dep_node + new_args.append(dep_node) + + fx_node.args = tuple(new_args) + + outputs = [] + for _, v in name_to_fx_node.items(): + if len(v.users) == 0: + outputs.append(v) + graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) + + if print_graph: + print(graph) + print("starting creating module") + gm = GraphModule({}, graph) + graph = legalize_graph(gm) + gm.graph.lint() + print("starting drawing") + draw_graph(gm, fname, clear_meta=False) + + class Scheduler: def __init__(self, nodes): super(Scheduler, self).__init__() @@ -458,6 +563,21 @@ def __init__(self, nodes): assert False, node self.name_to_node = {node.get_name(): node for node in self.nodes} + if INDUCTOR_SCHEDULER_GRAPH: + + try: + from functorch._src.aot_autograd import get_graph_being_compiled + + graph_name = get_graph_being_compiled() + except ImportError: + logging.warning( + "Could not get graph name from `get_graph_being_compiled` \ + in functorch, use 'model' as default" + ) + graph_name = "model" + + create_fx_from_buffers(self.nodes, graph_name, print_graph=True) + # some new constants could have been created above self.available_buffer_names.update(V.graph.constants.keys())