Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions test/test_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,24 @@ def foo_backward(ctx, saved, grad_out):
(gx,) = torch.autograd.grad(y, x)
self.assertEqual(gx, x.cos())

@parametrize(
"tags",
[
subtest(torch.Tag.pointwise, "single"),
subtest((torch.Tag.pointwise,), "tuple"),
subtest([torch.Tag.pointwise], "list"),
],
)
def test_define_with_tags(self, tags):
lib = self.lib()
tags = (torch.Tag.pointwise,)
torch.library.define(
f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib, tags=tags
)
actual = self.ns().foo.default.tags
self.assertTrue(isinstance(actual, list))
self.assertEqual(actual, list(tags))

def test_define_and_impl(self):
lib = self.lib()
torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
Expand Down
8 changes: 5 additions & 3 deletions torch/csrc/utils/python_dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,18 +368,20 @@ void initDispatchBindings(PyObject* module) {
"define",
[](const py::object& self,
const char* schema,
const char* alias_analysis) {
const char* alias_analysis,
const std::vector<at::Tag>& tags) {
auto parsed_schema =
torch::schema(schema, parseAliasAnalysisKind(alias_analysis));
self.cast<torch::Library&>().def(
std::move(parsed_schema), {}, register_or_verify());
std::move(parsed_schema), tags, register_or_verify());
// TODO: this is dumb, had to make a second copy
return torch::schema(schema, parseAliasAnalysisKind(alias_analysis))
.name();
},
"",
py::arg("schema"),
py::arg("alias_analysis") = "")
py::arg("alias_analysis") = "",
py::arg("tags") = std::vector<at::Tag>())
.def(
"fallback_fallthrough",
[](py::object self, const char* dispatch) {
Expand Down
21 changes: 16 additions & 5 deletions torch/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,18 @@ def __init__(self, ns, kind, dispatch_key=""):
def __repr__(self):
return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"

def define(self, schema, alias_analysis=""):
def define(self, schema, alias_analysis="", *, tags=()):
r'''Defines a new operator and its semantics in the ns namespace.

Args:
schema: function schema to define a new operator.
alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be
inferred from the schema (default behavior) or not ("CONSERVATIVE").
tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this
operator. Tagging an operator changes the operator's behavior
under various PyTorch subsystems; please read the docs for the
torch.Tag carefully before applying it.

Returns:
name of the operator as inferred from the schema.

Expand All @@ -91,7 +96,9 @@ def define(self, schema, alias_analysis=""):
if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
raise RuntimeError(f"Invalid alias_analysis type {alias_analysis}")
assert self.m is not None
return self.m.define(schema, alias_analysis)
if isinstance(tags, torch.Tag):
tags = (tags,)
return self.m.define(schema, alias_analysis, tuple(tags))

def impl(self, op_name, fn, dispatch_key=''):
r'''Registers the function implementation for an operator defined in the library.
Expand Down Expand Up @@ -171,7 +178,7 @@ def _del_library(captured_impls, op_impls, registration_handles):


@functools.singledispatch
def define(qualname, schema, *, lib=None):
def define(qualname, schema, *, lib=None, tags=()):
r"""Defines a new operator.

In PyTorch, defining an op (short for "operator") is a two step-process:
Expand All @@ -191,9 +198,13 @@ def define(qualname, schema, *, lib=None):
avoid name collisions; a given operator may only be created once.
If you are writing a Python library, we recommend the namespace to
be the name of your top-level module.
schema (str): The schema of the operator
schema (str): The schema of the operator.
lib (Optional[Library]): If provided, the lifetime of this operator
will be tied to the lifetime of the Library object.
tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this
operator. Tagging an operator changes the operator's behavior
under various PyTorch subsystems; please read the docs for the
torch.Tag carefully before applying it.

Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LIBRARY)
Expand Down Expand Up @@ -222,7 +233,7 @@ def define(qualname, schema, *, lib=None):
if lib is None:
lib = Library(namespace, "FRAGMENT")
_keep_alive.append(lib)
lib.define(name + schema, alias_analysis="")
lib.define(name + schema, alias_analysis="", tags=tags)


@define.register
Expand Down