diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index d4e314ebff0f2..c7ade8459ca2a 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3247,7 +3247,7 @@ def run(kernel): for node in nodes: if node.group[1] in [ (group, reduction_group), - (group + reduction_group, ()), + (tuple(itertools.chain(group, reduction_group)), ()), ]: assert not in_suffix node.run(vars, reduction_vars) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index a8b3d73765d9e..18a6c9967f2be 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1273,12 +1273,13 @@ def generate_user_defined_triton_kernel( break assert grid_decision is not None + current_device = V.graph.scheduler.get_current_device_or_throw() self.generate_kernel_call( kernel_name, args, arg_types=arg_types, grid=grid_decision, - device_index=V.graph.scheduler.current_device.index, + device_index=current_device.index, cuda=True, triton=True, triton_meta=triton_meta, diff --git a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py index da058fac50284..5c91736e9abde 100644 --- a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +++ b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -1,5 +1,5 @@ import logging -from typing import cast, List +from typing import cast, Sequence from ...._dynamo.utils import counters @@ -73,7 +73,9 @@ def define_kernel(self, src_code: str, node_schedule) -> str: return kernel_name def codegen_template( - self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode] + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], ): """ Codegen a CUDA template, possibly with fused epilogues diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index dabe66caac24c..8cad41082d64e 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -176,10 +176,11 @@ def call_kernel( else: call_args.append("None") + current_device = V.graph.scheduler.get_current_device_or_throw() wrapper.generate_kernel_call( name, call_args, - device_index=V.graph.scheduler.current_device.index, + device_index=current_device.index, cuda=True, triton=False, arg_types=arg_types, diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py index b238cd805c032..f7be73c247fdc 100644 --- a/torch/_inductor/codegen/cuda_combined_scheduling.py +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import Sequence, Union from ..scheduler import ( BaseSchedulerNode, @@ -50,7 +50,9 @@ def group_fn(self, sizes): return self._triton_scheduling.group_fn(sizes) def codegen_template( - self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode] + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], ): if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node): assert epilogue_nodes is None or len(epilogue_nodes) == 0 diff --git a/torch/_inductor/codegen/multi_kernel.py b/torch/_inductor/codegen/multi_kernel.py index 6c7b2925f1b19..8b4dbb1790160 100644 --- a/torch/_inductor/codegen/multi_kernel.py +++ b/torch/_inductor/codegen/multi_kernel.py @@ -222,11 +222,12 @@ def call_kernel(self, kernel_name): ) grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid) + current_device = V.graph.scheduler.get_current_device_or_throw() V.graph.wrapper_code.generate_kernel_call( kernel_name, final_call_args, grid, - V.graph.scheduler.current_device.index, + current_device.index, arg_types=arg_types, ) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index a1a1ddb6a286d..22eb25b32ced3 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -17,6 +17,7 @@ Iterable, List, Optional, + Sequence, Set, Tuple, TYPE_CHECKING, @@ -557,7 +558,7 @@ def set_ranges(self, *lengths): @staticmethod def _split_iteration_ranges( - groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]] + groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]] ): sv = V.graph.sizevars new_ranges: List[List[sympy.Expr]] = [[] for _ in groups] @@ -625,7 +626,7 @@ def getter(flat_vars): @classmethod def is_compatible( - cls, groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]] + cls, groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]] ): try: cls._split_iteration_ranges(groups, lengths) @@ -1544,6 +1545,7 @@ def codegen_node_schedule( name = node.get_name() if name not in live_outs: continue + assert node.node is not None origin_node = node.node.get_origin_node() if origin_node is not None: counters["inductor"]["intermediate_hooks"] += 1 diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 0617c21c8eaa1..874e661fee861 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1784,7 +1784,8 @@ def codegen_kernel_benchmark(self, num_gb, grid=None): grid_arg = f"{extra_args_str}grid=grid({', '.join(grid)})" else: grid_arg = f"grid={grid}" - index = V.graph.scheduler.current_device.index + current_device = V.graph.scheduler.get_current_device_or_throw() + index = current_device.index with result.indent(): result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") with result.indent(): @@ -1948,7 +1949,9 @@ def codegen_kernel(self, name=None): ) triton_meta = { "signature": triton_meta_signature, - "device": DeviceProperties.create(V.graph.scheduler.current_device), + "device": DeviceProperties.create( + V.graph.scheduler.get_current_device_or_throw() + ), "constants": {}, } @@ -2115,7 +2118,7 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None): call_args, arg_types = self.get_call_args() grid: List[Any] = [] self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid) - current_device = V.graph.scheduler.current_device + current_device = V.graph.scheduler.get_current_device_or_throw() if self.args.workspace_arg is not None: ws = self.args.workspace_arg @@ -2225,9 +2228,8 @@ def define_kernel(self, src_code, node_schedule, kernel): compile_wrapper = IndentedBuffer() compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''") compile_wrapper.splice(src_code, strip=True) - compile_wrapper.writeline( - f"''', device_str='{V.graph.scheduler.current_device.type}')" - ) + current_device = V.graph.scheduler.get_current_device_or_throw() + compile_wrapper.writeline(f"''', device_str='{current_device.type}')") metadata_comment = f"# kernel path: {kernel_path}" origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) diff --git a/torch/_inductor/codegen/triton_foreach.py b/torch/_inductor/codegen/triton_foreach.py index a676b87ba64e3..8ed909ec823aa 100644 --- a/torch/_inductor/codegen/triton_foreach.py +++ b/torch/_inductor/codegen/triton_foreach.py @@ -158,7 +158,9 @@ def jit_lines(self): _, _, signature, _ = self.args.python_argdefs() triton_meta = { "signature": signature_to_meta(signature, size_dtype=size_dtype), - "device": DeviceProperties.create(V.graph.scheduler.current_device), + "device": DeviceProperties.create( + V.graph.scheduler.get_current_device_or_throw() + ), "constants": {}, } triton_meta["configs"] = [config_of(signature)] @@ -230,20 +232,19 @@ def call_kernel(self, code, name: str): for i in range(len(call_args)): if V.graph.is_unspec_arg(call_args[i]): call_args[i] = call_args[i] + ".item()" + current_device = V.graph.scheduler.get_current_device_or_throw() if V.graph.cpp_wrapper: V.graph.wrapper_code.generate_kernel_call( name, call_args, - device_index=V.graph.scheduler.current_device.index, + device_index=current_device.index, grid=self.grid(), arg_types=arg_types, ) else: # TODO: refactor generate_kernel_call call_args_str = ", ".join(call_args) - stream_name = code.write_get_raw_stream( - V.graph.scheduler.current_device.index - ) + stream_name = code.write_get_raw_stream(current_device.index) code.writeline( f"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})" ) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 0f61e7b3e7995..5e92a5e18278f 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -685,9 +685,8 @@ def generate_user_defined_triton_kernel( for line in code.split("\n"): self.writeline(line) - stream_name = self.write_get_raw_stream( - V.graph.scheduler.current_device.index, V.graph - ) + current_device = V.graph.scheduler.get_current_device_or_throw() + stream_name = self.write_get_raw_stream(current_device.index, V.graph) self.writeline( f"{kernel_name}.run({', '.join(args)}, grid={grid}, stream={stream_name})" ) @@ -1149,7 +1148,9 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs): size_dtype=index_dtype, indices=non_constant_indices, ), - "device": DeviceProperties.create(V.graph.scheduler.current_device), + "device": DeviceProperties.create( + V.graph.scheduler.get_current_device_or_throw() + ), # Triton compiler includes equal_to_1 args into constants even # when they are not constexpr. otherwise there may be a segfault # during launching the Inductor-compiled Triton kernel. @@ -1270,9 +1271,8 @@ def traverse(cur_kernel): traverse(kernel) - compile_wrapper.writeline( - f"''', device_str='{V.graph.scheduler.current_device.type}')" - ) + current_device = V.graph.scheduler.get_current_device_or_throw() + compile_wrapper.writeline(f"''', device_str='{current_device.type}')") _, lineno = inspect.getsourcelines(kernel.fn) srcfile = inspect.getsourcefile(kernel.fn) metadata = f"# Original path: {srcfile}:{lineno}" @@ -1387,9 +1387,8 @@ def generate_kernel_call( """ if cuda: call_args_str = ", ".join(pexpr(item) for item in call_args) - stream_name = self.write_get_raw_stream( - V.graph.scheduler.current_device.index, V.graph - ) + current_device = V.graph.scheduler.get_current_device_or_throw() + stream_name = self.write_get_raw_stream(current_device.index, V.graph) if triton: grid_str = ", ".join(pexpr(item) for item in grid) grid_str = f"{grid_fn}({grid_str})" diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 1c9b8dc31be33..453be88c56bea 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -1,3 +1,4 @@ +# mypy: disallow-untyped-defs import collections import dataclasses import functools @@ -8,6 +9,7 @@ import os import pprint import textwrap +import typing from typing import ( Any, Counter, @@ -47,6 +49,7 @@ get_device_tflops, get_dtype_size, get_gpu_dram_gbps, + IndentedBuffer, is_collective, is_gpu, is_wait, @@ -70,18 +73,18 @@ def __init__(self, node1: "BaseSchedulerNode", node2: "BaseSchedulerNode"): self.node1 = node1 self.node2 = node2 - def __call__(self, reason, *args): + def __call__(self, reason: str, *args: Any) -> None: self.reason = reason self.args = args fusion_log.debug(self) - def __str__(self): + def __str__(self) -> str: return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + ( self.reason % self.args ) -def pformat(obj): +def pformat(obj: Any) -> str: if isinstance(obj, set): # pformat has trouble with sets of sympy exprs obj = sorted(obj, key=str) @@ -92,23 +95,25 @@ def pformat(obj): class OutputNode: - def __init__(self, dep): + def __init__(self, dep: StarDep) -> None: self.unmet_dependencies = {dep} - self.inverse_users = [] + self.inverse_users: List[BaseSchedulerNode] = [] - def is_reduction(self): + def is_reduction(self) -> bool: return False - def get_inputs_that_alias_output(self): + def get_inputs_that_alias_output(self) -> Sequence[str]: return () - def get_name(self): + def get_name(self) -> str: return "OUTPUT" __repr__ = get_name -def _prune_redundant_deps(node, name_to_fused_node): +def _prune_redundant_deps( + node: "BaseSchedulerNode", name_to_fused_node: Dict[str, "BaseSchedulerNode"] +) -> None: """ Prunes weakdeps intended for mutation ordering on an upstream fused node if after fusion there is another dependency @@ -123,7 +128,7 @@ def _prune_redundant_deps(node, name_to_fused_node): if not isinstance(dep, WeakDep): name_to_dep_count[name_to_fused_node[dep.name].get_name()] += 1 - def should_prune(dep): + def should_prune(dep: Dep) -> bool: if isinstance(dep, WeakDep): is_redundant = ( name_to_dep_count[name_to_fused_node[dep.name].get_name()] > 0 @@ -153,9 +158,11 @@ def should_prune(dep): class BaseSchedulerNode: - def __init__(self, scheduler: "Scheduler", node: ir.Buffer): + group: Tuple[torch.device, Sequence[Sequence[sympy.Expr]]] + + def __init__(self, scheduler: "Scheduler", node: ir.Buffer) -> None: self.scheduler: Scheduler = scheduler - self.node: ir.Buffer = node + self.node: Optional[ir.Buffer] = node self.users: List[NodeUser] = [] self.inverse_users: List[BaseSchedulerNode] = [] self.node_users: List[BaseSchedulerNode] = [] @@ -168,7 +175,7 @@ def __init__(self, scheduler: "Scheduler", node: ir.Buffer): ] = set() # buffers that won't be used after this kernel self.written = False - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}(name={self.get_name()!r})" def debug_str(self) -> str: @@ -193,7 +200,7 @@ def debug_str(self) -> str: def debug_str_extra(self) -> str: return "" - def log_details(self): + def log_details(self) -> None: log.info( "%s: unmet_dependencies = %s, writes = %s", self, @@ -201,16 +208,16 @@ def log_details(self): self.read_writes.writes, ) - def update_mutated_names(self, renames: Dict[str, str]): + def update_mutated_names(self, renames: Dict[str, str]) -> None: self.set_read_writes(self.read_writes.rename(renames)) - def add_mutation_dep(self, dep): + def add_mutation_dep(self, dep: Dep) -> None: self.set_read_writes(self.read_writes.with_read(dep)) - def add_fake_dep(self, dep): + def add_fake_dep(self, dep: Dep) -> None: self.set_read_writes(self.read_writes.with_read(dep)) - def set_users(self, users: List["NodeUser"]): + def set_users(self, users: List["NodeUser"]) -> None: # deduplicate result: Dict[int, NodeUser] = {} for use in users: @@ -222,26 +229,28 @@ def set_users(self, users: List["NodeUser"]): def set_last_usage( self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str] - ): + ) -> None: used_buffers = self.used_or_aliased_buffer_names() used_buffers = {mutation_real_name.get(k, k) for k in used_buffers} self.last_usage = used_buffers - future_used_buffers - def get_aliases(self): + def get_aliases(self) -> Sequence[str]: + assert self.node is not None return self.node.get_inputs_that_alias_output() - def get_mutations(self): + def get_mutations(self) -> List[str]: + assert self.node is not None return self.node.get_mutation_names() - def has_aliasing_or_mutation(self): + def has_aliasing_or_mutation(self) -> bool: return bool(self.get_aliases() or self.get_mutations()) - def set_read_writes(self, rw: dependencies.ReadWrites): + def set_read_writes(self, rw: dependencies.ReadWrites) -> None: self.read_writes: dependencies.ReadWrites = rw self.unmet_dependencies = self.read_writes.reads self.prune_deps() - def op_counts(self): + def op_counts(self) -> Counter[str]: return self.read_writes.op_counts def used_buffer_names(self) -> Set[str]: @@ -266,25 +275,28 @@ def used_or_aliased_buffer_names(self) -> Set[str]: deps.append(alias) return used_names - def prune_deps(self): + def prune_deps(self) -> None: self.unmet_dependencies = { dep for dep in self.unmet_dependencies if dep.name not in self.scheduler.available_buffer_names } - def prune_weak_deps(self): + def prune_weak_deps(self) -> None: # Prune weak dependencies on buffers that have been removed - def should_prune(dep): + def should_prune(dep: Dep) -> bool: return isinstance(dep, WeakDep) and dep.name in V.graph.removed_buffers to_remove = {dep for dep in self.read_writes.reads if should_prune(dep)} self.set_read_writes(self.read_writes.remove_reads(to_remove)) - def prune_redundant_deps(self, name_to_fused_node): + def prune_redundant_deps( + self, name_to_fused_node: Dict[str, "BaseSchedulerNode"] + ) -> None: _prune_redundant_deps(self, name_to_fused_node) def get_name(self) -> str: + assert self.node is not None return self.node.get_name() def get_first_name(self) -> str: @@ -296,35 +308,37 @@ def get_names(self) -> Set[str]: def get_nodes(self) -> Sequence["BaseSchedulerNode"]: return [self] - def get_device(self): + def get_device(self) -> torch.device: + assert self.node is not None return self.node.get_device() - def is_reduction(self): + def is_reduction(self) -> bool: return False - def is_split_scan(self): + def is_split_scan(self) -> bool: return False - def is_template(self): + def is_template(self) -> bool: return False - def is_extern(self): + def is_extern(self) -> bool: return False - def is_foreach(self): + def is_foreach(self) -> bool: return False - def can_inplace(self, read_dep: dependencies.MemoryDep): + def can_inplace(self, read_dep: dependencies.Dep) -> bool: return False - def has_side_effects(self): + def has_side_effects(self) -> bool: return False - def decide_inplace_update(self): + def decide_inplace_update(self) -> None: """ Decide if there should be inplace updates for the node and record the decision in the active kernel. """ + assert self.node is not None if not self.node.should_allocate(): return @@ -365,6 +379,7 @@ def decide_inplace_update(self): len(remaining_uses) == 1 and remaining_uses[0].can_inplace and remaining_uses[0].node is self + and input_node.node is not None and not isinstance( input_node.node.get_layout(), ( @@ -404,7 +419,8 @@ def decide_inplace_update(self): ] = input_node.get_name() break - def allocate(self): + def allocate(self) -> None: + assert self.node is not None if not self.node.should_allocate(): return @@ -428,8 +444,9 @@ def allocate(self): else: V.graph.wrapper_code.codegen_allocation(self.node) - def can_free(self): + def can_free(self) -> bool: # There's no real allocated buffer, no need to free it + assert self.node is not None if isinstance(self.node.layout, ir.NoneLayout): return False for use in self.users: @@ -437,12 +454,15 @@ def can_free(self): return False return True - def codegen_originating_info(self, buffer, only_once=True): + def codegen_originating_info( + self, buffer: IndentedBuffer, only_once: bool = True + ) -> None: if not config.comment_origin: return if only_once and self.written: return + assert self.node is not None origins = self.node.origins out_lines = [] @@ -509,7 +529,7 @@ def get_read_write_buffers_sizes(self) -> int: # todo: Calculate this - it's kinda annoying. return 0 - def try_size_hint(s): + def try_size_hint(s: sympy.Expr) -> int: return V.graph.sizevars.size_hint(s, fallback=0) if isinstance(self, SchedulerNode): @@ -526,7 +546,7 @@ def try_size_hint(s): reads = {dep.name for dep in self.read_writes.reads} writes = {dep.name for dep in self.read_writes.writes} - def is_materialized(buf, snodes): + def is_materialized(buf: str, snodes: Sequence[BaseSchedulerNode]) -> bool: users = self.scheduler.name_to_node[buf].users buf_uses = {user.node for user in users} return len(buf_uses - set(snodes)) > 0 @@ -549,13 +569,16 @@ def is_materialized(buf, snodes): else: continue - def get_buf_elems(buf): + def get_buf_elems(buf: Optional[Union[ir.Buffer, ir.TensorBox]]) -> int: + if not buf: + return 0 # Kind of a lazy way to get the MultiOutput nodes corresponding to # a MultiOutputLayout if isinstance(buf.layout, MultiOutputLayout): users = self.scheduler.name_to_node[buf.get_name()].users tot = 0 for user in users: + assert isinstance(user.node, BaseSchedulerNode) if isinstance(user.node.node, MultiOutput): tot += get_buf_elems(user.node.node) else: @@ -599,6 +622,7 @@ def get_estimated_runtime(self) -> float: # Collective kernels if is_collective(self.node): + assert self.node is not None try: return estimate_nccl_collective_runtime(self.node) except ValueError as e: @@ -673,15 +697,19 @@ def get_estimated_runtime(self) -> float: return 0 + def get_template_node(self) -> Optional[ir.TemplateBuffer]: + return None + class ExternKernelSchedulerNode(BaseSchedulerNode): def debug_str_extra(self) -> str: return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}" - def is_extern(self): + def is_extern(self) -> bool: return True - def has_side_effects(self): + def has_side_effects(self) -> bool: + assert self.node is not None return hasattr(self.node, "has_side_effects") and self.node.has_side_effects() @@ -691,15 +719,17 @@ class NopKernelSchedulerNode(BaseSchedulerNode): def debug_triton_code(node: Union["SchedulerNode", "FusedSchedulerNode"]) -> List[str]: lines = [] - is_multi_template = node.is_template() and isinstance( - node.get_template_node(), ir.MultiTemplateBuffer - ) - if is_multi_template and node.get_template_node().make_kernel_render is None: + multi_template = node.get_template_node() + assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer) + if multi_template and multi_template.make_kernel_render is None: lines.append(f"{node.get_name()} Unfinalized multi template buffer") else: + from torch._inductor.codegen.triton import TritonScheduling + snodes = (node,) if isinstance(node, SchedulerNode) else node.snodes device = snodes[0].get_device() backend = node.scheduler.get_backend(device) + assert isinstance(backend, TritonScheduling) V.graph.scheduler.current_device = device # Don't increment kernel count when generating debug string. @@ -719,14 +749,14 @@ def __init__( self, scheduler: "Scheduler", node: Union[ir.ComputedBuffer, ir.TemplateBuffer], - ): + ) -> None: super().__init__(scheduler, node) self._compute_attrs() def _compute_attrs( self, extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, - ): + ) -> None: assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) self._sizes, self._body = self.node.simplify_and_reorder( extra_indexing_constraints=extra_indexing_constraints @@ -746,7 +776,7 @@ def _compute_attrs( def recompute_size_and_body( self, extra_indexing_constraints: Tuple[Dict[Any, Any], List[Any]] - ): + ) -> None: self._compute_attrs(extra_indexing_constraints=extra_indexing_constraints) def debug_str_extra(self) -> str: @@ -768,21 +798,22 @@ def debug_str_extra(self) -> str: lines.append(f"class {name}_loop_body:") lines.append(textwrap.indent(self._body.debug_str(), " ")) + assert self.node is not None if ir.is_triton(self.node.get_device()): lines.extend(debug_triton_code(self)) return "\n".join(lines) - def get_ranges(self): + def get_ranges(self) -> Sequence[Sequence[sympy.Expr]]: return self._sizes - def is_reduction(self): + def is_reduction(self) -> bool: assert isinstance( self.node, (ir.ComputedBuffer, ir.TemplateBuffer) ), f"{type(self.node)=}" return bool(self.node.get_reduction_type()) - def is_split_scan(self): + def is_split_scan(self) -> bool: assert isinstance( self.node, (ir.ComputedBuffer, ir.TemplateBuffer) ), f"{type(self.node)=}" @@ -790,21 +821,23 @@ def is_split_scan(self): self.node.data, ir.SplitScan ) - def is_template(self): + def is_template(self) -> bool: return isinstance(self.node, ir.TemplateBuffer) - def get_template_node(self): - return self.node if self.is_template() else None + def get_template_node(self) -> Optional[ir.TemplateBuffer]: + return self.node if isinstance(self.node, ir.TemplateBuffer) else None - def run(self, *index_vars): + def run(self, *index_vars: Sequence[sympy.Expr]) -> None: self.decide_inplace_update() self.mark_run() self.codegen(index_vars) - def mark_run(self): + def mark_run(self) -> None: self.allocate() - def ranges_from_index_vars(self, index_vars): + def ranges_from_index_vars( + self, index_vars: Sequence[Sequence[sympy.Expr]] + ) -> Dict[sympy.Expr, sympy.Expr]: sizes = self._sizes assert sum(map(len, sizes)) == sum(map(len, index_vars)) var_ranges = dict( @@ -815,7 +848,7 @@ def ranges_from_index_vars(self, index_vars): ) return var_ranges - def codegen(self, index_vars): + def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None: var_ranges = self.ranges_from_index_vars(index_vars) try: with V.set_ops_handler( @@ -826,18 +859,18 @@ def codegen(self, index_vars): log.fatal("Error in codegen for %s", self.node) raise - def pointwise_read_writes(self): + def pointwise_read_writes(self) -> dependencies.ReadWrites: """ Get the memory dependencies in the non-reduction axis. """ sizes, reduction_sizes = self._sizes - def fn(index): + def fn(index: Sequence[sympy.Symbol]) -> str: return self._body(index, [sympy.Integer(0) for _ in reduction_sizes]) return dependencies.extract_read_writes(fn, sizes) - def can_inplace(self, read_dep: dependencies.MemoryDep): + def can_inplace(self, read_dep: dependencies.Dep) -> bool: if self.get_aliases() or self.is_template(): return False if len(self.read_writes.writes) == 1 and isinstance( @@ -877,18 +910,22 @@ class FusedSchedulerNode(BaseSchedulerNode): """ @classmethod - def fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + def fuse( + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> "FusedSchedulerNode": assert node1.scheduler is node2.scheduler - assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) and isinstance( - node2, (SchedulerNode, FusedSchedulerNode) - ) - return cls(node1.scheduler, list(node1.get_nodes()) + list(node2.get_nodes())) # type: ignore[arg-type] + assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) + assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) + nodes = list(itertools.chain(node1.get_nodes(), node2.get_nodes())) + return cls(node1.scheduler, nodes) - def __init__(self, scheduler: "Scheduler", snodes: List[SchedulerNode]): + def __init__( + self, scheduler: "Scheduler", snodes: Sequence[BaseSchedulerNode] + ) -> None: # NB: No need to call super().__init__() because we don't need to re-use any of its logic. self.snodes = snodes self.scheduler = scheduler - self.node: ir.Buffer = None # type: ignore[assignment] + self.node = None self.users: List[NodeUser] = [] self.inverse_users = [] self.node_users = [] @@ -925,7 +962,9 @@ def debug_str_extra(self) -> str: f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}" for i, node in enumerate(self.snodes) ] - device = self.snodes[0].node.get_device() + node = self.snodes[0].node + assert node is not None + device = node.get_device() if ir.is_triton(device): lines.extend(debug_triton_code(self)) @@ -933,7 +972,7 @@ def debug_str_extra(self) -> str: def set_last_usage( self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str] - ): + ) -> None: # Set self.last_usage using the global information # This will be used for inter-kernel optimisations super().set_last_usage(future_used_buffers, mutation_real_name) @@ -942,7 +981,7 @@ def set_last_usage( future_used_buffers: Set[str] = set() for node in reversed(self.snodes): node.set_last_usage(future_used_buffers, mutation_real_name) - future_used_buffers.update(node.last_usage) # type: ignore[arg-type] + future_used_buffers.update(node.last_usage) @cache_on_self def used_buffer_names(self) -> Set[str]: @@ -952,40 +991,40 @@ def used_buffer_names(self) -> Set[str]: def used_or_aliased_buffer_names(self) -> Set[str]: return set.union(*[x.used_or_aliased_buffer_names() for x in self.snodes]) - def get_nodes(self) -> List[SchedulerNode]: + def get_nodes(self) -> Sequence[BaseSchedulerNode]: return self.snodes - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}(nodes={self.get_name()})" @cache_on_self - def is_reduction(self): + def is_reduction(self) -> bool: return any(x.is_reduction() for x in self.snodes) @cache_on_self - def is_split_scan(self): + def is_split_scan(self) -> bool: return any(x.is_split_scan() for x in self.snodes) @cache_on_self - def is_template(self): + def is_template(self) -> bool: return any(x.is_template() for x in self.snodes) @cache_on_self - def get_template_node(self): + def get_template_node(self) -> Optional[ir.TemplateBuffer]: for node in self.snodes: if node.is_template(): return node.get_template_node() return None - def get_device(self): + def get_device(self) -> torch.device: return self.group[0] @cache_on_self - def has_aliasing_or_mutation(self): + def has_aliasing_or_mutation(self) -> bool: return any(x.has_aliasing_or_mutation() for x in self.snodes) @cache_on_self - def op_counts(self): + def op_counts(self) -> Counter[str]: op_counts: Counter[str] = collections.Counter() for node in self.snodes: op_counts.update(node.op_counts()) @@ -993,28 +1032,28 @@ def op_counts(self): # None of these need to be implemented, as a FusedSchedulerNode is just an # abstraction for scheduling purposes - def update_mutated_names(self, renames: Dict[str, str]): + def update_mutated_names(self, renames: Dict[str, str]) -> None: raise NotImplementedError - def add_mutation_dep(self, name): + def add_mutation_dep(self, name: Dep) -> None: raise NotImplementedError - def set_users(self, users: List["NodeUser"]): + def set_users(self, users: List["NodeUser"]) -> None: raise NotImplementedError - def get_aliases(self): + def get_aliases(self) -> Sequence[str]: raise NotImplementedError - def get_mutations(self): + def get_mutations(self) -> List[str]: raise NotImplementedError - def can_inplace(self, read_dep: dependencies.MemoryDep): + def can_inplace(self, read_dep: dependencies.Dep) -> bool: raise NotImplementedError - def allocate(self): + def allocate(self) -> None: raise NotImplementedError - def can_free(self): + def can_free(self) -> bool: raise NotImplementedError def debug_str(self) -> str: @@ -1042,13 +1081,17 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode): """Scheduler node which consists of a list of scheduler nodes that each operate on a distinct tensor in a list of tensors.""" - def get_consumer_subnode_for(self, producer): + def get_consumer_subnode_for( + self, producer: BaseSchedulerNode + ) -> Optional[BaseSchedulerNode]: if producer.get_name() in self.read_to_node: return self.read_to_node[producer.get_name()] return None - def get_producer_subnode_for(self, consumer): + def get_producer_subnode_for( + self, consumer: BaseSchedulerNode + ) -> Optional[BaseSchedulerNode]: for rd in consumer.read_writes.reads: if rd.name in self.name_to_node: return self.name_to_node[rd.name] @@ -1056,9 +1099,11 @@ def get_producer_subnode_for(self, consumer): return None @classmethod - def can_fuse(cls, producer, consumer): + def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool: why = WhyNoFuse(producer, consumer) if producer.is_foreach() and consumer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) foreach_match = len(producer.snodes) == len(consumer.snodes) if not foreach_match: why("foreach do not have same length") @@ -1067,6 +1112,7 @@ def can_fuse(cls, producer, consumer): for l, r in zip(producer.snodes, consumer.snodes) ) elif consumer.is_foreach(): + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) consumer_subnode = consumer.get_consumer_subnode_for(producer) if consumer_subnode is not None: return consumer.scheduler.can_fuse(producer, consumer_subnode) @@ -1075,6 +1121,7 @@ def can_fuse(cls, producer, consumer): return False elif producer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) producer_subnode = producer.get_producer_subnode_for(consumer) if producer_subnode is not None: return producer.scheduler.can_fuse(producer_subnode, consumer) @@ -1087,16 +1134,22 @@ def can_fuse(cls, producer, consumer): ) @classmethod - def fuse(cls, producer, consumer): + def fuse( + cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode + ) -> "ForeachKernelSchedulerNode": assert producer.is_foreach() or consumer.is_foreach() prev_node_1 = None prev_node_2 = None + fused_nodes: List[BaseSchedulerNode] if producer.is_foreach() and consumer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) fused_nodes = [ FusedSchedulerNode.fuse(l, r) for l, r in zip(producer.snodes, consumer.snodes) ] elif producer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) producer_subnode = producer.get_producer_subnode_for(consumer) fused_nodes = [] prev_node_1 = producer @@ -1110,6 +1163,7 @@ def fuse(cls, producer, consumer): fused_nodes.append(node) elif consumer.is_foreach(): + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) consumer_subnode = consumer.get_consumer_subnode_for(producer) fused_nodes = [] prev_node_1 = consumer @@ -1123,15 +1177,15 @@ def fuse(cls, producer, consumer): else: fused_nodes.append(node) - return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2) # type: ignore[possibly-undefined] + return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2) def __init__( self, scheduler: "Scheduler", - nodes: List[SchedulerNode], - prev_node_1=None, - prev_node_2=None, - ): + nodes: Sequence[BaseSchedulerNode], + prev_node_1: Optional[BaseSchedulerNode] = None, + prev_node_2: Optional[BaseSchedulerNode] = None, + ) -> None: self.read_to_node = {} self.name_to_node = {} @@ -1147,7 +1201,7 @@ def __init__( else: self.scheduler = scheduler self.snodes = nodes - self.node: ir.Buffer = None # type: ignore[assignment] + self.node = None self.users: List[NodeUser] = [] self.set_read_writes( @@ -1167,8 +1221,12 @@ def __init__( self.min_order = min([prev_node_1.min_order, prev_node_2.min_order]) self.max_order = max([prev_node_1.max_order, prev_node_2.max_order]) - foreach_node = prev_node_1 if prev_node_1.is_foreach() else prev_node_2 - other_node = prev_node_2 if prev_node_1.is_foreach() else prev_node_1 + if prev_node_1.is_foreach(): + assert isinstance(prev_node_1, ForeachKernelSchedulerNode) + foreach_node, other_node = prev_node_1, prev_node_2 + else: + assert isinstance(prev_node_2, ForeachKernelSchedulerNode) + foreach_node, other_node = prev_node_2, prev_node_1 self.ancestors = foreach_node.ancestors self.ancestors.update(other_node.ancestors) @@ -1177,50 +1235,57 @@ def __init__( for name in other_node.get_names(): self.name_to_node[name] = other_node - self.group = (nodes[0].get_device(), "foreach") + self.group = (nodes[0].get_device(), [[sympy.Expr("foreach")]]) self.origins: Set[torch.fx.Node] = set() - def mark_run(self): + def mark_run(self) -> None: raise NotImplementedError - def codegen(self): + def codegen(self) -> None: assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}" self.node.get_store_function()(self.node.make_loader()()) - def can_free(self): + def can_free(self) -> bool: raise NotImplementedError - def is_foreach(self): + def is_foreach(self) -> bool: return True - def get_subkernel_nodes(self): + def get_subkernel_nodes(self) -> List[BaseSchedulerNode]: """Returns a list of nodes which comprise the foreach kernel, operating on corresponding elements of our input lists. These nodes may be vertically fused.""" return list(self.snodes) - def get_nodes(self): - """Returns all nodes contained in this kernel, unpacking fused nodes into their constituent scheduler nodes.""" + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + """Returns all nodes contained in this kernel, unpacking fused nodes + into their constituent scheduler nodes.""" return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes)) - def get_first_name(self): + def get_first_name(self) -> str: return self.snodes[0].get_first_name() - def prune_redundant_deps(self, name_to_fused_node): + def prune_redundant_deps( + self, name_to_fused_node: Dict[str, BaseSchedulerNode] + ) -> None: _prune_redundant_deps(self, name_to_fused_node) for node in self.snodes: node.prune_redundant_deps(name_to_fused_node) -def pick_loop_order(stride_lengths, sizes, priority_idx=()): +def pick_loop_order( + stride_lengths: List[List[int]], + sizes: List[sympy.Expr], + priority_idx: Tuple[int, ...] = (), +) -> List[int]: """ A heuristic to decide loop iteration orders. This has not been well tuned and may be something we should autotune. """ @functools.cmp_to_key - def index_cmp(a, b): + def index_cmp(a: int, b: int) -> int: if sizes[a] == 1 or sizes[b] == 1: # 1-sizes don't matter, just move them to the end return cmp(sizes[a] == 1, sizes[b] == 1) @@ -1255,24 +1320,25 @@ def index_cmp(a, b): @dataclasses.dataclass class NodeUser: - node: BaseSchedulerNode + node: Union[BaseSchedulerNode, OutputNode] can_inplace: bool = False # A weak user must be scheduled after a given node, but doesn't actually # use the result is_weak: bool = False - def __hash__(self): + def __hash__(self) -> int: return hash((self.node.get_name(), self.can_inplace, self.is_weak)) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return ( - self.get_name() == other.get_name() + isinstance(other, NodeUser) + and self.get_name() == other.get_name() and self.can_inplace == other.can_inplace and self.is_weak == other.is_weak ) - def get_name(self): + def get_name(self) -> str: return self.node.get_name() def merge(self, other: "NodeUser") -> "NodeUser": @@ -1289,11 +1355,10 @@ def merge(self, other: "NodeUser") -> "NodeUser": class Scheduler: @dynamo_timed - def __init__(self, nodes): + def __init__(self, nodes: List[ir.Buffer]) -> None: super().__init__() V.graph.scheduler = self - self.backends = {} - self.fuse_cache = {} + self.backends: Dict[torch.device, BaseScheduling] = {} self.post_grad_graph_id = next(_post_grad_graph_counter) self.available_buffer_names = { @@ -1321,7 +1386,7 @@ def __init__(self, nodes): # If you mutate buf0 inside of buf1's kernel, then: # mutation_real_name = {"buf0" : "buf1"} # all subsequent uses of buf0 become buf1's usage in dependency graph - self.mutation_real_name = {} + self.mutation_real_name: Dict[str, str] = {} # We handle mutation by renaming modified versions of the same # buffer in the dependency graph to prevent cycles. @@ -1331,7 +1396,7 @@ def __init__(self, nodes): # If you mutate buf0 inside of buf1's kernel, then: # mutation_renames = {"buf1" : "buf0"} # in codegen we only use buf0, never buf1 - self.mutation_renames = {} + self.mutation_renames: Dict[str, str] = {} self.compute_dependencies() self.topological_sort_schedule() @@ -1346,7 +1411,7 @@ def __init__(self, nodes): self.name_to_fused_node = {n.get_name(): n for n in self.nodes} self.create_foreach_nodes() self.topological_sort_schedule() - self.logged_slow_fusion = set() + self.logged_slow_fusion: Set[Tuple[str, str]] = set() self.fuse_nodes() self.finalize_multi_template_buffers() if config.reorder_for_compute_comm_overlap: @@ -1359,12 +1424,12 @@ def __init__(self, nodes): self.debug_draw_graph() # used during codegen: - self.current_device: torch.device = None # type: ignore[assignment] - self.buffer_names_to_free = set() + self.current_device: Optional[torch.device] = None + self.buffer_names_to_free: Set[str] = set() # fx graph node to the position it appears in the graph # for debug attribution - self.origin_to_index = {} + self.origin_to_index: Dict[torch.fx.Node, int] = {} get_metric_table("graph_stats").add_row( lambda: { @@ -1374,20 +1439,26 @@ def __init__(self, nodes): } ) - def debug_draw_graph(self): + def get_current_device_or_throw(self) -> torch.device: + if device := self.current_device: + return device + else: + raise RuntimeError("No current device") + + def debug_draw_graph(self) -> None: """Generate an image of the graph for debugging""" if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1": from .debug import draw_buffers draw_buffers(self.nodes, print_graph=True) - def debug_print_nodes(self, label): + def debug_print_nodes(self, label: str) -> None: if log.isEnabledFor(logging.INFO): log.info("%s:", label) for node in self.nodes: node.log_details() - def create_scheduler_node(self, node): + def create_scheduler_node(self, node: ir.Buffer) -> BaseSchedulerNode: assert ( node.origins is not None ), "All nodes passed to scheduling must have an origin" @@ -1400,7 +1471,7 @@ def create_scheduler_node(self, node): else: raise NotImplementedError(node) - def create_foreach_nodes(self): + def create_foreach_nodes(self) -> None: removed_node_names = set() fe_nodes = [] kept_node_names = self.name_to_fused_node.keys() @@ -1419,7 +1490,7 @@ def create_foreach_nodes(self): removed_node_names.update(names) snodes = [self.name_to_node[name] for name in names] - fe_node = ForeachKernelSchedulerNode(self, snodes) # type: ignore[arg-type] + fe_node = ForeachKernelSchedulerNode(self, snodes) fe_nodes.append(fe_node) @@ -1428,9 +1499,9 @@ def create_foreach_nodes(self): self.nodes = [ node for node in self.nodes if node.get_name() not in removed_node_names - ] + fe_nodes + ] + list(fe_nodes) - def compute_dependencies(self): + def compute_dependencies(self) -> None: """ Create dependency edges between nodes, handling aliasing and mutation properly. @@ -1448,7 +1519,11 @@ class DedupList(Generic[T]): semantics. """ - def __init__(self, items=None, membership=None): + def __init__( + self, + items: Optional[List[T]] = None, + membership: Optional[Set[T]] = None, + ) -> None: self.items = items or list() self.membership = membership or set() @@ -1488,12 +1563,12 @@ def __add__(self, other: "DedupList[T]") -> "DedupList[T]": else: name_to_users[node1_name] = name_to_users[node2_name] - def rename(n): + def rename(n: str) -> str: if n in self.mutation_renames: return rename(self.mutation_renames[n]) return n - def dep_closure(node_name): + def dep_closure(node_name: str) -> Set[str]: reachable_names = {node_name} node = self.name_to_node[node_name] write_dep = next(iter(node.read_writes.writes)) @@ -1508,7 +1583,12 @@ def dep_closure(node_name): reachable_names.update(dep_closure(read_dep.name)) return reachable_names - def add_user(used_by_name, user_node, can_inplace=False, is_weak=False): + def add_user( + used_by_name: str, + user_node: Union[BaseSchedulerNode, OutputNode], + can_inplace: bool = False, + is_weak: bool = False, + ) -> None: name_to_users[rename(used_by_name)].append( NodeUser(user_node, can_inplace, is_weak) ) @@ -1527,6 +1607,7 @@ def add_user(used_by_name, user_node, can_inplace=False, is_weak=False): # unbacked symbols don't follow ordinary buffer dependencies, so # we track their def/uses separately + assert node.node is not None unbacked_symbol_defs = sorted( node.node.get_unbacked_symbol_defs(), key=lambda x: x.name ) @@ -1632,9 +1713,9 @@ def add_user(used_by_name, user_node, can_inplace=False, is_weak=False): for user in node.users: user.node.inverse_users.append(node) - def compute_node_users(self): + def compute_node_users(self) -> None: # set up buffer name to (fused)snode mapping - buf_to_snode = {} + buf_to_snode: Dict[str, BaseSchedulerNode] = {} for node in self.nodes: if isinstance(node, FusedSchedulerNode): for x in node.snodes: @@ -1647,7 +1728,7 @@ def compute_node_users(self): # compute inverse_users for node in self.nodes: - inverse_users = [] + inverse_users: List[BaseSchedulerNode] = [] for dep in node.unmet_dependencies: assert dep.name in buf_to_snode dep_node = buf_to_snode[dep.name] @@ -1665,7 +1746,7 @@ def compute_node_users(self): for node, users in node_to_users.items(): node.node_users = users - def dead_node_elimination(self): + def dead_node_elimination(self) -> None: """ Remove any nodes without users """ @@ -1674,7 +1755,7 @@ def dead_node_elimination(self): updated_nodes = [] for node in self.nodes: - def can_eliminate_user(user: NodeUser): + def can_eliminate_user(user: NodeUser) -> bool: return user.is_weak or user.get_name() in V.graph.removed_buffers can_eliminate = not node.has_side_effects() and all( @@ -1695,15 +1776,15 @@ def can_eliminate_user(user: NodeUser): for node in self.nodes: node.prune_weak_deps() - def topological_sort_schedule(self): + def topological_sort_schedule(self) -> None: """ Ensure self.nodes is in topologically sorted order """ - seen: Set[ir.Buffer] = set() - name_to_node: Dict[str, ir.Buffer] = dict() - result: List[ir.Buffer] = [] + seen: Set[BaseSchedulerNode] = set() + name_to_node: Dict[str, BaseSchedulerNode] = dict() + result: List[BaseSchedulerNode] = [] - def visit(n): + def visit(n: BaseSchedulerNode) -> None: if n not in seen: seen.add(n) for dep in sorted(n.unmet_dependencies, key=lambda d: d.name): @@ -1717,7 +1798,7 @@ def visit(n): visit(node) self.nodes = result - def compute_ancestors(self): + def compute_ancestors(self) -> None: """ Populate each node.ancestors """ @@ -1735,7 +1816,7 @@ def compute_ancestors(self): node.min_order = order node.max_order = order - def fuse_nodes(self): + def fuse_nodes(self) -> None: """ Mutates self.nodes to combine nodes into FusedSchedulerNodes. """ @@ -1758,7 +1839,9 @@ def fuse_nodes(self): fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1) break - def benchmark_fused_nodes(self, nodes) -> Tuple[float, str]: + def benchmark_fused_nodes( + self, nodes: Sequence[BaseSchedulerNode] + ) -> Tuple[float, str]: """ Benchmark fused list of nodes and return the execution time in milliseconds on randomly generated inputs. @@ -1769,8 +1852,10 @@ def benchmark_fused_nodes(self, nodes) -> Tuple[float, str]: backend = self.get_backend(device) return backend.benchmark_fused_nodes(nodes) - def finalize_multi_template_buffers(self): - def replace_buffer(orig_node: ir.MultiTemplateBuffer, new_node: ir.Buffer): + def finalize_multi_template_buffers(self) -> None: + def replace_buffer( + orig_node: ir.MultiTemplateBuffer, new_node: ir.Buffer + ) -> None: replaced_name = new_node.name orig_name = orig_node.get_name() assert isinstance(orig_name, str) and isinstance(replaced_name, str) @@ -1819,7 +1904,9 @@ def replace_buffer(orig_node: ir.MultiTemplateBuffer, new_node: ir.Buffer): user.node.inverse_users.remove(node) user.node.inverse_users.append(new_scheduler_node) - def speedup_by_fusion(self, node1, node2): + def speedup_by_fusion( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: """ If config.benchmark_fusion is False, always return True. Otherwise, return True if fusion can brings speedup. @@ -1848,13 +1935,14 @@ def speedup_by_fusion(self, node1, node2): return True node_list_2 = node2.get_nodes() - node_list_fused = node_list_1 + node_list_2 + node_list_fused = list(itertools.chain(node_list_1, node_list_2)) # We can not accurately benchmark kernel using atomic_add # due to how we generate random integer inputs. # Skip benchmarking them by allowing fusion. if any( hasattr(n.node, "data") + and n.node is not None and hasattr(n.node.data, "scatter_mode") and n.node.data.scatter_mode == "atomic_add" for n in node_list_fused @@ -1865,7 +1953,7 @@ def speedup_by_fusion(self, node1, node2): why = WhyNoFuse(node1, node2) - def log_fusion(ms_fused, ms1, ms2): + def log_fusion(ms_fused: float, ms1: float, ms2: float) -> None: if fusion_log.isEnabledFor(logging.DEBUG): if ms_fused < ms1 + ms2: fusion_log.debug( @@ -1966,7 +2054,7 @@ def log_fusion(ms_fused, ms1, ms2): ) return ms_fused < ms1 + ms2 - def fuse_nodes_once(self): + def fuse_nodes_once(self) -> None: """ Mutates self.nodes to combine nodes into FusedSchedulerNodes. @@ -2000,18 +2088,18 @@ def fuse_nodes_once(self): self.topological_sort_schedule() self.prune_redundant_deps() - def prune_redundant_deps(self): + def prune_redundant_deps(self) -> None: for node in self.nodes: node.prune_redundant_deps(self.name_to_fused_node) - def get_possible_fusions(self): + def get_possible_fusions(self) -> List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]: """ Helper to find all legal fusion opportunities, sorted by self.score_fusion() """ possible_fusions = [] seen = set() - def check_all_pairs(nodes): + def check_all_pairs(nodes: List[BaseSchedulerNode]) -> None: for node1_index, node1 in enumerate(nodes): for node2 in nodes[node1_index + 1 :]: key = (node1, node2) @@ -2050,13 +2138,17 @@ def check_all_pairs(nodes): fusion_log.debug("found %d possible fusions", len(possible_fusions)) return possible_fusions - def will_fusion_create_cycle(self, node1, node2): + def will_fusion_create_cycle( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: """ Finds whether there's a path from node1 to node2 (or vice-versa) caused indirectly by other fusions. """ - def found_path(node): + visited = set() + + def found_path(node: BaseSchedulerNode) -> bool: # only fused nodes can introduce new ancestors. if isinstance(node, FusedSchedulerNode) and node not in visited: visited.add(node) @@ -2079,7 +2171,6 @@ def found_path(node): ) return False - visited = set() combined_names = node1.get_names() | node2.get_names() combined_ancestors = (node1.ancestors | node2.ancestors) - combined_names cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors) @@ -2089,7 +2180,7 @@ def found_path(node): def can_fusion_increase_peak_memory( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode - ): + ) -> bool: """ This function prevents fusion for nodes that can increase memory footprint. This problem is more common in horizontal fusion, where nodes @@ -2114,7 +2205,12 @@ def can_fusion_increase_peak_memory( ) return proximity_score > 64 - def decide_fusion_fail_reason(self, node1, node2, common_buf_names): + def decide_fusion_fail_reason( + self, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + common_buf_names: Tuple[str, ...], + ) -> str: """ Try to decide reasons why fusion fail due to no shared memory even though there are common buffers. @@ -2168,7 +2264,7 @@ def decide_fusion_fail_reason(self, node1, node2, common_buf_names): return str(reasons) - def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: """ Determine if it is possible to combine node1 and node2 into a single fused node. @@ -2262,7 +2358,9 @@ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): return False return self.get_backend(device).can_fuse_horizontal(node1, node2) - def can_fuse_vertical(self, node1, node2): + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: """ Check if it is legal to fuse a consumer (node2) into a producer (node1). @@ -2281,7 +2379,7 @@ def can_fuse_vertical(self, node1, node2): # However, broadcasting sometimes strips dimensions, and if that's the case # we still can match unmet dep # if there's indirect indexing, don't match it - def fusable_read_and_write(read: Dep, write: Dep): + def fusable_read_and_write(read: Dep, write: Dep) -> bool: read_name = self.mutation_renames.get(read.name, read.name) write_name = self.mutation_renames.get(write.name, write.name) if ( @@ -2341,7 +2439,9 @@ def fusable_read_and_write(read: Dep, write: Dep): return True - def score_fusion(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + def score_fusion( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> Tuple[bool, bool, int, int]: """ Assign a score (higher comes first) to the fusion of node1 and node2. When different fusions conflict with each other, @@ -2363,7 +2463,9 @@ def score_fusion(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): proximity_score, ) - def score_fusion_memory(self, node1, node2): + def score_fusion_memory( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: """ The first term in our fusion score that estimates number of saved memory operations. """ @@ -2375,7 +2477,9 @@ def score_fusion_memory(self, node1, node2): } return sum(dep.numbytes_hint() for dep in common_memory_deps) - def get_possible_fusions_with_highest_priority(self, possible_fusions): + def get_possible_fusions_with_highest_priority( + self, possible_fusions: List[Tuple[BaseSchedulerNode, BaseSchedulerNode]] + ) -> List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]: # Group the possible fusions based on their priority from the backend. # Only return the group of possible fusions with highest priority. if len(possible_fusions) == 0: @@ -2405,14 +2509,16 @@ def get_possible_fusions_with_highest_priority(self, possible_fusions): assert len(possible_fusions_with_highest_priority) > 0 return possible_fusions_with_highest_priority - def score_fusion_key(self, nodes): + def score_fusion_key( + self, nodes: Tuple[BaseSchedulerNode, BaseSchedulerNode] + ) -> Tuple[bool, bool, int, int]: """ Shim for list.sort(key=...) """ node1, node2 = nodes return self.score_fusion(node1, node2) - def compute_last_usage(self): + def compute_last_usage(self) -> None: """ Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode) """ @@ -2423,7 +2529,7 @@ def compute_last_usage(self): node.set_last_usage(future_used_buffers, self.mutation_real_name) future_used_buffers.update(node.last_usage) - def free_buffers(self): + def free_buffers(self) -> None: """Free any buffers that are no longer needed""" for name in sorted( self.buffer_names_to_free @@ -2441,7 +2547,7 @@ def free_buffers(self): self.buffer_names_to_free.clear() - def remove_kernel_local_buffers(self): + def remove_kernel_local_buffers(self) -> None: """ Any buffers that are both created and have a last use in the same kernel can be removed. @@ -2458,7 +2564,7 @@ def remove_kernel_local_buffers(self): if users.issubset(fused_node_names): names_to_remove.append(out_buf) - def remove_filter(n): + def remove_filter(n: str) -> bool: return ( n not in V.kernel.must_keep_buffers and n not in V.kernel.args.input_buffers @@ -2480,7 +2586,7 @@ def remove_filter(n): else: self.remove_buffer(name) - def remove_buffer(self, name): + def remove_buffer(self, name: str) -> None: # Assign a special value instead of deleting the entry # because we still rely on output_buffers's length to # generate unique arg name. @@ -2488,7 +2594,7 @@ def remove_buffer(self, name): V.kernel.args.output_buffers[name] = "REMOVED" V.kernel.removed_buffers.add(name) - def remove_inplace_buffer(self, name): + def remove_inplace_buffer(self, name: str) -> None: log.debug("removing_inplace_buffer(%r)", name) inner_name = V.kernel.args.inplace_buffers[name].inner_name V.kernel.args.inplace_buffers[name] = inner_name.replace( @@ -2496,12 +2602,12 @@ def remove_inplace_buffer(self, name): ) V.kernel.removed_buffers.add(name) - def flush(self): + def flush(self) -> None: for backend in self.backends.values(): backend.flush() self.free_buffers() - def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode): + def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode) -> None: assert isinstance(scheduler_node, ExternKernelSchedulerNode) # 'decide_inplace_update' stores the inplace update decisions in # the current kernel from where 'allocate' retrieve those decisions. @@ -2516,7 +2622,7 @@ def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode): node.codegen(V.graph.wrapper_code) self.free_buffers() - def create_backend(self, device: torch.device): + def create_backend(self, device: torch.device) -> "BaseScheduling": assert ( not is_gpu(device.type) or device.index is not None ), f"{device} should have been normalized in lowering" @@ -2541,20 +2647,23 @@ def create_backend(self, device: torch.device): return device_scheduling(self) - def get_backend(self, device: torch.device): + def get_backend(self, device: torch.device) -> "BaseScheduling": if device not in self.backends: self.backends[device] = self.create_backend(device) return self.backends[device] - def enter_context(self, node): - def get_order(n): + def enter_context(self, node: BaseSchedulerNode) -> None: + def get_order(n: torch.fx.Node) -> int: if n not in self.origin_to_index: self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) return self.origin_to_index[n] # Use a dict to have ordering origins = { - (get_order(e), e): None for n in node.get_nodes() for e in n.node.origins + (get_order(e), e): None + for n in node.get_nodes() + if n.node is not None + for e in n.node.origins } origins = list(origins.keys()) if origins: @@ -2562,7 +2671,7 @@ def get_order(n): V.graph.wrapper_code.enter_context(last) @dynamo_timed - def codegen(self): + def codegen(self) -> None: for node in self.nodes: try: log.debug( @@ -2602,13 +2711,23 @@ def codegen(self): if node.is_template(): node, *epilogue = node.get_nodes() - self.get_backend(device).codegen_template(node, epilogue) # type: ignore[possibly-undefined] + self.get_backend(device).codegen_template(node, epilogue) elif node.is_extern(): + node = typing.cast(ExternKernelSchedulerNode, node) self.codegen_extern_call(node) elif node.is_foreach(): - self.get_backend(device).codegen_foreach(node) # type: ignore[possibly-undefined] + node = typing.cast(ForeachKernelSchedulerNode, node) + backend_ = self.get_backend(device) + from .codegen.cuda_combined_scheduling import CUDACombinedScheduling + from .codegen.simd import SIMDScheduling + + if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)): + backend = backend_ + else: + raise AssertionError(f"{type(self)=}") + backend.codegen_foreach(node) elif isinstance(node, (FusedSchedulerNode, SchedulerNode)): - self.get_backend(device).codegen_node(node) # type: ignore[possibly-undefined] + self.get_backend(device).codegen_node(node) else: assert isinstance(node, NopKernelSchedulerNode) node.allocate() @@ -2617,7 +2736,7 @@ def codegen(self): V.graph.wrapper_code.generate_inf_and_nan_checker(node) if config.triton.debug_sync_kernel: - self.get_backend(device).codegen_sync() # type: ignore[possibly-undefined] + self.get_backend(device).codegen_sync() self.available_buffer_names.update(node.get_names()) @@ -2635,23 +2754,30 @@ def codegen(self): def get_buffer_layout(self, buf_name: str) -> ir.Layout: node = self.name_to_node[buf_name] + assert node.node is not None return node.node.get_layout() class BaseScheduling: - def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: """ Check whether node1 and node2 can be vertically fused or not. """ raise NotImplementedError - def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + def can_fuse_horizontal( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: """ Check whether node1 and node2 can be horizontally fused or not. """ raise NotImplementedError - def fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + def fuse( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> FusedSchedulerNode: """ Fuse two nodes """ @@ -2660,15 +2786,19 @@ def fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): else: return FusedSchedulerNode.fuse(node1, node2) - def group_fn(self, sizes): + def group_fn( + self, sizes: Sequence[Sequence[sympy.Expr]] + ) -> Sequence[Sequence[sympy.Expr]]: """ Process the iteration sizes in case a transformation needs to be applied. """ raise NotImplementedError def codegen_template( - self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode] - ): + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + ) -> Optional[str]: """ Given a template node, generate a kernel. @@ -2677,13 +2807,13 @@ def codegen_template( """ raise NotImplementedError - def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]): + def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None: """ Generate a kernel given a list of pre-fused nodes. """ raise NotImplementedError - def codegen_sync(self): + def codegen_sync(self) -> None: """ Generate synchronization code for the kernel. This method depends on the hardware characteristics. """ @@ -2696,20 +2826,24 @@ def ready_to_flush(self) -> bool: """ return False - def flush(self): + def flush(self) -> None: """ Flush the generated kernel and python wrapper code to the source code file. """ raise NotImplementedError - def benchmark_fused_nodes(self, nodes): + def benchmark_fused_nodes( + self, nodes: Sequence[BaseSchedulerNode] + ) -> Tuple[float, str]: """ Benchmark fused list of nodes and return the execution time in milliseconds on randomly generated inputs. """ raise NotImplementedError - def get_fusion_pair_priority(self, node1, node2) -> int: + def get_fusion_pair_priority( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: """ Return an unsigned integer which represents the priority of this fusion pair. The smaller is with higher priority. diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 0aad631fc56f6..a92f535a9b49a 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -470,6 +470,8 @@ def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): if isinstance(call_args[i], sympy.Symbol): call_args[i] = texpr(call_args[i]) + current_device = V.graph.scheduler.get_current_device_or_throw() + if V.graph.cpp_wrapper: # In the cpp_wrapper case, we have to compute CUDA launch grid at runtime # if any dynamic dimension is involved. We rely on the Python version @@ -484,15 +486,13 @@ def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): wrapper.generate_kernel_call( name, call_args, - device_index=V.graph.scheduler.current_device.index, + device_index=current_device.index, arg_types=arg_types, grid=grid, triton_meta=self.triton_meta, ) else: - stream_name = wrapper.write_get_raw_stream( - V.graph.scheduler.current_device.index - ) + stream_name = wrapper.write_get_raw_stream(current_device.index) wrapper.add_import_once(f"import {self.grid_fn.__module__}") meta = wrapper.add_meta_once(self.meta) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 59baad51885e0..54ade4b500548 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1470,6 +1470,7 @@ def dump_node_schedule(node_schedule): is_red = node.is_reduction() print(f"{'red' if is_red else 'pw'} scheduler node") if is_red: + assert node.node is not None print(f"original reduction hint {node.node.data.reduction_hint}") # type: ignore[attr-defined] print("ReadDep:") for dep in node.read_writes.reads: