Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FX feature extraction #800

Merged
merged 11 commits into from Nov 19, 2021

Conversation

alexander-soare
Copy link
Contributor

@alexander-soare alexander-soare commented Aug 12, 2021

  • Added timm.models.fx_features.FeatureGraphNet as another option for feature extraction. (works as a standalone commit)
  • Made all models traceable (2nd commit)
  • Tests to enforce all models traceable (3rd commit)

Caveat - Right now we can only safely say it works in eval mode. Control flow that depends on the value of model.training is frozen into place by the tracing operation. So if the model was traced in eval mode, it stays that way (actually only those parts that were traced through, leaf modules and leaf functions respect the training mode). Therefore, we cannot expect model.train() to have the desired effect. This is a TODO, so right now there is a warning when the user tries to do model.train().

  • This is sorted but hasn't been tested in anger.

All local tests passed.

EDIT - This feature has been added to torchvision pytorch/vision@72d650a

@alexander-soare
Copy link
Contributor Author

This is a script to check that features from the pre-existing extraction methods match those from FX extraction. All but 1 model passes the test:

  • dpn fails because the chosen output module returns a tuple. FX doesn't know which output to look at unless you specify. So it's not really a failure.
"""
NOTE: You'll need to patch timm.helpers.build_model_with_cfg with
if kwargs.pop('use_fx'):
    if feature_cfg is None or 'out_indices' not in feature_cfg:
        feature_cfg = {'feature_cls': 'fx'}
    else:
        feature_cfg = {'feature_cls': 'fx', 'out_indices': feature_cfg['out_indices']}
"""

import random

import torch
import timm
import numpy as np
from tqdm import tqdm

TARGET_FFEAT_SIZE = 96


def _get_input_size(model=None, model_name='', target=None):
    if model is None:
        assert model_name, "One of model or model_name must be provided"
        input_size = timm.get_model_default_value(model_name, 'input_size')
        fixed_input_size = timm.get_model_default_value(model_name, 'fixed_input_size')
        min_input_size = timm.get_model_default_value(model_name, 'min_input_size')
    else:
        default_cfg = model.default_cfg
        input_size = default_cfg['input_size']
        fixed_input_size = default_cfg.get('fixed_input_size', None)
        min_input_size = default_cfg.get('min_input_size', None)
    assert input_size is not None

    if fixed_input_size:
        return input_size

    if min_input_size:
        if target and max(input_size) > target:
            input_size = min_input_size
    else:
        if target and max(input_size) > target:
            input_size = tuple([min(x, target) for x in input_size])
    return input_size


def seed_everything(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(True)


def compare_features(model_name, batch_size):
    print(model_name)
    seed_everything()
    model = timm.create_model(model_name, pretrained=False, features_only=True, use_fx=False)
    model.eval()
    seed_everything()
    fx_model = timm.create_model(model_name, pretrained=False, features_only=True, use_fx=True)
    fx_model.eval()
    
    def basic_checks(model):
        expected_channels = model.feature_info.channels()
        
        assert len(expected_channels) >= 4  # all models here should have at least 4 feature levels by default, some 5 or 6

        input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)

        seed_everything()
        outputs = model(torch.randn((batch_size, *input_size)))
        assert len(expected_channels) == len(outputs)
        for e, o in zip(expected_channels, outputs):
            try:
                assert e == o.shape[1]
            except AssertionError:
                print("Channels mismatch! Maybe due to multiple outputs from submodule.")
            assert o.shape[0] == batch_size
            assert not torch.isnan(o).any()

        return outputs

    model_outputs = basic_checks(model)
    fx_model_outputs = basic_checks(fx_model)
    for model_output, fx_model_output, info in zip(model_outputs, fx_model_outputs, model.feature_info.info):
        try:
            assert torch.allclose(model_output, fx_model_output)
        except RuntimeError as e:
            print(f"{model_name} torch.allclose failed. {e}")
        except AssertionError:
            print("torch.allclose failed")
            print(info)


NON_STD_FILTERS = [
    'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
    'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*']

EXCLUDE_FEAT_FILTERS = [
    '*pruned*',  # hopefully fix at some point
] + NON_STD_FILTERS


for model_name in tqdm(timm.list_models(exclude_filters=EXCLUDE_FEAT_FILTERS)):
    compare_features(model_name=model_name, batch_size=1)

@alexander-soare
Copy link
Contributor Author

For info:

How I made the models symbolically traceable

Here's a list of changes somewhat ranked in order of most impactful to least impactful.

Make timm layers into leaf modules

A leaf module doesn't get traced through, just the reference to it is recorded. This means we can avoid whatever issues come up within. Mostly these have been flow control, and InplaceAbn. Check is_leaf_module in MyTracer of fx_coverage.py to see which layers I included.

Note: By default, standard torch modules are leaf modules. Maybe there could be consequences of treating timm modules as leaf modules. Not sure.

assert *, * -> torch._assert(*, *)

Most of the control flow was in assert statements. Luckily we have https://pytorch.org/docs/stable/generated/torch._assert.html

Roughly 10 files needed at least one of these.

