Skip to content

Commit

Permalink
[JIT] Add API for ignoring arbitrary module attributes
Browse files Browse the repository at this point in the history
**Summary**
This commit adds an API for ignoring arbitrary module attributes during
scripting. A class attribute named `ignored_attributes` containing names
of attributes to ignore can be added to the class of the instance being
scripted. Attributes ignored in this fashion cannot be used in
`forward`, methods used by `forward` or by `@exported` methods. They
are, however, copied to the `RecursiveScriptModule` wrapper and can be
used by `@ignored` methods and regular Python code.

**Test Plan**
This commit adds unit tests to `TestScriptPy3` to test this new API.

ghstack-source-id: b9fcadc0f56ffe932250905f1c31743ef9de9a0c
Pull Request resolved: #45262
  • Loading branch information
Meghan Lele committed Sep 24, 2020
1 parent 6449867 commit 2d43bec
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 2 deletions.
50 changes: 50 additions & 0 deletions test/test_jit_py3.py
Expand Up @@ -678,6 +678,56 @@ def attr(self):
with self.assertRaisesRegex(torch.nn.modules.module.ModuleAttributeError, "has no attribute"):
scripted_mod.ignored_attr

def test_ignoring_module_attributes(self):
"""
Test that module attributes can be ignored.
"""
class Sub(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a: int) -> int:
return sum([a])

class ModuleWithIgnoredAttr(torch.nn.Module):
__ignored_attributes__ = ["a", "sub"]

def __init__(self, a: int, b: int):
super().__init__()
self.a = a
self.b = b
self.sub = Sub()

def forward(self) -> int:
return self.b

@torch.jit.ignore
def ignored_fn(self) -> int:
return self.sub.forward(self.a)

mod = ModuleWithIgnoredAttr(1, 4)
scripted_mod = torch.jit.script(mod)
self.assertEqual(scripted_mod(), 4)
self.assertEqual(scripted_mod.ignored_fn(), 1)

# Test the error message for ignored attributes.
class ModuleUsesIgnoredAttr(torch.nn.Module):
__ignored_attributes__ = ["a", "sub"]

def __init__(self, a: int):
super().__init__()
self.a = a
self.sub = Sub()

def forward(self) -> int:
return self.sub(self.b)

mod = ModuleUsesIgnoredAttr(1)

with self.assertRaisesRegexWithHighlight(RuntimeError, r"attribute was ignored during compilation", "self.sub"):
scripted_mod = torch.jit.script(mod)


def test_export_opnames_interface(self):
global OneTwoModule

Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/jit/frontend/concrete_module_type.cpp
Expand Up @@ -186,6 +186,10 @@ c10::optional<std::string> ConcreteModuleType::findFailedAttribute(
return c10::nullopt;
}

bool ConcreteModuleType::isIgnoredAttribute(const std::string& name) const {
return data_.ignoredAttributes_.count(name) > 0;
}

std::shared_ptr<ConcreteModuleType> ConcreteModuleType::
findSubmoduleConcreteType(const std::string& name) const {
const auto it = std::find_if(
Expand Down Expand Up @@ -281,6 +285,10 @@ void ConcreteModuleTypeBuilder::addFailedAttribute(
failedAttributes_.emplace(std::move(name), std::move(failureReason));
}

void ConcreteModuleTypeBuilder::addIgnoredAttribute(std::string name) {
ignoredAttributes_.emplace(std::move(name));
}

void ConcreteModuleType::dump() const {
std::cout << "ConcreteModuleType for: "
<< py::getattr(data_.pyClass_, "__name__") << "\n";
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/frontend/concrete_module_type.h
Expand Up @@ -81,6 +81,7 @@ class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder {
std::vector<std::string> overloadedMethodNames);
void addBuiltinFunction(std::string name, std::string symbol_name);
void addFailedAttribute(std::string name, std::string failureReason);
void addIgnoredAttribute(std::string name);
void setIterableModuleKind(IterableModuleKind kind);

// If a ConcreteModuleType is poisoned, it will never compare equal to any
Expand Down Expand Up @@ -150,6 +151,9 @@ class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder {
// Any attributes we failed to convert to TorchScript, along with a hint as to
// why
std::unordered_map<std::string, std::string> failedAttributes_;
// Any attributes that were marked as ignored. They cannot be used in
// TorchScript but can still be used in ignored function in Python.
std::unordered_set<std::string> ignoredAttributes_;
// Any function attributes. These are special right now because functions are
// not first-class in the type system.
std::unordered_map<std::string, FunctionAttribute> functionAttributes_;
Expand Down Expand Up @@ -191,6 +195,7 @@ class VISIBILITY_HIDDEN ConcreteModuleType {
std::shared_ptr<ConcreteModuleType> findSubmoduleConcreteType(
const std::string& name) const;
c10::optional<std::string> findFailedAttribute(const std::string& name) const;
bool isIgnoredAttribute(const std::string& name) const;

// These getters are only here to return things as types that can be
// automatically converted by pybind.
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/python/python_sugared_value.cpp
Expand Up @@ -586,12 +586,14 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
std::string hint;
if (auto failureReason = concreteType_->findFailedAttribute(field)) {
hint = *failureReason;
} else if (concreteType_->isIgnoredAttribute(field)) {
hint = "attribute was ignored during compilation";
}

throw ErrorReport(loc)
<< "Module '"
<< concreteType_->getJitType()->expect<ClassType>()->name()->name() << "'"
<< " has no attribute '" << field << "' " << hint;
<< " has no attribute '" << field << "', " << hint;
}

SugaredValuePtr ModuleValue::iter(const SourceRange& loc, Function& m) {
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/python/script_init.cpp
Expand Up @@ -1579,6 +1579,9 @@ void initJitScriptBindings(PyObject* module) {
.def(
"add_failed_attribute",
&ConcreteModuleTypeBuilder::addFailedAttribute)
.def(
"add_ignored_attribute",
&ConcreteModuleTypeBuilder::addIgnoredAttribute)
.def(
"set_module_dict",
[](ConcreteModuleTypeBuilder& self) {
Expand All @@ -1604,6 +1607,7 @@ void initJitScriptBindings(PyObject* module) {
.def("get_attributes", &ConcreteModuleType::getAttributesPy)
.def("get_modules", &ConcreteModuleType::getModulesPy)
.def("dump", &ConcreteModuleType::dump)
.def("is_ignored_attribute", &ConcreteModuleType::isIgnoredAttribute)
.def(
"equals",
[](const ConcreteModuleType& self, const ConcreteModuleType& other) {
Expand Down
22 changes: 21 additions & 1 deletion torch/jit/_recursive.py
Expand Up @@ -91,6 +91,8 @@ def infer_concrete_type_builder(nn_module, share_types=True):
concrete_type_builder.set_module_list()

class_annotations = getattr(nn_module, '__annotations__', {})
# Get user-annotated ignored attributes.
user_annotated_ignored_attributes = getattr(nn_module, "__ignored_attributes__", set())

# try to infer the type from type annotation or from the object itself
def infer_type(name, item):
Expand All @@ -110,6 +112,10 @@ def infer_type(name, item):
added_names = set()

for name, item in nn_module._parameters.items():
if name in user_annotated_ignored_attributes:
concrete_type_builder.add_ignored_attribute(name)
continue

assert item is None or isinstance(item, torch.Tensor)
attr_type = infer_type(name, item)
# We currently have the invariant in various places in our code
Expand All @@ -121,12 +127,20 @@ def infer_type(name, item):
added_names.add(name)

for name, item in nn_module._buffers.items():
if name in user_annotated_ignored_attributes:
concrete_type_builder.add_ignored_attribute(name)
continue

assert item is None or isinstance(item, torch.Tensor)
attr_type = infer_type(name, item)
concrete_type_builder.add_attribute(name, attr_type, False, True)
added_names.add(name)

for name, item in nn_module._modules.items():
if name in user_annotated_ignored_attributes:
concrete_type_builder.add_ignored_attribute(name)
continue

attr_type = infer_type(name, item)
if item is None:
# Modules can be None. We don't have direct support for optional
Expand Down Expand Up @@ -192,6 +206,10 @@ def infer_type(name, item):
# PyTorch adds a few more. Prevent these from getting compiled.
continue

if name in user_annotated_ignored_attributes:
concrete_type_builder.add_ignored_attribute(name)
continue

if name in added_names:
# Don't re-add anything we already added
continue
Expand Down Expand Up @@ -377,14 +395,16 @@ def init_fn(script_module):
cpp_module.setattr(name, scripted)
script_module._modules[name] = scripted

# 3. Copy @ignored/@unused methods from the original `nn_module` to the new ScriptModule.
# 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule.
# This ensures we can access these Python methods on the ScriptModule.
for name in dir(nn_module):
item = getattr(nn_module, name, None)
if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
unbound_function = getattr(type(nn_module), name)
bound_method = unbound_function.__get__(script_module)
setattr(script_module, name, bound_method)
elif concrete_type.is_ignored_attribute(name):
setattr(script_module, name, item)

# For convenience, attach the concrete type to the new ScriptModule
script_module._concrete_type = concrete_type
Expand Down

0 comments on commit 2d43bec

Please sign in to comment.