diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 4df2e040ceae..3f7b2562ae25 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -1675,6 +1675,24 @@ def foo_backward(ctx, saved, grad_out): (gx,) = torch.autograd.grad(y, x) self.assertEqual(gx, x.cos()) + @parametrize( + "tags", + [ + subtest(torch.Tag.pointwise, "single"), + subtest((torch.Tag.pointwise,), "tuple"), + subtest([torch.Tag.pointwise], "list"), + ], + ) + def test_define_with_tags(self, tags): + lib = self.lib() + tags = (torch.Tag.pointwise,) + torch.library.define( + f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib, tags=tags + ) + actual = self.ns().foo.default.tags + self.assertTrue(isinstance(actual, list)) + self.assertEqual(actual, list(tags)) + def test_define_and_impl(self): lib = self.lib() torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 423386cc6536..25280c09e5fd 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -368,18 +368,20 @@ void initDispatchBindings(PyObject* module) { "define", [](const py::object& self, const char* schema, - const char* alias_analysis) { + const char* alias_analysis, + const std::vector& tags) { auto parsed_schema = torch::schema(schema, parseAliasAnalysisKind(alias_analysis)); self.cast().def( - std::move(parsed_schema), {}, register_or_verify()); + std::move(parsed_schema), tags, register_or_verify()); // TODO: this is dumb, had to make a second copy return torch::schema(schema, parseAliasAnalysisKind(alias_analysis)) .name(); }, "", py::arg("schema"), - py::arg("alias_analysis") = "") + py::arg("alias_analysis") = "", + py::arg("tags") = std::vector()) .def( "fallback_fallthrough", [](py::object self, const char* dispatch) { diff --git a/torch/library.py b/torch/library.py index d3573056877f..1a082e9f2e90 100644 --- a/torch/library.py +++ b/torch/library.py @@ -71,13 +71,18 @@ def __init__(self, ns, kind, dispatch_key=""): def __repr__(self): return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>" - def define(self, schema, alias_analysis=""): + def define(self, schema, alias_analysis="", *, tags=()): r'''Defines a new operator and its semantics in the ns namespace. Args: schema: function schema to define a new operator. alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be inferred from the schema (default behavior) or not ("CONSERVATIVE"). + tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this + operator. Tagging an operator changes the operator's behavior + under various PyTorch subsystems; please read the docs for the + torch.Tag carefully before applying it. + Returns: name of the operator as inferred from the schema. @@ -91,7 +96,9 @@ def define(self, schema, alias_analysis=""): if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]: raise RuntimeError(f"Invalid alias_analysis type {alias_analysis}") assert self.m is not None - return self.m.define(schema, alias_analysis) + if isinstance(tags, torch.Tag): + tags = (tags,) + return self.m.define(schema, alias_analysis, tuple(tags)) def impl(self, op_name, fn, dispatch_key=''): r'''Registers the function implementation for an operator defined in the library. @@ -171,7 +178,7 @@ def _del_library(captured_impls, op_impls, registration_handles): @functools.singledispatch -def define(qualname, schema, *, lib=None): +def define(qualname, schema, *, lib=None, tags=()): r"""Defines a new operator. In PyTorch, defining an op (short for "operator") is a two step-process: @@ -191,9 +198,13 @@ def define(qualname, schema, *, lib=None): avoid name collisions; a given operator may only be created once. If you are writing a Python library, we recommend the namespace to be the name of your top-level module. - schema (str): The schema of the operator + schema (str): The schema of the operator. lib (Optional[Library]): If provided, the lifetime of this operator will be tied to the lifetime of the Library object. + tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this + operator. Tagging an operator changes the operator's behavior + under various PyTorch subsystems; please read the docs for the + torch.Tag carefully before applying it. Example:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LIBRARY) @@ -222,7 +233,7 @@ def define(qualname, schema, *, lib=None): if lib is None: lib = Library(namespace, "FRAGMENT") _keep_alive.append(lib) - lib.define(name + schema, alias_analysis="") + lib.define(name + schema, alias_analysis="", tags=tags) @define.register