diff --git a/test/test_jit_py3.py b/test/test_jit_py3.py index 6c19335b0edc..098e2f9cc0ec 100644 --- a/test/test_jit_py3.py +++ b/test/test_jit_py3.py @@ -678,6 +678,56 @@ def attr(self): with self.assertRaisesRegex(torch.nn.modules.module.ModuleAttributeError, "has no attribute"): scripted_mod.ignored_attr + def test_ignoring_module_attributes(self): + """ + Test that module attributes can be ignored. + """ + class Sub(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a: int) -> int: + return sum([a]) + + class ModuleWithIgnoredAttr(torch.nn.Module): + __ignored_attributes__ = ["a", "sub"] + + def __init__(self, a: int, b: int): + super().__init__() + self.a = a + self.b = b + self.sub = Sub() + + def forward(self) -> int: + return self.b + + @torch.jit.ignore + def ignored_fn(self) -> int: + return self.sub.forward(self.a) + + mod = ModuleWithIgnoredAttr(1, 4) + scripted_mod = torch.jit.script(mod) + self.assertEqual(scripted_mod(), 4) + self.assertEqual(scripted_mod.ignored_fn(), 1) + + # Test the error message for ignored attributes. + class ModuleUsesIgnoredAttr(torch.nn.Module): + __ignored_attributes__ = ["a", "sub"] + + def __init__(self, a: int): + super().__init__() + self.a = a + self.sub = Sub() + + def forward(self) -> int: + return self.sub(self.b) + + mod = ModuleUsesIgnoredAttr(1) + + with self.assertRaisesRegexWithHighlight(RuntimeError, r"attribute was ignored during compilation", "self.sub"): + scripted_mod = torch.jit.script(mod) + + def test_export_opnames_interface(self): global OneTwoModule diff --git a/torch/csrc/jit/frontend/concrete_module_type.cpp b/torch/csrc/jit/frontend/concrete_module_type.cpp index 169b589cbfff..dc619aa79ccb 100644 --- a/torch/csrc/jit/frontend/concrete_module_type.cpp +++ b/torch/csrc/jit/frontend/concrete_module_type.cpp @@ -186,6 +186,10 @@ c10::optional ConcreteModuleType::findFailedAttribute( return c10::nullopt; } +bool ConcreteModuleType::isIgnoredAttribute(const std::string& name) const { + return data_.ignoredAttributes_.count(name) > 0; +} + std::shared_ptr ConcreteModuleType:: findSubmoduleConcreteType(const std::string& name) const { const auto it = std::find_if( @@ -281,6 +285,10 @@ void ConcreteModuleTypeBuilder::addFailedAttribute( failedAttributes_.emplace(std::move(name), std::move(failureReason)); } +void ConcreteModuleTypeBuilder::addIgnoredAttribute(std::string name) { + ignoredAttributes_.emplace(std::move(name)); +} + void ConcreteModuleType::dump() const { std::cout << "ConcreteModuleType for: " << py::getattr(data_.pyClass_, "__name__") << "\n"; diff --git a/torch/csrc/jit/frontend/concrete_module_type.h b/torch/csrc/jit/frontend/concrete_module_type.h index 0410693d439c..1b1a9142f73e 100644 --- a/torch/csrc/jit/frontend/concrete_module_type.h +++ b/torch/csrc/jit/frontend/concrete_module_type.h @@ -81,6 +81,7 @@ class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder { std::vector overloadedMethodNames); void addBuiltinFunction(std::string name, std::string symbol_name); void addFailedAttribute(std::string name, std::string failureReason); + void addIgnoredAttribute(std::string name); void setIterableModuleKind(IterableModuleKind kind); // If a ConcreteModuleType is poisoned, it will never compare equal to any @@ -150,6 +151,9 @@ class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder { // Any attributes we failed to convert to TorchScript, along with a hint as to // why std::unordered_map failedAttributes_; + // Any attributes that were marked as ignored. They cannot be used in + // TorchScript but can still be used in ignored function in Python. + std::unordered_set ignoredAttributes_; // Any function attributes. These are special right now because functions are // not first-class in the type system. std::unordered_map functionAttributes_; @@ -191,6 +195,7 @@ class VISIBILITY_HIDDEN ConcreteModuleType { std::shared_ptr findSubmoduleConcreteType( const std::string& name) const; c10::optional findFailedAttribute(const std::string& name) const; + bool isIgnoredAttribute(const std::string& name) const; // These getters are only here to return things as types that can be // automatically converted by pybind. diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index ba94d33f37b3..66f3892b053b 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -586,12 +586,14 @@ std::shared_ptr ModuleValue::attr( std::string hint; if (auto failureReason = concreteType_->findFailedAttribute(field)) { hint = *failureReason; + } else if (concreteType_->isIgnoredAttribute(field)) { + hint = "attribute was ignored during compilation"; } throw ErrorReport(loc) << "Module '" << concreteType_->getJitType()->expect()->name()->name() << "'" - << " has no attribute '" << field << "' " << hint; + << " has no attribute '" << field << "', " << hint; } SugaredValuePtr ModuleValue::iter(const SourceRange& loc, Function& m) { diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 95d041fe315b..cf5d1fb6c982 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -1579,6 +1579,9 @@ void initJitScriptBindings(PyObject* module) { .def( "add_failed_attribute", &ConcreteModuleTypeBuilder::addFailedAttribute) + .def( + "add_ignored_attribute", + &ConcreteModuleTypeBuilder::addIgnoredAttribute) .def( "set_module_dict", [](ConcreteModuleTypeBuilder& self) { @@ -1604,6 +1607,7 @@ void initJitScriptBindings(PyObject* module) { .def("get_attributes", &ConcreteModuleType::getAttributesPy) .def("get_modules", &ConcreteModuleType::getModulesPy) .def("dump", &ConcreteModuleType::dump) + .def("is_ignored_attribute", &ConcreteModuleType::isIgnoredAttribute) .def( "equals", [](const ConcreteModuleType& self, const ConcreteModuleType& other) { diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 85853cd1b1ee..143fd1db3846 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -91,6 +91,8 @@ def infer_concrete_type_builder(nn_module, share_types=True): concrete_type_builder.set_module_list() class_annotations = getattr(nn_module, '__annotations__', {}) + # Get user-annotated ignored attributes. + user_annotated_ignored_attributes = getattr(nn_module, "__ignored_attributes__", set()) # try to infer the type from type annotation or from the object itself def infer_type(name, item): @@ -110,6 +112,10 @@ def infer_type(name, item): added_names = set() for name, item in nn_module._parameters.items(): + if name in user_annotated_ignored_attributes: + concrete_type_builder.add_ignored_attribute(name) + continue + assert item is None or isinstance(item, torch.Tensor) attr_type = infer_type(name, item) # We currently have the invariant in various places in our code @@ -121,12 +127,20 @@ def infer_type(name, item): added_names.add(name) for name, item in nn_module._buffers.items(): + if name in user_annotated_ignored_attributes: + concrete_type_builder.add_ignored_attribute(name) + continue + assert item is None or isinstance(item, torch.Tensor) attr_type = infer_type(name, item) concrete_type_builder.add_attribute(name, attr_type, False, True) added_names.add(name) for name, item in nn_module._modules.items(): + if name in user_annotated_ignored_attributes: + concrete_type_builder.add_ignored_attribute(name) + continue + attr_type = infer_type(name, item) if item is None: # Modules can be None. We don't have direct support for optional @@ -192,6 +206,10 @@ def infer_type(name, item): # PyTorch adds a few more. Prevent these from getting compiled. continue + if name in user_annotated_ignored_attributes: + concrete_type_builder.add_ignored_attribute(name) + continue + if name in added_names: # Don't re-add anything we already added continue @@ -377,7 +395,7 @@ def init_fn(script_module): cpp_module.setattr(name, scripted) script_module._modules[name] = scripted - # 3. Copy @ignored/@unused methods from the original `nn_module` to the new ScriptModule. + # 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule. # This ensures we can access these Python methods on the ScriptModule. for name in dir(nn_module): item = getattr(nn_module, name, None) @@ -385,6 +403,8 @@ def init_fn(script_module): unbound_function = getattr(type(nn_module), name) bound_method = unbound_function.__get__(script_module) setattr(script_module, name, bound_method) + elif concrete_type.is_ignored_attribute(name): + setattr(script_module, name, item) # For convenience, attach the concrete type to the new ScriptModule script_module._concrete_type = concrete_type