Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torchbind as attribute in torch.fx symbolic tracing #48732

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions test/test_fx.py
Expand Up @@ -1139,6 +1139,21 @@ def forward(self):
m = M()
self.checkGraphModule(m, ())

def test_torchbind_class_attribute_in_fx(self):
if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS:
self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping")

class FooBar1234(torch.nn.Module):
def __init__(self):
super(FooBar1234, self).__init__()
self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])

def forward(self):
return self.f.top()

m = FooBar1234()
self.checkGraphModule(m, ())


if __name__ == '__main__':
run_tests()
6 changes: 6 additions & 0 deletions torch/csrc/jit/python/script_init.cpp
Expand Up @@ -782,6 +782,12 @@ void initJitScriptBindings(PyObject* module) {
});
})
.def("__copy__", &Object::copy)
.def(
"__hash__",
[](const Object& self) {
// Similar to Tensor's `__hash__`, which is `id()`.
return std::hash<c10::ivalue::Object*>{}(self._ivalue().get());
})
.def(py::pickle(
[](const Object& self)
-> std::tuple<py::object, std::string> { // __getstate__
Expand Down
5 changes: 3 additions & 2 deletions torch/fx/symbolic_trace.py
Expand Up @@ -2,6 +2,7 @@
from types import CodeType, FunctionType
from typing import Any, Dict, Optional, List, Callable, Union
import torch
from torch._C import ScriptObject # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the lint error here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's Error: Module 'torch._C' has no attribute 'ScriptModule' [attr-defined]

Tried to add it in the mypy.ini for ignore import error, but it didn't work. A little bit hesitate to mark it ignore all the errors. So just mark this linne.


from .node import Argument
from .graph import Graph
Expand Down Expand Up @@ -86,7 +87,7 @@ def create_arg(self, a: Any) -> Argument:
# a get_attr to retrieve that tensor. Otherwise, we'll store away the
# tensor value into a special attribute on the Module s.t. we can
# retrieve it with a get_attr.
if isinstance(a, torch.Tensor):
if isinstance(a, (torch.Tensor, ScriptObject)):
qualname : Optional[str] = self.tensor_attrs.get(a)

# Tensor was not found in the Module hierarchy, stow it away in a
Expand Down Expand Up @@ -221,7 +222,7 @@ def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph:

def collect_tensor_attrs(m : torch.nn.Module, prefix_atoms : List[str]):
for k, v in m.__dict__.items():
if isinstance(v, torch.Tensor):
if isinstance(v, (torch.Tensor, ScriptObject)):
self.tensor_attrs[v] = '.'.join(prefix_atoms + [k])
for k, v in m.named_children():
collect_tensor_attrs(v, prefix_atoms + [k])
Expand Down