Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions benchmarks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,10 @@ def main(runner, original_dir=None):

current_name = name
current_device = device
from functorch._src.aot_autograd import set_model_name

set_model_name(name)

if args.float32:
model, example_inputs = cast_to_fp32(model, example_inputs)
elif args.float16:
Expand Down
109 changes: 109 additions & 0 deletions torchinductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import itertools
import logging
import os
from typing import Any
from typing import Dict
from typing import List
Expand All @@ -25,6 +26,9 @@
log = logging.getLogger(__name__)


INDUCTOR_SCHEDULER_GRAPH = bool(os.environ.get("INDUCTOR_SCHEDULER_GRAPH", None) == "1")


def cmp(a, b):
return int(a > b) - int(a < b)

Expand Down Expand Up @@ -418,6 +422,105 @@ def get_name(self):
return self.node.get_name()


def get_fake_func(name):
def func1(*args):
return 0

func1.__name__ = name
return func1


def create_fx_from_buffers(nodes, fname, print_graph=False):
"""
Draw a graph in fname.svg.

nodes is a list of SchedulerNode objects.
"""

from functorch._src.partitioners import draw_graph
from torch.fx.graph_module import GraphModule
from torch.fx.passes.shape_prop import TensorMetadata
from torch.fx.passes.tools_common import legalize_graph

func_dict = {}
name_to_fx_node = {}
graph = torch.fx.Graph()
first_node = None

# create call_function node for each Buffer and Kernel
for snode in nodes:
node = snode.node
name = node.get_name()
node_type = str(type(node)).split(".")[-1].replace("'>", "")

if node_type in func_dict:
fake_f = func_dict[node_type]
else:
fake_f = get_fake_func(node_type)
func_dict[node_type] = fake_f
fx_node = graph.call_function(fake_f, args=(), kwargs=None)
fx_node.name = name

# gather meta data
dtype = None
if isinstance(node, ir.ComputedBuffer):
dtype = node.data.dtype

stride = node.get_stride()
layout = type(node.layout)

if isinstance(snode, NopKernelSchedulerNode):
group = "nop"
elif isinstance(snode, ExternKernelSchedulerNode):
if should_use_template(node):
group = snode.group[1]
else:
group = "extern"
else: # SchedulerNode
group = snode.group[1]

metadata = TensorMetadata(group, dtype, False, stride, layout, None, None)
fx_node.meta["tensor_meta"] = metadata

name_to_fx_node[name] = fx_node
if first_node is None:
first_node = fx_node

# create edges between nodes
for snode in nodes:
node = snode.node
name = node.get_name()
deps = node.get_reads()
fx_node = name_to_fx_node[name]

new_args = []
for dep in deps:
if dep.name in name_to_fx_node:
dep_node = name_to_fx_node[dep.name]
else:
with graph.inserting_before(first_node):
dep_node = graph.placeholder(dep.name)
name_to_fx_node[dep.name] = dep_node
new_args.append(dep_node)

fx_node.args = tuple(new_args)

outputs = []
for _, v in name_to_fx_node.items():
if len(v.users) == 0:
outputs.append(v)
graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))

if print_graph:
print(graph)
print("starting creating module")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stray prints?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is within create_fx_graph, which is for debugging purposes, so i think it's fine.

gm = GraphModule({}, graph)
graph = legalize_graph(gm)
gm.graph.lint()
print("starting drawing")
draw_graph(gm, fname, clear_meta=False)


class Scheduler:
def __init__(self, nodes):
super(Scheduler, self).__init__()
Expand Down Expand Up @@ -458,6 +561,12 @@ def __init__(self, nodes):
assert False, node
self.name_to_node = {node.get_name(): node for node in self.nodes}

if INDUCTOR_SCHEDULER_GRAPH:
from functorch._src.aot_autograd import get_graph_being_compiled

graph_name = get_graph_being_compiled()
create_fx_from_buffers(self.nodes, graph_name, print_graph=True)

# some new constants could have been created above
self.available_buffer_names.update(V.graph.constants.keys())

Expand Down