Skip to content

Commit

Permalink
Add torch.library.register_kernel (#124299)
Browse files Browse the repository at this point in the history
This mirrors the .register_kernel method on the object produced by the
custom_op decorator.

Test Plan:
- new tests
Pull Request resolved: #124299
Approved by: https://github.com/albanD
ghstack dependencies: #124180, #124200
  • Loading branch information
zou3519 authored and pytorchmergebot committed Apr 19, 2024
1 parent 3918dfe commit bad8d25
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 5 deletions.
8 changes: 5 additions & 3 deletions docs/source/library.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ Use :func:`torch.library.custom_op` to create new custom ops.
Extending custom ops (created from Python or C++)
-------------------------------------------------

Use the impl methods, such as :func:`torch.library.impl` and
func:`torch.library.impl_abstract`, to add implementations
Use the register.* methods, such as :func:`torch.library.register_kernel` and
func:`torch.library.register_fake`, to add implementations
for any operators (they may have been created using :func:`torch.library.custom_op` or
via PyTorch's C++ operator registration APIs).

.. autofunction:: impl
.. autofunction:: register_kernel
.. autofunction:: register_autograd
.. autofunction:: register_fake
.. autofunction:: impl_abstract
Expand All @@ -53,3 +53,5 @@ A tutorial that walks you through some examples on how to use this API is availa
.. autofunction:: fallthrough_kernel

.. autofunction:: define

.. autofunction:: impl
104 changes: 104 additions & 0 deletions test/test_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2359,6 +2359,110 @@ def _(x, y):
self.assertEqual(z.shape, x.shape)
self.assertTrue(called)

@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_library_register_kernel(self):
modes = ["function", "qualname", "opoverload"]
calls = ["decorator", "function"]
device_types_options = ["cpu", None]

for mode, call, device_types in itertools.product(
modes, calls, device_types_options
):

@torch.library.custom_op(
"_torch_testing::add", mutates_args=(), device_types="cuda"
)
def add(x: Tensor, y: float) -> Tensor:
x_np = x.cpu().numpy()
out_np = x_np + y
return torch.from_numpy(out_np).to(x.device)

if mode == "function":
op = add
elif mode == "qualname":
op = "_torch_testing::add"
else:
assert mode == "opoverload"
op = torch.ops._torch_testing.add.default

called = False

if call == "decorator":

@torch.library.register_kernel(op, device_types)
def _(x, y):
nonlocal called
called = True
x_np = x.numpy()
out_np = x_np + y
return torch.from_numpy(out_np)

else:
assert call == "function"

def add_cpu(x, y):
nonlocal called
called = True
x_np = x.numpy()
out_np = x_np + y
return torch.from_numpy(out_np)

torch.library.register_kernel(op, device_types, add_cpu)

x = torch.randn(3)
y = 3.14
z = add(x, y)
self.assertEqual(z, x + y)
self.assertTrue(called)

@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_library_register_kernel_low_level(self):
modes = ["qualname", "opoverload"]
calls = ["decorator", "function"]
device_types_options = [("cpu", "cuda"), "cpu", None]

for mode, call, device_types in itertools.product(
modes, calls, device_types_options
):
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
lib.define("add9(Tensor x, float y) -> Tensor")

if mode == "qualname":
op = "_torch_testing::add9"
else:
assert mode == "opoverload"
op = torch.ops._torch_testing.add9.default

called = False

if call == "decorator":

@torch.library.register_kernel(op, device_types, lib=lib)
def _(x, y):
nonlocal called
called = True
x_np = x.numpy()
out_np = x_np + y
return torch.from_numpy(out_np)

else:
assert call == "function"

def add_cpu(x, y):
nonlocal called
called = True
x_np = x.numpy()
out_np = x_np + y
return torch.from_numpy(out_np)

torch.library.register_kernel(op, device_types, add_cpu, lib=lib)

x = torch.randn(3)
y = 3.14
z = torch.ops._torch_testing.add9.default(x, y)
self.assertEqual(z, x + y)
self.assertTrue(called)

@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_library_register_autograd(self):
for mode in ["function", "qualname", "opoverload"]:
Expand Down
2 changes: 1 addition & 1 deletion torch/_library/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def register_kernel(
>>> from torch.library import custom_op
>>> import numpy as np
>>>
>>> # Example of split cpu and cuda definitions
>>> # Create a custom op that works on cpu
>>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
>>> def numpy_sin(x: Tensor) -> Tensor:
>>> x_np = x.numpy()
Expand Down
61 changes: 60 additions & 1 deletion torch/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import contextlib
import sys
import warnings
from torch._library.custom_ops import custom_op, _maybe_get_opdef
from torch._library.custom_ops import custom_op, _maybe_get_opdef, device_types_t
import torch._library as _library


Expand Down Expand Up @@ -424,6 +424,65 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
_op_identifier = Union[str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"]


def register_kernel(
op: _op_identifier,
device_types: device_types_t,
func: Optional[Callable] = None,
/,
*,
lib: Optional[Library] = None):
"""Register an implementation for a device type for this operator.
Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
This API may be used as a decorator.
Args:
fn (Callable): The function to register as the implementation for
the given device types.
device_types (None | str | Sequence[str]): The device_types to register an impl to.
If None, we will register to all device types -- please only use
this option if your implementation is truly device-type-agnostic.
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>>
>>> # Create a custom op that works on cpu
>>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
>>> def numpy_sin(x: Tensor) -> Tensor:
>>> x_np = x.numpy()
>>> y_np = np.sin(x_np)
>>> return torch.from_numpy(y_np)
>>>
>>> # Add implementations for the cuda device
>>> @torch.library.register_kernel("mylib::numpy_sin", "cuda")
>>> def _(x):
>>> x_np = x.cpu().numpy()
>>> y_np = np.sin(x_np)
>>> return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> x_cpu = torch.randn(3)
>>> x_cuda = x_cpu.cuda()
>>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
>>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
"""

if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
raise ValueError("register_kernel(op): got unexpected type for op: {type(op)}")
if isinstance(op, torch._ops.OpOverload):
op = op._name
opdef = _maybe_get_opdef(op)
if opdef is not None:
return opdef.register_kernel(device_types, func)
assert isinstance(op, str)
if device_types is None:
device_types = "CompositeExplicitAutograd"
return impl(op, device_types, func, lib=lib)


def register_fake(
op: _op_identifier,
Expand Down

0 comments on commit bad8d25

Please sign in to comment.