Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use FX to have a more robust intermediate feature extraction #3597

Closed
wants to merge 6 commits into from

Conversation

fmassa
Copy link
Member

@fmassa fmassa commented Mar 23, 2021

No description provided.

@nairbv
Copy link
Contributor

nairbv commented Mar 29, 2021

    It has a strong assumption that the modules have been registered
    into the model in the same order as they are used.
    This means that one should **not** reuse the same nn.Module
    twice in the forward if you want this to work.

    Additionally, it is only able to query submodules that are directly
    assigned to the model. So if `model` is passed, `model.feature1` can
    be returned, but not `model.feature1.layer2`.

does this approach resolve both of those constraints?

@fmassa
Copy link
Member Author

fmassa commented Mar 29, 2021

@nairbv yes, this FX-based approach addresses both of the aforementioned constraints from the current implementation in torchvision (under the assumption that FX can appropriately symbolically trace the model)

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for snooping around before you mark the PR as complete; hope you don't mind. This feature is going to be super useful for some of the things I'm looking after, so I wanted to have an early sneak pick. 😄

I like the approach. Below I just highlighted few corner-cases. Let me know what you think.

# Get output node
orig_output_node: Optional[torch.fx.Node] = None
for n in reversed(m.graph.nodes):
if n.op == "output":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens in cases where we have multiple outputs (example Inception3 which got auxiliaries)? It seems that FX has another node called inception_outputs:

>>> list(reversed(m.graph.nodes))
[output, inception_outputs, fc, flatten_1, dropout, ....]

You can see this by replacing this input on your test:

        model = torchvision.models.inception_v3(pretrained=False)
        return_layers = {'Mixed_7c': '0', 'avgpool': '1'}

def test_old_new_match(self):
model = torchvision.models.resnet18(pretrained=False)

return_layers = {'layer2': '5', 'layer4': 'pool'}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, it fails when we include the final output in the return layers:

Suggested change
return_layers = {'layer2': '5', 'layer4': 'pool'}
return_layers = {'layer2': '5', 'layer4': 'pool', 'fc': 'fc1'}

with:

E       RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x1 and 512x1000)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch! I need to check more carefully, but the old implementation doesn't work in this case because of the torch.flatten call (which is not a nn.Module), but I believe this should work with the new implementation. To be verified

