Skip to content

Commit

Permalink
typing scheduler.py [2/2]: Apply types (#126656)
Browse files Browse the repository at this point in the history
Add `# mypy: disallow-untyped-defs` to scheduler.py and then fix the resulting fallout.

We probably should eventually add a new node between BaseSchedulerNode and all the non-FusedSchedulerNode types to indicate the split between nodes that have a valid `self.node` and ones that don't. That would cause a lot of the `assert self.node is not None` churn to go away - but was a bigger change because a lot of code makes assumptions about types that aren't reflected in the types themselves.

Pull Request resolved: #126656
Approved by: https://github.com/eellison
  • Loading branch information
aorenste authored and pytorchmergebot committed May 22, 2024
1 parent 3591bce commit e4623de
Show file tree
Hide file tree
Showing 13 changed files with 385 additions and 239 deletions.
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3247,7 +3247,7 @@ def run(kernel):
for node in nodes:
if node.group[1] in [
(group, reduction_group),
(group + reduction_group, ()),
(tuple(itertools.chain(group, reduction_group)), ()),
]:
assert not in_suffix
node.run(vars, reduction_vars)
Expand Down
3 changes: 2 additions & 1 deletion torch/_inductor/codegen/cpp_wrapper_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,12 +1273,13 @@ def generate_user_defined_triton_kernel(
break
assert grid_decision is not None

current_device = V.graph.scheduler.get_current_device_or_throw()
self.generate_kernel_call(
kernel_name,
args,
arg_types=arg_types,
grid=grid_decision,
device_index=V.graph.scheduler.current_device.index,
device_index=current_device.index,
cuda=True,
triton=True,
triton_meta=triton_meta,
Expand Down
6 changes: 4 additions & 2 deletions torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import cast, List
from typing import cast, Sequence

from ...._dynamo.utils import counters

Expand Down Expand Up @@ -73,7 +73,9 @@ def define_kernel(self, src_code: str, node_schedule) -> str:
return kernel_name

def codegen_template(
self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode]
self,
template_node: BaseSchedulerNode,
epilogue_nodes: Sequence[BaseSchedulerNode],
):
"""
Codegen a CUDA template, possibly with fused epilogues
Expand Down
3 changes: 2 additions & 1 deletion torch/_inductor/codegen/cuda/cuda_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,11 @@ def call_kernel(
else:
call_args.append("None")

current_device = V.graph.scheduler.get_current_device_or_throw()
wrapper.generate_kernel_call(
name,
call_args,
device_index=V.graph.scheduler.current_device.index,
device_index=current_device.index,
cuda=True,
triton=False,
arg_types=arg_types,
Expand Down
6 changes: 4 additions & 2 deletions torch/_inductor/codegen/cuda_combined_scheduling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import Sequence, Union

from ..scheduler import (
BaseSchedulerNode,
Expand Down Expand Up @@ -50,7 +50,9 @@ def group_fn(self, sizes):
return self._triton_scheduling.group_fn(sizes)

def codegen_template(
self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode]
self,
template_node: BaseSchedulerNode,
epilogue_nodes: Sequence[BaseSchedulerNode],
):
if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
assert epilogue_nodes is None or len(epilogue_nodes) == 0
Expand Down
3 changes: 2 additions & 1 deletion torch/_inductor/codegen/multi_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,12 @@ def call_kernel(self, kernel_name):
)

grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid)
current_device = V.graph.scheduler.get_current_device_or_throw()
V.graph.wrapper_code.generate_kernel_call(
kernel_name,
final_call_args,
grid,
V.graph.scheduler.current_device.index,
current_device.index,
arg_types=arg_types,
)

Expand Down
6 changes: 4 additions & 2 deletions torch/_inductor/codegen/simd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Expand Down Expand Up @@ -557,7 +558,7 @@ def set_ranges(self, *lengths):

@staticmethod
def _split_iteration_ranges(
groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]]
groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]]
):
sv = V.graph.sizevars
new_ranges: List[List[sympy.Expr]] = [[] for _ in groups]
Expand Down Expand Up @@ -625,7 +626,7 @@ def getter(flat_vars):

@classmethod
def is_compatible(
cls, groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]]
cls, groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]]
):
try:
cls._split_iteration_ranges(groups, lengths)
Expand Down Expand Up @@ -1544,6 +1545,7 @@ def codegen_node_schedule(
name = node.get_name()
if name not in live_outs:
continue
assert node.node is not None
origin_node = node.node.get_origin_node()
if origin_node is not None:
counters["inductor"]["intermediate_hooks"] += 1
Expand Down
14 changes: 8 additions & 6 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,7 +1784,8 @@ def codegen_kernel_benchmark(self, num_gb, grid=None):
grid_arg = f"{extra_args_str}grid=grid({', '.join(grid)})"
else:
grid_arg = f"grid={grid}"
index = V.graph.scheduler.current_device.index
current_device = V.graph.scheduler.get_current_device_or_throw()
index = current_device.index
with result.indent():
result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
with result.indent():
Expand Down Expand Up @@ -1948,7 +1949,9 @@ def codegen_kernel(self, name=None):
)
triton_meta = {
"signature": triton_meta_signature,
"device": DeviceProperties.create(V.graph.scheduler.current_device),
"device": DeviceProperties.create(
V.graph.scheduler.get_current_device_or_throw()
),
"constants": {},
}

Expand Down Expand Up @@ -2115,7 +2118,7 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None):
call_args, arg_types = self.get_call_args()
grid: List[Any] = []
self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid)
current_device = V.graph.scheduler.current_device
current_device = V.graph.scheduler.get_current_device_or_throw()

if self.args.workspace_arg is not None:
ws = self.args.workspace_arg
Expand Down Expand Up @@ -2225,9 +2228,8 @@ def define_kernel(self, src_code, node_schedule, kernel):
compile_wrapper = IndentedBuffer()
compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''")
compile_wrapper.splice(src_code, strip=True)
compile_wrapper.writeline(
f"''', device_str='{V.graph.scheduler.current_device.type}')"
)
current_device = V.graph.scheduler.get_current_device_or_throw()
compile_wrapper.writeline(f"''', device_str='{current_device.type}')")

metadata_comment = f"# kernel path: {kernel_path}"
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
Expand Down
11 changes: 6 additions & 5 deletions torch/_inductor/codegen/triton_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def jit_lines(self):
_, _, signature, _ = self.args.python_argdefs()
triton_meta = {
"signature": signature_to_meta(signature, size_dtype=size_dtype),
"device": DeviceProperties.create(V.graph.scheduler.current_device),
"device": DeviceProperties.create(
V.graph.scheduler.get_current_device_or_throw()
),
"constants": {},
}
triton_meta["configs"] = [config_of(signature)]
Expand Down Expand Up @@ -230,20 +232,19 @@ def call_kernel(self, code, name: str):
for i in range(len(call_args)):
if V.graph.is_unspec_arg(call_args[i]):
call_args[i] = call_args[i] + ".item()"
current_device = V.graph.scheduler.get_current_device_or_throw()
if V.graph.cpp_wrapper:
V.graph.wrapper_code.generate_kernel_call(
name,
call_args,
device_index=V.graph.scheduler.current_device.index,
device_index=current_device.index,
grid=self.grid(),
arg_types=arg_types,
)
else:
# TODO: refactor generate_kernel_call
call_args_str = ", ".join(call_args)
stream_name = code.write_get_raw_stream(
V.graph.scheduler.current_device.index
)
stream_name = code.write_get_raw_stream(current_device.index)
code.writeline(
f"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})"
)
19 changes: 9 additions & 10 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,9 +685,8 @@ def generate_user_defined_triton_kernel(
for line in code.split("\n"):
self.writeline(line)

stream_name = self.write_get_raw_stream(
V.graph.scheduler.current_device.index, V.graph
)
current_device = V.graph.scheduler.get_current_device_or_throw()
stream_name = self.write_get_raw_stream(current_device.index, V.graph)
self.writeline(
f"{kernel_name}.run({', '.join(args)}, grid={grid}, stream={stream_name})"
)
Expand Down Expand Up @@ -1149,7 +1148,9 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
size_dtype=index_dtype,
indices=non_constant_indices,
),
"device": DeviceProperties.create(V.graph.scheduler.current_device),
"device": DeviceProperties.create(
V.graph.scheduler.get_current_device_or_throw()
),
# Triton compiler includes equal_to_1 args into constants even
# when they are not constexpr. otherwise there may be a segfault
# during launching the Inductor-compiled Triton kernel.
Expand Down Expand Up @@ -1270,9 +1271,8 @@ def traverse(cur_kernel):

traverse(kernel)

compile_wrapper.writeline(
f"''', device_str='{V.graph.scheduler.current_device.type}')"
)
current_device = V.graph.scheduler.get_current_device_or_throw()
compile_wrapper.writeline(f"''', device_str='{current_device.type}')")
_, lineno = inspect.getsourcelines(kernel.fn)
srcfile = inspect.getsourcefile(kernel.fn)
metadata = f"# Original path: {srcfile}:{lineno}"
Expand Down Expand Up @@ -1387,9 +1387,8 @@ def generate_kernel_call(
"""
if cuda:
call_args_str = ", ".join(pexpr(item) for item in call_args)
stream_name = self.write_get_raw_stream(
V.graph.scheduler.current_device.index, V.graph
)
current_device = V.graph.scheduler.get_current_device_or_throw()
stream_name = self.write_get_raw_stream(current_device.index, V.graph)
if triton:
grid_str = ", ".join(pexpr(item) for item in grid)
grid_str = f"{grid_fn}({grid_str})"
Expand Down

0 comments on commit e4623de

Please sign in to comment.