-
Notifications
You must be signed in to change notification settings - Fork 6.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP/NO_MERGE] Prototype RegularizedShortcut #4549
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think overall this looks OK. If I understand correctly, the procedure is:
- Iterate through the named modules in the module hierarchy, and for each module that's part of the
block_types
of interest:
a. Add the shortcut module
b. trace the module and search for a residual connection (i.e. add node with two input and a placeholder input)
c. Replace the residual connection with the shortcut module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think overall this looks OK. If I understand correctly, the procedure is...
@jamesr66a Thanks a lot for reviewing. Your description of the approach is correct.
I was worried that looping through named_modules
, tracing independently the graphs of the submodules and then overwriting the original modules would be problematic. Just to be safe, below I highlight the bits that concerned me. If you have any thoughts on how to improve it I'm happy to adopt it.
torchvision/prototype/ops/_utils.py
Outdated
if isinstance(m, block_types): | ||
# Add the Layer directly on submodule prior tracing | ||
# workaround due to https://github.com/pytorch/pytorch/issues/66197 | ||
m.add_module(_MODULE_NAME, RegularizedShortcut(regularizer_layer)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally I wanted to create the Layer on the fly and attach it directly on the graph but the pytorch/pytorch#66197 issue prohibits me from doing this. Here I attach it to the module just before tracing it as a workaround. Any concerns?
It will be removed once the issue is fixed.
torchvision/prototype/ops/_utils.py
Outdated
with graph.inserting_after(node): | ||
# Always put the shortcut value first | ||
args = node.args if node.args[0] == input else node.args[::-1] | ||
node.replace_all_uses_with(graph.call_module(_MODULE_NAME, args)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling the previously created module by name. Hopefully this will be replaced with something like the following:
fn_impl_traced = torch.fx.symbolic_trace(RegularizedShortcut(regularizer_layer))
args = node.args if node.args[0] == input else node.args[::-1]
fn_impl_output_node = fn_impl_traced(*map_arg(args, Proxy))
node.replace_all_uses_with(fn_impl_output_node.node)
for node in graph.nodes: | ||
# The isinstance() won't work if the model has already been traced before because it loses | ||
# the class info of submodules. See https://github.com/pytorch/pytorch/issues/66335 | ||
if node.op == "call_module" and isinstance(model.get_submodule(node.target), block_types): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jamesr66a We just figured out that FX traced models lose their submodule class information. This means that for a model that has been traced before, we can't use isinstance()
to identify its Block type. Is this intentional or a bug?
This is an early prototype utility based on FX.
The target is to detect Residual Connections in arbitrary Model architectures and modify the network to add regularlization blocks (such as
StochasticDepth
).Example usage:
Output:
Before
After addition
After deletion
Also tested with:
Affected by pytorch/pytorch#66197 and pytorch/pytorch#66335