New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JIT] Add API for ignoring arbitrary module attributes #45262
Changes from all commits
72e4840
e29bc87
2d38d2f
e89e74e
5479ecb
14e998d
7d38b5d
ec913d6
dd8aad0
54fcb1d
7e9137f
f728b0f
ded8d28
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -106,6 +106,7 @@ bool ConcreteModuleTypeBuilder::equals( | |
bool equal = | ||
pyClass_.is(other.pyClass_) && | ||
iterableModuleKind_ == other.iterableModuleKind_ && | ||
ignoredAttributes_ == other.ignoredAttributes_ && | ||
constants_ == other.constants_ && | ||
attributes_ == other.attributes_ && | ||
overloads_ == other.overloads_ && | ||
|
@@ -186,6 +187,10 @@ c10::optional<std::string> ConcreteModuleType::findFailedAttribute( | |
return c10::nullopt; | ||
} | ||
|
||
bool ConcreteModuleType::isIgnoredAttribute(const std::string& name) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to edit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added this but it's hard to test because ignoring an attribute changes the list of attributes on the I'm not saying we shouldn't have it, but I can't prove that it works. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I understand, ConcreteMuduleType::equals is to disambiguate between different instantiations of an nn.Module which can result in a different type. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's my understanding too, but as per one of my tests, the class attribute can be edited between invocations to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can be edited, or we can set it in the body of the
Both happen with some degree of regularity for other elements of ConcreteModuleType. But good point that it changes the available attributes, so checking ignored attrs is somewhat redundant. I think we should do it still for consistency, since all other attributes of ConcreteModuleType should participate in equality. |
||
return data_.ignoredAttributes_.count(name) > 0; | ||
} | ||
|
||
std::shared_ptr<ConcreteModuleType> ConcreteModuleType:: | ||
findSubmoduleConcreteType(const std::string& name) const { | ||
const auto it = std::find_if( | ||
|
@@ -281,6 +286,10 @@ void ConcreteModuleTypeBuilder::addFailedAttribute( | |
failedAttributes_.emplace(std::move(name), std::move(failureReason)); | ||
} | ||
|
||
void ConcreteModuleTypeBuilder::addIgnoredAttribute(std::string name) { | ||
ignoredAttributes_.emplace(std::move(name)); | ||
} | ||
|
||
void ConcreteModuleType::dump() const { | ||
std::cout << "ConcreteModuleType for: " | ||
<< py::getattr(data_.pyClass_, "__name__") << "\n"; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,6 +105,10 @@ def infer_concrete_type_builder(nn_module, share_types=True): | |
|
||
class_annotations = getattr(nn_module, '__annotations__', {}) | ||
|
||
# Get user-annotated ignored attributes. | ||
user_annotated_ignored_attributes = getattr(nn_module, "__jit_ignored_attributes__", list()) | ||
concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes) | ||
|
||
# try to infer the type from type annotation or from the object itself | ||
def infer_type(name, item): | ||
# The forward function from Module is special; never use this annotations; we | ||
|
@@ -123,6 +127,9 @@ def infer_type(name, item): | |
added_names = set() | ||
|
||
for name, item in nn_module._parameters.items(): | ||
if name in user_annotated_ignored_attributes: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to check
The only difference would be that ignored attributes would be added even if they aren't present in the module object, but that won't make a difference in practice I think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But we still need to skip adding attributes if they are ignored, right? So instead of
we'd have
or
|
||
continue | ||
|
||
assert item is None or isinstance(item, torch.Tensor) | ||
attr_type = infer_type(name, item) | ||
# We currently have the invariant in various places in our code | ||
|
@@ -134,12 +141,18 @@ def infer_type(name, item): | |
added_names.add(name) | ||
|
||
for name, item in nn_module._buffers.items(): | ||
if name in user_annotated_ignored_attributes: | ||
continue | ||
|
||
assert item is None or isinstance(item, torch.Tensor) | ||
attr_type = infer_type(name, item) | ||
concrete_type_builder.add_attribute(name, attr_type, False, True) | ||
added_names.add(name) | ||
|
||
for name, item in nn_module._modules.items(): | ||
if name in user_annotated_ignored_attributes: | ||
continue | ||
|
||
attr_type = infer_type(name, item) | ||
if item is None: | ||
# Modules can be None. We don't have direct support for optional | ||
|
@@ -205,6 +218,9 @@ def infer_type(name, item): | |
# PyTorch adds a few more. Prevent these from getting compiled. | ||
continue | ||
|
||
if name in user_annotated_ignored_attributes: | ||
continue | ||
|
||
if name in added_names: | ||
# Don't re-add anything we already added | ||
continue | ||
|
@@ -390,14 +406,16 @@ def init_fn(script_module): | |
cpp_module.setattr(name, scripted) | ||
script_module._modules[name] = scripted | ||
|
||
# 3. Copy @ignored/@unused methods from the original `nn_module` to the new ScriptModule. | ||
# 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule. | ||
# This ensures we can access these Python methods on the ScriptModule. | ||
for name in dir(nn_module): | ||
item = getattr(nn_module, name, None) | ||
if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item): | ||
unbound_function = getattr(type(nn_module), name) | ||
bound_method = unbound_function.__get__(script_module) | ||
setattr(script_module, name, bound_method) | ||
elif concrete_type.is_ignored_attribute(name): | ||
setattr(script_module, name, item) | ||
|
||
# For convenience, attach the concrete type to the new ScriptModule | ||
script_module._concrete_type = concrete_type | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_type_sharing.py
to make sure we are not improperly sharing types with different ignored attributes sets?test_jit_py3.py
. This file originally existed because it let us use py3 syntax, but that no longer has a purpose now that we are py3 only.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_type_sharing.py
to test that types are shared if ignoring attrs make them equal and not shared if the ignored attributes list is modified and the types of modules before and after the modification should be different.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For 2: not sure, I wish we had the tests that exercise "compilation customization" APIs, but they have not yet been factored out of
test_jit.py
.