>>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
"""
# TODO come up with a better name for this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think return_layers is fine. I understand it remaps but it still stores the mapping of the returned layers.

@mthrok
Copy link
Contributor

mthrok commented Jun 7, 2021

@fmassa

I was exploring ways to achieve something similar and found this. This one looks simpler.

https://discuss.pytorch.org/t/how-can-l-load-my-best-model-as-a-feature-extractor-evaluator/17254/6

Is there an advantage/disadvantage in this PR's approach?

@fmassa
Copy link
Member Author

fmassa commented Jun 7, 2021

Hi @mthrok

That's a good question, and I've in fact implemented a similar solution 4 years ago (although I don't think I have the code lying around anymore it looks like)

I've discussed about using hooks to get intermediate features in pytorch/pytorch#21064, but the current drawbacks compared to the current approach we have are the following:

  • we have to always forward the whole model, even though we don't need it all (wasting compute)
  • we always keep the same model parameters around
  • in principle, we might want to return elements which are not part of nn.Module outputs, but intermediates

From those perspectives, hooks can get us many of the things we are looking for, but if the model can be FX-traced, the FX-based approach can be a bit more powerful

@mthrok
Copy link
Contributor

mthrok commented Jun 7, 2021

Hi @mthrok

That's a good question, and I've in fact implemented a similar solution 4 years ago (although I don't think I have the code lying around anymore it looks like)

I've discussed about using hooks to get intermediate features in pytorch/pytorch#21064, but the current drawbacks compared to the current approach we have are the following:

  • we have to always forward the whole model, even though we don't need it all (wasting compute)
  • we always keep the same model parameters around
  • in principle, we might want to return elements which are not part of nn.Module outputs, but intermediates

From those perspectives, hooks can get us many of the things we are looking for, but if the model can be FX-traced, the FX-based approach can be a bit more powerful

I see. Thanks for the clarification.

@ppwwyyxx
Copy link
Contributor

This looks pretty useful! any plans on pushing it further?

@alexander-soare
Copy link
Contributor

alexander-soare commented Jul 15, 2021

@datumbox @fmassa trying to do something similar for the timm library (cc @rwightman). Right now it uses two types of approaches: hooks and something along the lines of IntermediateLayerGetter. But ultimately the FX approach seems most flexible/robust as you say.

Unfortunately, only 42% of the models are traceable in their current state, with control flow being the most frequent blocker (and that's just from catching the first error). Other frequent issues are tensor constructors which need concrete arguments, or when we make use of Tensor.shape.

Do you know of any workarounds for these limitations of symbolic tracing which could be useful without having to touch the models (much)? For instance, using concrete args (I tried but the "concreteness" gets washed away when we trace through non-custom modules), forwarding a fully representative set of inputs for building up control flow paths, or customising the tracer class to deal with problem nodes?

I also wonder if there are any near future developments in the pipeline that will help with this.

Thanks!

@fmassa
Copy link
Member Author

fmassa commented Aug 12, 2021

Hi,

Sorry @ppwwyyxx and @alexander-soare for the delay in replying, I missed the notifications as I was on holidays.

@ppwwyyxx yes, we would like to get this finalized and merged in torchvision sometime soon. Currently all classification models in torchvision work with this approach (and thus detection / segmentation models can be adapted as well), but as @alexander-soare pointed out there are many models in the community that wouldn't work out of the box.

I do have some ideas on how to push the FX-based approach to work for all models (with some caveats on what is possible to be obtained). The main idea is as follows:

  • Use FX to recursively trace each module
  • if a module can't be traced (e.g., due to control flow or unsupported feature), do not trace inside the module but instead keep the module as a leaf

This approach would enable all models to be traced, with the caveat that in the worst possible case the whole model would be a leaf node (and thus we would only be able to get its output and no other intermediate activation -- this case should be rare though).

Thoughts on this approach?

@alexander-soare
Copy link
Contributor

alexander-soare commented Aug 12, 2021

@fmassa thanks for that, I actually went ahead and implemented some of those ideas. In case you find it useful, here's a kind of outdated write up from a different branch to where I'm currently working on this

One thing I'm still trying to work out is how to make model.train() and model.eval() retain its effect when there is control flow based on model.training.

@fmassa
Copy link
Member Author

fmassa commented Aug 12, 2021

@alexander-soare Nice! Do you think you could work on getting your code from https://github.com/alexander-soare/pytorch-image-models/blob/fx-feature-extract-new/timm/models/fx_features.py to be submitted as a PR to torchvision when it's ready?

@fmassa
Copy link
Member Author

fmassa commented Aug 12, 2021

About your question

One thing I'm still trying to work out is how to make model.train() and model.eval() retain its effect when there is control flow based on model.training.

I can reach out to some folks in the FX team to figure out a possible approach. It might probably involve tracing the model twice with different flags, and stitching it in FX somehow

@alexander-soare
Copy link
Contributor

@fmassa would love to make that PR. What are the reasons this one hasn't gone forward so I can make sure I'm addressing them?

@rwightman
Copy link
Contributor

@fmassa this approach for feat extraction feels like it's close to usable, would be great to hash out the final details and smooth out some of the wrinkles (mostly due to tracing limitations, flow control, etc). I'd like to add @alexander-soare 's work here to timm but still some testing and determination of whether there will be any show stoppers to use it for downstream tasks like obj detection, segmentation (undue constraints on the users of the downstream models wrt to scripting, training, exporting, checkpoint saving/loading, etc).

Also on the timm end I need to spend some time figuring how how to better specify the interface for selecting the features to use in deferent use cases (feature pyramid, attention maps, arbitrary taps, etc) for each model. It'd be good to know what you'd like the torchvision API for such functionality to cover so I can roughly match it and eventually use the torchvision code w/ timm feature specs...

@fmassa
Copy link
Member Author

fmassa commented Aug 13, 2021

@alexander-soare

What are the reasons this one hasn't gone forward so I can make sure I'm addressing them?

There were only minor reasons that we didn't get this merged in torchvision yet (for classification models at least).

  • the current module to node assignment that we do in here has a few rough edges. Indeed, one node can belong to multiple modules (i.e., layer0.3.2 and layer0.3 and layer0 can represent the same tensor), but in the current approach only the last one is valid (i.e., layer0 will work, but layer0.3 and layer0.3.2 won't be visible). I would have liked to fix this before getting this merged, but I went on holidays for a few weeks and didn't get to finish it.
  • although all models in torchvision would work with this approach, I had given it a quick try on timm models and as you noted many models wouldn't work out of the box. So I was wondering if I should first adapt the tracing to make it work for timm models or just go ahead with the v1 version, and improve it over time.

I think we can get started with just fixing the first point I mentioned, and collaboratively work to get the more robust tracing working.

@rwightman ultimately I would love to see a generic solution for pytorch/pytorch#21064, and I think using FX can be a way for that, including for detection / segmentation models.
I do think we should go by steps here though. Even though FX allows for querying arbitrary nodes in the computation graph, it doesn't make guarantees (yet?) that the names will be consistent across versions. So querying arbitrary features (which are not the output of a nn.Module) is possible, but will probably not be "officially" supported until we come up with some more guarantees.

Also on the timm end I need to spend some time figuring how how to better specify the interface for selecting the features to use in deferent use cases (feature pyramid, attention maps, arbitrary taps, etc) for each model. It'd be good to know what you'd like the torchvision API for such functionality to cover so I can roughly match it and eventually use the torchvision code w/ timm feature specs

I've been trying to avoid specifying / exposing myself what should different "levels" of a model should be in torchvision, because it is a rather arbitrary decision and ultimately it's up to the user to decide what works best for their application.
The question was kind of easier to answer with resnet-style models because of the different "stages" (so one could just assume that the output of a stage is what we want), but there are newer models like ViT where the definition of stage is way less clear.

My take on this is that the user should just specify a list of strings corresponding to the modules they want to gather the information from, and maybe provide a helper function that prints / returns all possible layer names for a given model.
Something in the lines of:

model = resnet50()

possible_layers = get_all_layer_names_in_execution_order(model)
# now take some of the layers, proportional to the number of layers
n = len(possible_layers)
layers = [possible_layers[int(n * frac)] for frac in [0.25, 0.5, 0.75]]

new_model = get_intermediate_layers(model, layers)

which allows us to be "somewhat" generic.

For models that have some more structure, this metadata can also be present within the models itself (like an attribute .stages or something like that) which returns the layer names for different stages.
This leaves room for the user to either use some pre-selected layers, but also the flexibility to chose something else if they want.

Thoughts?

@alexander-soare
Copy link
Contributor

@fmassa I believe my implementation covers your point 1. I actually got rid of this line from your implementation, meaning you won't get layer0 as it's not a leaf. And then if the user specifies a truncated qualified name like layer0 the intermediate_layer_getter will pick the last one in order of execution (so maybe layer0.3.2). Currently, this behaviour is silent, so might need to come up with a nice way to make sure the user knows it's happening.

Regarding your get_all_layer_names_in_execution_order, you can check that with print_graph_node_qualified_names from my branch.

import timm
from timm.models.fx_features import print_graph_node_qualified_names

model = timm.create_model('resnet50')
print_graph_node_qualified_names(model)

Still though, there are probably many unknown things to smooth out which will only become apparent when it's applied in a variety of use cases.

So, I'd suggest that timm could be an iterative test bed to start with. We could implement it with the foresight that there will be a "generic" tool which requires the user to specify the node names. Then around that we can wrap @rwightman 's interface. Then once we're ready, and if it makes sense, we can move the generic core to torchvision and it should just be a matter of changing the import paths in timm to torchvision. @fmassa this would just mean that we need to stay connected on the timm end to make sure it converges towards what's required for torchvision (or not, maybe we decide it differs at some point and we need to fork it - in which case I'd be happy to continue working on it in torchvision as well).

@rwightman @fmassa does that arrangement sound like it could work?

@fmassa
Copy link
Member Author

fmassa commented Aug 13, 2021

This makes sense to me if it means you'll be able to move faster on this front.
Ultimately I would love if we could join efforts to get some generic tooling out for our users.

Depending on how generic the implementation is, I think it could even live within PyTorch fx folder as a set of helper functions, as this is the type of feature that I think could also be used in torchtext / torchaudio / etc.

@alexander-soare
Copy link
Contributor

alexander-soare commented Aug 21, 2021

@fmassa nevertheless I've gone ahead and made a draft to help keep this moving along. I know it's a bit of a u-turn from my suggestion above, but I realised it's mostly done... I've moved the convo there.

@fmassa
Copy link
Member Author

fmassa commented Sep 8, 2021

Superseeded by #4302

@fmassa fmassa closed this Sep 8, 2021
@fmassa fmassa deleted the intermediate_layer_getter_2 branch September 8, 2021 08:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants