From b3243bf4fbfa8e07cead9ff2be15feeb3d0aeb30 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 30 Aug 2024 17:20:50 +0000 Subject: [PATCH 1/6] Update [ghstack-poisoned] --- .../triton/device_interface.py | 5 +- torch/_dynamo/device_interface.py | 11 ++-- torch/_dynamo/variables/builder.py | 12 ++-- torch/_dynamo/variables/torch.py | 3 +- torch/_streambase.py | 56 +++++-------------- torch/cuda/streams.py | 5 +- torch/xpu/streams.py | 8 +-- 7 files changed, 34 insertions(+), 66 deletions(-) diff --git a/test/inductor/extension_backends/triton/device_interface.py b/test/inductor/extension_backends/triton/device_interface.py index c7cabf31dc67e..9ca96e71a7d5a 100644 --- a/test/inductor/extension_backends/triton/device_interface.py +++ b/test/inductor/extension_backends/triton/device_interface.py @@ -2,6 +2,7 @@ import time +import torch from torch._dynamo import device_interface # noqa: PLC2701 import-private-name @@ -13,9 +14,7 @@ def __init__(self) -> None: class DeviceInterface(device_interface.DeviceInterface): - class Event( - device_interface._EventBase - ): # pyright: ignore [reportPrivateImportUsage] + class Event(torch.Event): def __init__( self, enable_timing: bool = False, diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index 5670172c49c52..9554b08c65ba1 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union import torch -from torch._streambase import _EventBase, _StreamBase get_cuda_stream: Optional[Callable[[int], int]] @@ -24,12 +23,12 @@ def __new__(metacls, *args, **kwargs): class_member = args[2] if "Event" in class_member: assert inspect.isclass(class_member["Event"]) and issubclass( - class_member["Event"], _EventBase - ), "DeviceInterface member Event should be inherit from _EventBase" + class_member["Event"], torch.Event + ), "DeviceInterface member Event should be inherit from torch.Event" if "Stream" in class_member: assert inspect.isclass(class_member["Stream"]) and issubclass( - class_member["Stream"], _StreamBase - ), "DeviceInterface member Stream should be inherit from _StreamBase" + class_member["Stream"], torch.Stream + ), "DeviceInterface member Stream should be inherit from torch.Stream" return super().__new__(metacls, *args, **kwargs) @@ -155,7 +154,7 @@ class CudaInterface(DeviceInterface): device = torch.cuda.device # register Event and Stream class into the backend interface - # make sure Event and Stream are implemented and inherited from the _EventBase and _StreamBase + # make sure Event and Stream are implemented and inherited from the torch.Event and torch.Stream Event = torch.cuda.Event Stream = torch.cuda.Stream diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index e0d355a7fb2b3..facd4068d680f 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -34,7 +34,6 @@ from torch._guards import GuardSource, TracingContext from torch._higher_order_ops.torchbind import call_torchbind from torch._ops import HigherOrderOperator -from torch._streambase import _EventBase, _StreamBase from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode from torch._subclasses.meta_utils import is_sparse_any, safe_grad from torch._utils_internal import justknobs_check @@ -818,7 +817,7 @@ def build_key_value(i, k, v): stream_source = AttrSource(self.source, "stream") stream_var = VariableBuilder(self.tx, stream_source)(value.stream) return StreamContextVariable.create(self.tx, stream_var) - elif isinstance(value, _StreamBase): + elif isinstance(value, torch.Stream): self.install_guards(GuardBuilder.ID_MATCH) stream_proxy = self.tx.output.create_proxy( "call_function", @@ -840,7 +839,7 @@ def build_key_value(i, k, v): elif isinstance(value, (torch._C._SDPAParams)): self.install_guards(GuardBuilder.TYPE_MATCH) return SDPAParamsVariable.create(self.tx, value, self.source) - elif isinstance(value, _EventBase): + elif isinstance(value, torch.Event): self.install_guards(GuardBuilder.ID_MATCH) torch._dynamo.utils.store_user_object_weakref(value) event_proxy = self.tx.output.create_proxy( @@ -2215,7 +2214,7 @@ def _clone_input(value): return SymNodeVariable(proxy, example_value, **options) elif ( inspect.isclass(proxy.node.target) - and issubclass(proxy.node.target, _StreamBase) + and issubclass(proxy.node.target, torch.Stream) ) or proxy.node.target in [ device_interface.current_stream for _, device_interface in get_registered_device_interfaces() @@ -2223,7 +2222,8 @@ def _clone_input(value): set_example_value(proxy.node, example_value) return StreamVariable(proxy, example_value, example_value.device, **options) elif ( - inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, _EventBase) + inspect.isclass(proxy.node.target) + and issubclass(proxy.node.target, torch.Event) ) or proxy.node.target in [ device_interface.Event for _, device_interface in get_registered_device_interfaces() @@ -2235,7 +2235,7 @@ def _clone_input(value): return ConstantVariable(example_value, **options) elif ( example_value is not None - and isinstance(example_value, _EventBase) + and isinstance(example_value, torch.Event) and proxy.node.target == "record_event" and proxy.node.op == "call_method" ): diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 1d86c16a6e09e..cf39caaae0894 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -14,7 +14,6 @@ import torch.onnx.operators from torch._guards import TracingContext from torch._logging import warning_once -from torch._streambase import _StreamBase from .. import config, polyfills, variables from ..codegen import PyCodegen @@ -257,7 +256,7 @@ def call_function( assert len(args) <= 1 and len(kwargs) == 0 inf_mode = args[0].as_python_constant() if len(args) == 1 else True return InferenceModeVariable.create(tx, inf_mode) - elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase): + elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream): from torch._dynamo.variables.builder import wrap_fx_proxy_cls return wrap_fx_proxy_cls( diff --git a/torch/_streambase.py b/torch/_streambase.py index 85e203a3d9938..9d71120c959b1 100644 --- a/torch/_streambase.py +++ b/torch/_streambase.py @@ -1,46 +1,20 @@ -# mypy: allow-untyped-defs -from abc import ABC, abstractmethod +from typing_extensions import deprecated +import torch -class _StreamBase(ABC): - r"""Base stream class abstraction for multi backends Stream to herit from""" - @abstractmethod - def wait_event(self, event) -> None: - raise NotImplementedError +# Preserved only for BC reasons +@deprecated( + "`torch._streambase._StreamBase` is deprecated. Please use `torch.Stream` instead.", + category=FutureWarning, +) +class _StreamBase(torch.Stream): + pass - @abstractmethod - def wait_stream(self, stream) -> None: - raise NotImplementedError - @abstractmethod - def record_event(self, event=None) -> None: - raise NotImplementedError - - @abstractmethod - def query(self) -> bool: - raise NotImplementedError - - @abstractmethod - def synchronize(self) -> None: - raise NotImplementedError - - @abstractmethod - def __eq__(self, stream) -> bool: - raise NotImplementedError - - -class _EventBase(ABC): - r"""Base Event class abstraction for multi backends Event to herit from""" - - @abstractmethod - def wait(self, stream=None) -> None: - raise NotImplementedError - - @abstractmethod - def query(self) -> bool: - raise NotImplementedError - - @abstractmethod - def synchronize(self) -> None: - raise NotImplementedError +@deprecated( + "`torch._streambase._EventBase` is deprecated. Please use `torch.Event` instead.", + category=FutureWarning, +) +class _EventBase(torch.Event): + pass diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py index d4ee6eb68d689..6ef0baeeaf4ec 100644 --- a/torch/cuda/streams.py +++ b/torch/cuda/streams.py @@ -2,7 +2,6 @@ import ctypes import torch -from torch._streambase import _EventBase, _StreamBase from torch._utils import _dummy_type @@ -12,7 +11,7 @@ torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase") -class Stream(torch._C._CudaStreamBase, _StreamBase): +class Stream(torch._C._CudaStreamBase): r"""Wrapper around a CUDA stream. A CUDA stream is a linear sequence of execution that belongs to a specific @@ -138,7 +137,7 @@ def __new__(cls, stream_ptr, device=None, **kwargs): return super().__new__(cls, stream_ptr=stream_ptr, **kwargs) -class Event(torch._C._CudaEventBase, _EventBase): +class Event(torch._C._CudaEventBase): r"""Wrapper around a CUDA event. CUDA events are synchronization markers that can be used to monitor the diff --git a/torch/xpu/streams.py b/torch/xpu/streams.py index 19a7cda162f45..beb438be466d9 100644 --- a/torch/xpu/streams.py +++ b/torch/xpu/streams.py @@ -2,9 +2,7 @@ import ctypes import torch -from torch._streambase import _EventBase, _StreamBase - -from .._utils import _dummy_type +from torch._utils import _dummy_type if not hasattr(torch._C, "_XpuStreamBase"): @@ -13,7 +11,7 @@ torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase") -class Stream(torch._C._XpuStreamBase, _StreamBase): +class Stream(torch._C._XpuStreamBase): r"""Wrapper around a XPU stream. A XPU stream is a linear sequence of execution that belongs to a specific @@ -98,7 +96,7 @@ def __repr__(self): return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})" -class Event(torch._C._XpuEventBase, _EventBase): +class Event(torch._C._XpuEventBase): r"""Wrapper around a XPU event. XPU events are synchronization markers that can be used to monitor the From cf5b73e18643dda1db902e89f32e02dbd12f35e1 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 30 Aug 2024 20:21:37 +0000 Subject: [PATCH 2/6] Update [ghstack-poisoned] --- torch/_dynamo/device_interface.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index 9554b08c65ba1..b1510d8f19df8 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -import inspect from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union import torch @@ -18,21 +17,7 @@ caching_worker_current_devices: Dict[str, int] = {} -class DeviceInterfaceMeta(type): - def __new__(metacls, *args, **kwargs): - class_member = args[2] - if "Event" in class_member: - assert inspect.isclass(class_member["Event"]) and issubclass( - class_member["Event"], torch.Event - ), "DeviceInterface member Event should be inherit from torch.Event" - if "Stream" in class_member: - assert inspect.isclass(class_member["Stream"]) and issubclass( - class_member["Stream"], torch.Stream - ), "DeviceInterface member Stream should be inherit from torch.Stream" - return super().__new__(metacls, *args, **kwargs) - - -class DeviceInterface(metaclass=DeviceInterfaceMeta): +class DeviceInterface: """ This is a simple device runtime interface for Inductor. It enables custom backends to be integrated with Inductor in a device-agnostic semantic. @@ -42,6 +27,18 @@ class device: def __new__(cls, device: _device_t): raise NotImplementedError + class Event: + def __new__(cls, *args, **kwargs): + raise NotImplementedError( + "Please ensure member Event is inherited from torch.Event" + ) + + class Stream: + def __new__(cls, *args, **kwargs): + raise NotImplementedError( + "Please ensure member Stream is inherited from torch.Stream" + ) + class Worker: """ Worker API to query device properties that will work in multi processing From 8d4b185bd3cc734219be9473c0c4949b625ef06d Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 30 Aug 2024 20:27:18 +0000 Subject: [PATCH 3/6] Update [ghstack-poisoned] --- torch/_dynamo/device_interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index b1510d8f19df8..f6a8e35adeb50 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -30,13 +30,13 @@ def __new__(cls, device: _device_t): class Event: def __new__(cls, *args, **kwargs): raise NotImplementedError( - "Please ensure member Event is inherited from torch.Event" + "Event should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo." ) class Stream: def __new__(cls, *args, **kwargs): raise NotImplementedError( - "Please ensure member Stream is inherited from torch.Stream" + "Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo." ) class Worker: From 1121ca0047140af6c082d185031794103f1cc9be Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 27 Sep 2024 13:01:35 +0000 Subject: [PATCH 4/6] Update [ghstack-poisoned] --- torch/_dynamo/device_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index c22b0469f20b7..f1d5426273381 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -299,7 +299,7 @@ class CpuDeviceProperties: class CpuInterface(DeviceInterface): - class Event(_EventBase): + class Event(torch.Event): def __init__(self, enable_timing=True): self.time = 0.0 From da7e099ded83bb77c766bd2bb693722b07d33293 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 27 Sep 2024 13:33:44 +0000 Subject: [PATCH 5/6] Update [ghstack-poisoned] --- torch/_dynamo/device_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index f1d5426273381..baa26c6478988 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -306,7 +306,7 @@ def __init__(self, enable_timing=True): def elapsed_time(self, end_event) -> float: return (end_event.time - self.time) * 1000 - def record(self): + def record(self, stream=None): self.time = time.perf_counter() @staticmethod From b9e07a6ebf72068f015e0be02a281a58ba776e78 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Sun, 29 Sep 2024 08:54:00 +0000 Subject: [PATCH 6/6] Update [ghstack-poisoned] --- torch/_dynamo/device_interface.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index 36eb576512a2b..392f17e5b3a51 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -import inspect from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union import torch