* @ * -> torch.matmul(*, *)

Roughly 15 files, mostly attention related.

* and * -> fx_and(*, *)

and appears in a few of the torch._asserts. There we get the symbolic trace control flow error. Custom fx.Tracer class takes this function and wraps it to treat it as a "leaf function".

int -> fx_float_to_int

Appears in a few places where we are calculating some tensor dim for a reshape/view. Custom fx.Tracer class takes this function and wraps it to treat it as a "leaf function".

Don't modify tensor slices in place

... by using torch.cat to reconstruct the full tensor from slices.

In tnt.py

-        patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))
+        patch_embed = torch.cat(
+            [patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))],
+            dim=1)

In rexnet.py

-            x[:, 0:self.in_channels] += shortcut
+            x = torch.cat([x[:, 0:self.in_channels] + shortcut, x[:, self.in_channels:]], dim=1)

@rwightman
Copy link
Collaborator

@alexander-soare with the merged functionality in torchvision (congrats!), timm integration should probably leverage what's going to be in the next PyTorch 1.10 + torchvision release with any timm specific extras/interfaces... or is a specialized version needed?

@alexander-soare
Copy link
Contributor Author

@rwightman thanks! I don't think you need a specialised version. Some things to consider:

  • You could keep the leaf module registration mechanism that I added here and feed that into create_feature_extractor. You might also do the same with leaf functions.
  • This means you can go back to using @ instead of matmul.

... annnd maybe that's it. Happy to drop another PR with what's needed.

How do you plan to resolve the issue of PyTorch/Torchvision version. Will you wait a while till you bump the minimum required version for the whole lib? Or add this now and throw up a version error when applicable?

@rwightman
Copy link
Collaborator

@alexander-soare if you have time (no rush), I wouldn't mind rebooting this with your PR in torchvision now released with 1.10 + .11.1. The fx_features would try to import the relevant symbols from torchvision with an imort try/except guard and if the FeatureGraphNet gets used and the import wasn't valid has_fx_feature_extraction=False type check then it'd throw at that point.

I assume it would look pretty similar up to the create_feature_extractor and then all the mechanics of that live in tv?

Regarding the mods to support tracing,

  • it's nice that we can leave the @ in.
  • think it'd be better to break asserts with and into two seperate lines and avoid the fx_and, just two torch._assert
  • curious about some of the leaf_module requirements, like NormFreeBlock doesn't make much sense... control flow doesn't depend on input there, just fixed config that determines if the skipinit_gain or downsample path is needed, that's pretty standard pytorch if module not None type stuff.. hmmm
  • int workaround still needed?

@alexander-soare
Copy link
Contributor Author

alexander-soare commented Oct 29, 2021

@rwightman sounds good. Yeah I'm pretty sure we can just piggyback on what's in TV now. Answering your Qs

  • Yeah agreed NormFreeBlock was weird so I just magicked it away with my decorator. Will have another dig though. Were there any other specific ones you were wondering about?
  • Afaik there's still nothing in FX to help with int, but on another glance I realised I can just leafify the few places in which it's found in timm.

With the bulk of the FX magic moved to TV what's left over in this PR is just:

  1. the mechanisms to make everything traceable (torch._assert, leaf modules, leaf functions)
  2. a thin adaptor between create_feature_extractor and the logic of build_model_with_config

Can you confirm if you want to enforce ALL models be traceable (see the 3 tests I added), or leave it on an as-and-when-needed basis?

From this PR then, whenever someone adds a new model or wants to add feature extraction to an existing model, they can do so with the tools provided. On the other hand this PR doesn't tackle the whole idea of making a unified interface for feature extraction across all models. IMO though, that's a separate PR.

Thoughts?

@rwightman
Copy link
Collaborator

rwightman commented Nov 1, 2021 via email

@rwightman
Copy link
Collaborator

rwightman commented Nov 1, 2021 via email

@rwightman
Copy link
Collaborator

@alexander-soare I pulled _assert into timm due to the linked issue above, I had added the torch._assert to just patch_embed so I could do some experiments but broke PyTorch < 1.8. I made a tracer_utils under layers that can be used from layers down into models, other tracing specific workaround helpers should end up there.

@alexander-soare
Copy link
Contributor Author

alexander-soare commented Nov 13, 2021

