Skip to content
43 changes: 26 additions & 17 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,23 @@ class DeviceCodegen:
device_codegens: Dict[str, DeviceCodegen] = {}


class DeviceOpOverrides:
def import_get_raw_stream_as(self, name):
raise NotImplementedError()

def set_device(self, device_idx):
raise NotImplementedError()

def synchronize(self):
raise NotImplementedError()

def device_guard(self, device_idx):
raise NotImplementedError()


device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}


# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
# For any new backend looking to integrate with Inductor, customization of these two main
# parts are necessary to generate its specific code.
Expand Down Expand Up @@ -133,12 +150,18 @@ def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]


def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
device_op_overrides_dict[device] = device_op_overrides


def get_device_op_overrides(device: str):
assert isinstance(device, str)
if device == "cuda":
from .cuda.device_op_overrides import CUDADeviceOpOverrides

return CUDADeviceOpOverrides()
if not device_op_overrides_dict.keys():
from .cuda import device_op_overrides # noqa: F401

if device in device_op_overrides_dict.keys():
return device_op_overrides_dict[device]

return DeviceOpOverrides()

Expand Down Expand Up @@ -803,20 +826,6 @@ def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
return h


class DeviceOpOverrides:
def import_get_raw_stream_as(self, name):
raise NotImplementedError()

def set_device(self, device_idx):
raise NotImplementedError()

def synchronize(self):
raise NotImplementedError()

def device_guard(self, device_idx):
raise NotImplementedError()


class DeferredLine(DeferredLineBase):
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""

Expand Down
5 changes: 4 additions & 1 deletion torch/_inductor/codegen/cuda/device_op_overrides.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ..common import DeviceOpOverrides
from ..common import DeviceOpOverrides, register_device_op_overrides


class CUDADeviceOpOverrides(DeviceOpOverrides):
Expand All @@ -13,3 +13,6 @@ def synchronize(self):

def device_guard(self, device_idx):
return f"torch.cuda._DeviceGuard({device_idx})"


register_device_op_overrides("cuda", CUDADeviceOpOverrides())