Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions torchinductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import itertools
import logging
import os
from typing import Any
from typing import Dict
from typing import List
Expand All @@ -24,6 +25,8 @@

log = logging.getLogger(__name__)

INDUCTOR_SCHEDULER_GRAPH = bool(os.environ.get("INDUCTOR_SCHEDULER_GRAPH", None) == "1")


def cmp(a, b):
return int(a > b) - int(a < b)
Expand Down Expand Up @@ -418,6 +421,108 @@ def get_name(self):
return self.node.get_name()


def get_fake_func(name):
def func1(*args):
return 0

func1.__name__ = name
return func1


def create_fx_from_buffers(nodes, fname, print_graph=False):
"""
Draw a graph in fname.svg.
nodes is a list of SchedulerNode objects.
"""

from functorch._src.partitioners import draw_graph
from torch.fx.graph_module import GraphModule
from torch.fx.passes.shape_prop import TensorMetadata
from torch.fx.passes.tools_common import legalize_graph

func_dict = {}
name_to_fx_node = {}
graph = torch.fx.Graph()
first_node = None

# create call_function node for each Buffer and Kernel
for snode in nodes:
node = snode.node
name = node.get_name()
node_type = str(type(node)).split(".")[-1].replace("'>", "")

if node_type in func_dict:
fake_f = func_dict[node_type]
else:
fake_f = get_fake_func(node_type)
func_dict[node_type] = fake_f
fx_node = graph.call_function(fake_f, args=(), kwargs=None)
fx_node.name = name

# gather meta data
dtype = None
if isinstance(node, ir.ComputedBuffer):
dtype = node.data.dtype

try:
stride = node.get_stride()
except AttributeError:
stride = None

layout = type(node.layout)

if isinstance(snode, NopKernelSchedulerNode):
group = "nop"
elif isinstance(snode, ExternKernelSchedulerNode):
if should_use_template(node):
group = snode.group[1]
else:
group = "extern"
else: # SchedulerNode
group = snode.group[1]

metadata = TensorMetadata(group, dtype, False, stride, layout, None, None)
fx_node.meta["tensor_meta"] = metadata

name_to_fx_node[name] = fx_node
if first_node is None:
first_node = fx_node

# create edges between nodes
for snode in nodes:
node = snode.node
name = node.get_name()
deps = node.get_reads()
fx_node = name_to_fx_node[name]

new_args = []
for dep in deps:
if dep.name in name_to_fx_node:
dep_node = name_to_fx_node[dep.name]
else:
with graph.inserting_before(first_node):
dep_node = graph.placeholder(dep.name)
name_to_fx_node[dep.name] = dep_node
new_args.append(dep_node)

fx_node.args = tuple(new_args)

outputs = []
for _, v in name_to_fx_node.items():
if len(v.users) == 0:
outputs.append(v)
graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))

if print_graph:
print(graph)
print("starting creating module")
gm = GraphModule({}, graph)
graph = legalize_graph(gm)
gm.graph.lint()
print("starting drawing")
draw_graph(gm, fname, clear_meta=False)


class Scheduler:
def __init__(self, nodes):
super(Scheduler, self).__init__()
Expand Down Expand Up @@ -458,6 +563,21 @@ def __init__(self, nodes):
assert False, node
self.name_to_node = {node.get_name(): node for node in self.nodes}

if INDUCTOR_SCHEDULER_GRAPH:

try:
from functorch._src.aot_autograd import get_graph_being_compiled

graph_name = get_graph_being_compiled()
except ImportError:
logging.warning(
"Could not get graph name from `get_graph_being_compiled` \
in functorch, use 'model' as default"
)
graph_name = "model"

create_fx_from_buffers(self.nodes, graph_name, print_graph=True)

# some new constants could have been created above
self.available_buffer_names.update(V.graph.constants.keys())

Expand Down