diff --git a/test/test_fx.py b/test/test_fx.py index 5a47c729f7eb..af11f9615cb6 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -1139,6 +1139,21 @@ def forward(self): m = M() self.checkGraphModule(m, ()) + def test_torchbind_class_attribute_in_fx(self): + if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS: + self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping") + + class FooBar1234(torch.nn.Module): + def __init__(self): + super(FooBar1234, self).__init__() + self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"]) + + def forward(self): + return self.f.top() + + m = FooBar1234() + self.checkGraphModule(m, ()) + if __name__ == '__main__': run_tests() diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index feab73df6d1b..426707e303d3 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -782,6 +782,12 @@ void initJitScriptBindings(PyObject* module) { }); }) .def("__copy__", &Object::copy) + .def( + "__hash__", + [](const Object& self) { + // Similar to Tensor's `__hash__`, which is `id()`. + return std::hash{}(self._ivalue().get()); + }) .def(py::pickle( [](const Object& self) -> std::tuple { // __getstate__ diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py index b2e5b0961114..d48a067f5e56 100644 --- a/torch/fx/symbolic_trace.py +++ b/torch/fx/symbolic_trace.py @@ -2,6 +2,7 @@ from types import CodeType, FunctionType from typing import Any, Dict, Optional, List, Callable, Union import torch +from torch._C import ScriptObject # type: ignore from .node import Argument from .graph import Graph @@ -86,7 +87,7 @@ def create_arg(self, a: Any) -> Argument: # a get_attr to retrieve that tensor. Otherwise, we'll store away the # tensor value into a special attribute on the Module s.t. we can # retrieve it with a get_attr. - if isinstance(a, torch.Tensor): + if isinstance(a, (torch.Tensor, ScriptObject)): qualname : Optional[str] = self.tensor_attrs.get(a) # Tensor was not found in the Module hierarchy, stow it away in a @@ -221,7 +222,7 @@ def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph: def collect_tensor_attrs(m : torch.nn.Module, prefix_atoms : List[str]): for k, v in m.__dict__.items(): - if isinstance(v, torch.Tensor): + if isinstance(v, (torch.Tensor, ScriptObject)): self.tensor_attrs[v] = '.'.join(prefix_atoms + [k]) for k, v in m.named_children(): collect_tensor_attrs(v, prefix_atoms + [k])