Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] Add API for ignoring arbitrary module attributes #45262

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 51 additions & 0 deletions test/jit/test_type_sharing.py
Expand Up @@ -560,3 +560,54 @@ def forward(self, x):
self.assertDifferentType(top1_s, top2_s.sub)
self.assertDifferentType(top2_s, top2_s.sub)
self.assertDifferentType(top2_s, top1_s.sub)

def test_type_shared_ignored_attributes(self):
"""
Test that types are shared if the exclusion of their
ignored attributes makes them equal.
"""
class A(torch.nn.Module):
__jit_ignored_attributes__ = ["a"]

def __init__(self, a, b):
super().__init__()
self.a = a
self.b = b

def forward(self, x):
return x

a_with_linear = A(torch.nn.Linear(5, 5), 5)
a_with_string = A("string", 10)

# Both should have the same type because the attribute
# that differs in type is ignored and the common attribute
# has the same type.
self.assertSameType(a_with_linear, a_with_string)

def test_type_not_shared_ignored_attributes(self):
"""
Test that types are not shared if the exclusion of their
ignored attributes makes them not equal.
"""
class A(torch.nn.Module):
__jit_ignored_attributes__ = ["a"]

def __init__(self, a, b, c):
super().__init__()
self.a = a
self.b = b
self.c = c

def forward(self, x):
return x

mod = A(torch.nn.Linear(5, 5), 5, "string")
s1 = torch.jit.script(mod)
A.__jit_ignored_attributes__ = ["a", "b"]
s2 = torch.jit.script(mod)

# The types of s1 and s2 should differ. Although they are instances
# of A, __jit_ignored_attributes__ was modified before scripting s2,
# so the set of ignored attributes is different between s1 and s2.
self.assertDifferentType(s1, s2)
50 changes: 50 additions & 0 deletions test/test_jit_py3.py
Expand Up @@ -678,6 +678,56 @@ def attr(self):
with self.assertRaisesRegex(torch.nn.modules.module.ModuleAttributeError, "has no attribute"):
scripted_mod.ignored_attr

def test_ignoring_module_attributes(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. can we add tests to test_type_sharing.py to make sure we are not improperly sharing types with different ignored attributes sets?
  2. I think we should stop adding to test_jit_py3.py. This file originally existed because it let us use py3 syntax, but that no longer has a purpose now that we are py3 only.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I added tests to test_type_sharing.py to test that types are shared if ignoring attrs make them equal and not shared if the ignored attributes list is modified and the types of modules before and after the modification should be different.
  2. Where should I move them?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 2: not sure, I wish we had the tests that exercise "compilation customization" APIs, but they have not yet been factored out of test_jit.py.

"""
Test that module attributes can be ignored.
"""
class Sub(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a: int) -> int:
return sum([a])

class ModuleWithIgnoredAttr(torch.nn.Module):
__jit_ignored_attributes__ = ["a", "sub"]

def __init__(self, a: int, b: int):
super().__init__()
self.a = a
self.b = b
self.sub = Sub()

def forward(self) -> int:
return self.b

@torch.jit.ignore
def ignored_fn(self) -> int:
return self.sub.forward(self.a)

mod = ModuleWithIgnoredAttr(1, 4)
scripted_mod = torch.jit.script(mod)
self.assertEqual(scripted_mod(), 4)
self.assertEqual(scripted_mod.ignored_fn(), 1)

# Test the error message for ignored attributes.
class ModuleUsesIgnoredAttr(torch.nn.Module):
__jit_ignored_attributes__ = ["a", "sub"]

def __init__(self, a: int):
super().__init__()
self.a = a
self.sub = Sub()

def forward(self) -> int:
return self.sub(self.b)

mod = ModuleUsesIgnoredAttr(1)

with self.assertRaisesRegexWithHighlight(RuntimeError, r"attribute was ignored during compilation", "self.sub"):
scripted_mod = torch.jit.script(mod)


def test_export_opnames_interface(self):
global OneTwoModule

Expand Down
2 changes: 2 additions & 0 deletions torch/_C/__init__.pyi.in
Expand Up @@ -284,6 +284,8 @@ class ConcreteModuleTypeBuilder:
def add_builtin_function(self, name: str, symbol_name: str): ...
def add_failed_attribute(self, name: str, failure_reason: str): ...
def add_function_attribute(self, name: str, ty: JitType, func: Callable[..., Any]): ...
def add_ignored_attribute(self, name: str): ...
def add_ignored_attributes(self, names: List[str]): ...

class ConcreteModuleType:
def get_constants(self) -> Dict[str, Any]: ...
Expand Down
9 changes: 9 additions & 0 deletions torch/csrc/jit/frontend/concrete_module_type.cpp
Expand Up @@ -106,6 +106,7 @@ bool ConcreteModuleTypeBuilder::equals(
bool equal =
pyClass_.is(other.pyClass_) &&
iterableModuleKind_ == other.iterableModuleKind_ &&
ignoredAttributes_ == other.ignoredAttributes_ &&
constants_ == other.constants_ &&
attributes_ == other.attributes_ &&
overloads_ == other.overloads_ &&
Expand Down Expand Up @@ -186,6 +187,10 @@ c10::optional<std::string> ConcreteModuleType::findFailedAttribute(
return c10::nullopt;
}

bool ConcreteModuleType::isIgnoredAttribute(const std::string& name) const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to edit ConcreteModuleTypeBuilder::equals so that this field gets used to disambiguate instances of a concrete type.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this but it's hard to test because ignoring an attribute changes the list of attributes on the ConcreteModuleTypeBuilder instance. The equality check between ConcreteModuleTypeBuilder instances (also ConcreteModuleType and Builder instances) fails when when an attribute is ignored even if there is no ignoredAttributes_ == other.ignoredAttributes_ condition added to ConcreteModuleTypeBuilder::equals because attributes_ == other.attributes_ is false .

I'm not saying we shouldn't have it, but I can't prove that it works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, ConcreteMuduleType::equals is to disambiguate between different instantiations of an nn.Module which can result in a different type. __jit_ignored_attributes__ is a class-level attribute, I don't see how it could differ between different instantiations.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's my understanding too, but as per one of my tests, the class attribute can be edited between invocations to torch.jit.script. I don't think it's a realistic use case, but it is possible. But as I mentioned in my earlier comment, unintended type sharing does not happen because modifying that class attribute changes which attributes are included/excluded in ConcreteModuleTypeBuilder.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be edited, or we can set it in the body of the __init__, e.g.

def __init__(self, ignored):
    self.__ignored__ = ignored

Both happen with some degree of regularity for other elements of ConcreteModuleType. But good point that it changes the available attributes, so checking ignored attrs is somewhat redundant. I think we should do it still for consistency, since all other attributes of ConcreteModuleType should participate in equality.

return data_.ignoredAttributes_.count(name) > 0;
}

std::shared_ptr<ConcreteModuleType> ConcreteModuleType::
findSubmoduleConcreteType(const std::string& name) const {
const auto it = std::find_if(
Expand Down Expand Up @@ -281,6 +286,10 @@ void ConcreteModuleTypeBuilder::addFailedAttribute(
failedAttributes_.emplace(std::move(name), std::move(failureReason));
}

void ConcreteModuleTypeBuilder::addIgnoredAttribute(std::string name) {
ignoredAttributes_.emplace(std::move(name));
}

void ConcreteModuleType::dump() const {
std::cout << "ConcreteModuleType for: "
<< py::getattr(data_.pyClass_, "__name__") << "\n";
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/frontend/concrete_module_type.h
Expand Up @@ -81,6 +81,7 @@ class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder {
std::vector<std::string> overloadedMethodNames);
void addBuiltinFunction(std::string name, std::string symbol_name);
void addFailedAttribute(std::string name, std::string failureReason);
void addIgnoredAttribute(std::string name);
void setIterableModuleKind(IterableModuleKind kind);

// If a ConcreteModuleType is poisoned, it will never compare equal to any
Expand Down Expand Up @@ -150,6 +151,9 @@ class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder {
// Any attributes we failed to convert to TorchScript, along with a hint as to
// why
std::unordered_map<std::string, std::string> failedAttributes_;
// Any attributes that were marked as ignored. They cannot be used in
// TorchScript but can still be used in ignored function in Python.
std::unordered_set<std::string> ignoredAttributes_;
// Any function attributes. These are special right now because functions are
// not first-class in the type system.
std::unordered_map<std::string, FunctionAttribute> functionAttributes_;
Expand Down Expand Up @@ -191,6 +195,7 @@ class VISIBILITY_HIDDEN ConcreteModuleType {
std::shared_ptr<ConcreteModuleType> findSubmoduleConcreteType(
const std::string& name) const;
c10::optional<std::string> findFailedAttribute(const std::string& name) const;
bool isIgnoredAttribute(const std::string& name) const;

// These getters are only here to return things as types that can be
// automatically converted by pybind.
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/python/python_sugared_value.cpp
Expand Up @@ -586,6 +586,8 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
std::string hint;
if (auto failureReason = concreteType_->findFailedAttribute(field)) {
hint = *failureReason;
} else if (concreteType_->isIgnoredAttribute(field)) {
hint = "attribute was ignored during compilation";
}

throw ErrorReport(loc)
Expand Down
12 changes: 12 additions & 0 deletions torch/csrc/jit/python/script_init.cpp
Expand Up @@ -1564,6 +1564,17 @@ void initJitScriptBindings(PyObject* module) {
.def(
"add_failed_attribute",
&ConcreteModuleTypeBuilder::addFailedAttribute)
.def(
"add_ignored_attribute",
&ConcreteModuleTypeBuilder::addIgnoredAttribute)
.def(
"add_ignored_attributes",
[](ConcreteModuleTypeBuilder& self,
const std::vector<std::string>& names) {
for (auto& name : names) {
self.addIgnoredAttribute(name);
}
})
.def(
"set_module_dict",
[](ConcreteModuleTypeBuilder& self) {
Expand All @@ -1589,6 +1600,7 @@ void initJitScriptBindings(PyObject* module) {
.def("get_attributes", &ConcreteModuleType::getAttributesPy)
.def("get_modules", &ConcreteModuleType::getModulesPy)
.def("dump", &ConcreteModuleType::dump)
.def("is_ignored_attribute", &ConcreteModuleType::isIgnoredAttribute)
.def(
"equals",
[](const ConcreteModuleType& self, const ConcreteModuleType& other) {
Expand Down
20 changes: 19 additions & 1 deletion torch/jit/_recursive.py
Expand Up @@ -105,6 +105,10 @@ def infer_concrete_type_builder(nn_module, share_types=True):

class_annotations = getattr(nn_module, '__annotations__', {})

# Get user-annotated ignored attributes.
user_annotated_ignored_attributes = getattr(nn_module, "__jit_ignored_attributes__", list())
concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes)

# try to infer the type from type annotation or from the object itself
def infer_type(name, item):
# The forward function from Module is special; never use this annotations; we
Expand All @@ -123,6 +127,9 @@ def infer_type(name, item):
added_names = set()

for name, item in nn_module._parameters.items():
if name in user_annotated_ignored_attributes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to check user_annotated_ignored_attributes for each present attribute, we can just add it all to the concrete type right away, e.g.

concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes)

The only difference would be that ignored attributes would be added even if they aren't present in the module object, but that won't make a difference in practice I think.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we still need to skip adding attributes if they are ignored, right? So instead of

if name in user_annotated_ignored_attributes:
    concrete_type_builder.add_ignored_attribute(name)
    continue

we'd have

if name in user_annotated_ignored_attributes:
    continue

or

if {existing_condition} and name not in user_annotated_ignored_attributes:
     concrete_type_builder.add_xxx(...)

continue

assert item is None or isinstance(item, torch.Tensor)
attr_type = infer_type(name, item)
# We currently have the invariant in various places in our code
Expand All @@ -134,12 +141,18 @@ def infer_type(name, item):
added_names.add(name)

for name, item in nn_module._buffers.items():
if name in user_annotated_ignored_attributes:
continue

assert item is None or isinstance(item, torch.Tensor)
attr_type = infer_type(name, item)
concrete_type_builder.add_attribute(name, attr_type, False, True)
added_names.add(name)

for name, item in nn_module._modules.items():
if name in user_annotated_ignored_attributes:
continue

attr_type = infer_type(name, item)
if item is None:
# Modules can be None. We don't have direct support for optional
Expand Down Expand Up @@ -205,6 +218,9 @@ def infer_type(name, item):
# PyTorch adds a few more. Prevent these from getting compiled.
continue

if name in user_annotated_ignored_attributes:
continue

if name in added_names:
# Don't re-add anything we already added
continue
Expand Down Expand Up @@ -390,14 +406,16 @@ def init_fn(script_module):
cpp_module.setattr(name, scripted)
script_module._modules[name] = scripted

# 3. Copy @ignored/@unused methods from the original `nn_module` to the new ScriptModule.
# 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule.
# This ensures we can access these Python methods on the ScriptModule.
for name in dir(nn_module):
item = getattr(nn_module, name, None)
if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
unbound_function = getattr(type(nn_module), name)
bound_method = unbound_function.__get__(script_module)
setattr(script_module, name, bound_method)
elif concrete_type.is_ignored_attribute(name):
setattr(script_module, name, item)

# For convenience, attach the concrete type to the new ScriptModule
script_module._concrete_type = concrete_type
Expand Down