-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use FX to have a more robust intermediate feature extraction #3597
Conversation
does this approach resolve both of those constraints? |
@nairbv yes, this FX-based approach addresses both of the aforementioned constraints from the current implementation in torchvision (under the assumption that FX can appropriately symbolically trace the model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for snooping around before you mark the PR as complete; hope you don't mind. This feature is going to be super useful for some of the things I'm looking after, so I wanted to have an early sneak pick. 😄
I like the approach. Below I just highlighted few corner-cases. Let me know what you think.
# Get output node | ||
orig_output_node: Optional[torch.fx.Node] = None | ||
for n in reversed(m.graph.nodes): | ||
if n.op == "output": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens in cases where we have multiple outputs (example Inception3
which got auxiliaries)? It seems that FX has another node called inception_outputs
:
>>> list(reversed(m.graph.nodes))
[output, inception_outputs, fc, flatten_1, dropout, ....]
You can see this by replacing this input on your test:
model = torchvision.models.inception_v3(pretrained=False)
return_layers = {'Mixed_7c': '0', 'avgpool': '1'}
def test_old_new_match(self): | ||
model = torchvision.models.resnet18(pretrained=False) | ||
|
||
return_layers = {'layer2': '5', 'layer4': 'pool'} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, it fails when we include the final output in the return layers:
return_layers = {'layer2': '5', 'layer4': 'pool'} | |
return_layers = {'layer2': '5', 'layer4': 'pool', 'fc': 'fc1'} |
with:
E RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x1 and 512x1000)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch! I need to check more carefully, but the old implementation doesn't work in this case because of the torch.flatten
call (which is not a nn.Module
), but I believe this should work with the new implementation. To be verified
>>> [('feat1', torch.Size([1, 64, 56, 56])), | ||
>>> ('feat2', torch.Size([1, 256, 14, 14]))] | ||
""" | ||
# TODO come up with a better name for this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think return_layers
is fine. I understand it remaps but it still stores the mapping of the returned layers.
I was exploring ways to achieve something similar and found this. This one looks simpler. https://discuss.pytorch.org/t/how-can-l-load-my-best-model-as-a-feature-extractor-evaluator/17254/6 Is there an advantage/disadvantage in this PR's approach? |
Hi @mthrok That's a good question, and I've in fact implemented a similar solution 4 years ago (although I don't think I have the code lying around anymore it looks like) I've discussed about using hooks to get intermediate features in pytorch/pytorch#21064, but the current drawbacks compared to the current approach we have are the following:
From those perspectives, hooks can get us many of the things we are looking for, but if the model can be FX-traced, the FX-based approach can be a bit more powerful |
I see. Thanks for the clarification. |
This looks pretty useful! any plans on pushing it further? |
@datumbox @fmassa trying to do something similar for the Unfortunately, only 42% of the models are traceable in their current state, with control flow being the most frequent blocker (and that's just from catching the first error). Other frequent issues are tensor constructors which need concrete arguments, or when we make use of Do you know of any workarounds for these limitations of symbolic tracing which could be useful without having to touch the models (much)? For instance, using concrete args (I tried but the "concreteness" gets washed away when we trace through non-custom modules), forwarding a fully representative set of inputs for building up control flow paths, or customising the tracer class to deal with problem nodes? I also wonder if there are any near future developments in the pipeline that will help with this. Thanks! |
Hi, Sorry @ppwwyyxx and @alexander-soare for the delay in replying, I missed the notifications as I was on holidays. @ppwwyyxx yes, we would like to get this finalized and merged in torchvision sometime soon. Currently all classification models in torchvision work with this approach (and thus detection / segmentation models can be adapted as well), but as @alexander-soare pointed out there are many models in the community that wouldn't work out of the box. I do have some ideas on how to push the FX-based approach to work for all models (with some caveats on what is possible to be obtained). The main idea is as follows:
This approach would enable all models to be traced, with the caveat that in the worst possible case the whole model would be a leaf node (and thus we would only be able to get its output and no other intermediate activation -- this case should be rare though). Thoughts on this approach? |
@fmassa thanks for that, I actually went ahead and implemented some of those ideas. In case you find it useful, here's a kind of outdated write up from a different branch to where I'm currently working on this One thing I'm still trying to work out is how to make |
@alexander-soare Nice! Do you think you could work on getting your code from https://github.com/alexander-soare/pytorch-image-models/blob/fx-feature-extract-new/timm/models/fx_features.py to be submitted as a PR to torchvision when it's ready? |
About your question
I can reach out to some folks in the FX team to figure out a possible approach. It might probably involve tracing the model twice with different flags, and stitching it in FX somehow |
@fmassa would love to make that PR. What are the reasons this one hasn't gone forward so I can make sure I'm addressing them? |
@fmassa this approach for feat extraction feels like it's close to usable, would be great to hash out the final details and smooth out some of the wrinkles (mostly due to tracing limitations, flow control, etc). I'd like to add @alexander-soare 's work here to Also on the |
There were only minor reasons that we didn't get this merged in torchvision yet (for classification models at least).
I think we can get started with just fixing the first point I mentioned, and collaboratively work to get the more robust tracing working. @rwightman ultimately I would love to see a generic solution for pytorch/pytorch#21064, and I think using FX can be a way for that, including for detection / segmentation models.
I've been trying to avoid specifying / exposing myself what should different "levels" of a model should be in torchvision, because it is a rather arbitrary decision and ultimately it's up to the user to decide what works best for their application. My take on this is that the user should just specify a list of strings corresponding to the modules they want to gather the information from, and maybe provide a helper function that prints / returns all possible layer names for a given model. model = resnet50()
possible_layers = get_all_layer_names_in_execution_order(model)
# now take some of the layers, proportional to the number of layers
n = len(possible_layers)
layers = [possible_layers[int(n * frac)] for frac in [0.25, 0.5, 0.75]]
new_model = get_intermediate_layers(model, layers) which allows us to be "somewhat" generic. For models that have some more structure, this metadata can also be present within the models itself (like an attribute Thoughts? |
@fmassa I believe my implementation covers your point 1. I actually got rid of this line from your implementation, meaning you won't get Regarding your
Still though, there are probably many unknown things to smooth out which will only become apparent when it's applied in a variety of use cases. So, I'd suggest that @rwightman @fmassa does that arrangement sound like it could work? |
This makes sense to me if it means you'll be able to move faster on this front. Depending on how generic the implementation is, I think it could even live within PyTorch fx folder as a set of helper functions, as this is the type of feature that I think could also be used in torchtext / torchaudio / etc. |
Superseeded by #4302 |
No description provided.