@rwightman this is now ready to review. Some notes to add to the above:

  • Tests remain as before: forward_fx, backward_fx, forward_torchscript_fx.
  • I had to exclude some models from forward_torchscript_fx (ie they can't be scripted after FX). These can't be fx traced then scripted because there is some control flow with torch.jit.is_scripted(). They are 'beit_', 'deit_distilled_patch16_224', 'levit*', 'pit*_distilled_224'.
  • With the above accounted for, all tests passed locally.
  • Dropped all _float_to_int by using leaf nodes
  • tf models were fine. Just used leaf modules for all the same padding modules.
  • NormFreeBlock needs to be a leaf module because of the mul_. This actually causes FX core to drop a node because it thinks it's unused! See the issue I left them: [FX] [BUG] Tensor.{inplace_method}_(.) is eliminated as dead code pytorch/pytorch#68301. I know there are mul_s in other places of your code, but the tests don't raise errors (for example, see that issue to understand why Swish works).

@alexander-soare alexander-soare marked this pull request as ready for review November 13, 2021 00:08
@rwightman
Copy link
Collaborator

@alexander-soare looking through...

A comment on the decorators, while 'autorwrap_function' and 'leaf_module' make sense in terms of the mechanics of the implementation, they don't make much sense as far as a self descriptive (library user perspective) API.

How about register_notrace_module and register_notrace_function ...

Also, outside of layers _assert should be imported via layers, not need the full layers.trace_utils

@alexander-soare
Copy link
Contributor Author

@rwightman sorted. Thanks!

@rwightman rwightman merged commit 32c9937 into huggingface:master Nov 19, 2021
@rwightman
Copy link
Collaborator

@alexander-soare so, you may or may not have noticed, it's been a REAL PITA getting the FX tests to pass on GitHub actions. The memory is constrained and having the original model def plus the train and eval trace graphs just blows it up.

This part is within torchvision now, but what do you think about having an option to return just one trace (train or eval) instead of always two... so with an extra flag you get a SingleGraph wrapper instead of Dual? Think it's worth opening an issue on tv? Creating both seems unecessary if you know you'll only use it in one mode (and would likely cause less problems with the tests here)...

@alexander-soare
Copy link
Contributor Author

@rwightman hmm I think I did something to make sure params are not duplicated. And I just cheekily stepped through this code while staring at my resource monitor to confirm that.

import torch
from timm.models.fx_features import _leaf_modules, _autowrap_functions
from tests.test_models import _get_input_size, TARGET_BWD_SIZE, TARGET_FWD_SIZE
import timm
from torchvision.models.feature_extraction import NodePathTracer, get_graph_node_names, create_feature_extractor

import pdb; pdb.set_trace()
model_name = 'nfnet_f2'
model = timm.create_model(model_name, pretrained=False)
model.eval()

input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)

# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_nodes, eval_nodes = get_graph_node_names(
    model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
eval_return_nodes = [eval_nodes[ix] for ix in output_node_indices]

fx_model = create_feature_extractor(
    model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=eval_return_nodes,
    tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})

inputs = torch.randn((1, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
    outputs = torch.cat(outputs)
fx_outputs = tuple(fx_model(inputs).values())
if isinstance(fx_outputs, tuple):
    fx_outputs = torch.cat(fx_outputs)

Peek 2021-11-22 11-56

Resources only increase on initial model creation and forward passes (unless I put torch.no_grad() for the latter). To be clear, there's no visible jump on creating of the fx model. You can also check the params are indeed shared with

import torch
from torchvision.models.feature_extraction import create_feature_extractor

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        x += self.a
        x *= 1
        return x


model = Model()
print(model.a)

fx_model = create_feature_extractor(model, return_nodes=['mul'])
with torch.no_grad():
    model.a += 1
print(model.a)
print(fx_model.a)

So I believe the real issue in the tests is the fact that we are doing forward passes with grad twice. Using no_grad should sort it for the forward test. I noticed most of the tests in the commit history failed on forward, but the very last one failed on the backward pass - I tried the same experiment with that code and didn't notice any other places where lots of memory is being used redundantly. Maybe it's just barely tipping it over the edge?

@rwightman
Copy link
Collaborator

@alexander-soare I'm not exactly sure what's happening, but the memory (CPU memory we're talking about here so it could easily be a mix of parameters and Python overhead) is significant vs the non FX tests.

The number of iterations I've gone through to filter out models is getting ridiculous. Sigh. Think I'm getting close.

I did fix the forward FX tests for no grad, but the backward is the limiting case now, still a lot of memory compared to non FX tests, I think the max model size is maybe roughly 1/3 (based on param count) the non FX tests?

@alexander-soare
Copy link
Contributor Author

alexander-soare commented Nov 22, 2021

@rwightman weird, I even tried my exercise with the non fx version and memory consumption is about the same. Is it possible that memory is not being release properly and by the time you get to the fx tests there's a buildup? I'm running tests now on my machine and seeing a lot more memory usage than any individual test should take. 10% in and it's mostly been monotonically increasing apart from one dip. (yeah not sure about that last observation actually)

Sorry about that. Honestly, I'd rather have been responsible for cleaning up the mess 😞. In fact, I'd be happy to investigate next chance I get if there's an easy way for my to do it without getting in the way.

@rwightman
Copy link
Collaborator

rwightman commented Nov 22, 2021

@alexander-soare Hmm, I think if you make a new branch and open as a PR here it'll run the tests within that PR...

@alexander-soare
Copy link
Contributor Author

Yep, just was wondering if there was extra access to see memory consumption and the like. Anyway feel free to comment stuff out and I'll find time to look at it towards the end of the week

@rwightman
Copy link
Collaborator

yeah I'm not sure what a good way of debugging these tests is, much easier when running locally but then I don't have the problem running locally :)

If the next run fails I may just disable the backward FX test for now...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants