New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FX feature extraction #800
FX feature extraction #800
Conversation
This is a script to check that features from the pre-existing extraction methods match those from FX extraction. All but 1 model passes the test:
"""
NOTE: You'll need to patch timm.helpers.build_model_with_cfg with
if kwargs.pop('use_fx'):
if feature_cfg is None or 'out_indices' not in feature_cfg:
feature_cfg = {'feature_cls': 'fx'}
else:
feature_cfg = {'feature_cls': 'fx', 'out_indices': feature_cfg['out_indices']}
"""
import random
import torch
import timm
import numpy as np
from tqdm import tqdm
TARGET_FFEAT_SIZE = 96
def _get_input_size(model=None, model_name='', target=None):
if model is None:
assert model_name, "One of model or model_name must be provided"
input_size = timm.get_model_default_value(model_name, 'input_size')
fixed_input_size = timm.get_model_default_value(model_name, 'fixed_input_size')
min_input_size = timm.get_model_default_value(model_name, 'min_input_size')
else:
default_cfg = model.default_cfg
input_size = default_cfg['input_size']
fixed_input_size = default_cfg.get('fixed_input_size', None)
min_input_size = default_cfg.get('min_input_size', None)
assert input_size is not None
if fixed_input_size:
return input_size
if min_input_size:
if target and max(input_size) > target:
input_size = min_input_size
else:
if target and max(input_size) > target:
input_size = tuple([min(x, target) for x in input_size])
return input_size
def seed_everything(seed=42):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)
def compare_features(model_name, batch_size):
print(model_name)
seed_everything()
model = timm.create_model(model_name, pretrained=False, features_only=True, use_fx=False)
model.eval()
seed_everything()
fx_model = timm.create_model(model_name, pretrained=False, features_only=True, use_fx=True)
fx_model.eval()
def basic_checks(model):
expected_channels = model.feature_info.channels()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
seed_everything()
outputs = model(torch.randn((batch_size, *input_size)))
assert len(expected_channels) == len(outputs)
for e, o in zip(expected_channels, outputs):
try:
assert e == o.shape[1]
except AssertionError:
print("Channels mismatch! Maybe due to multiple outputs from submodule.")
assert o.shape[0] == batch_size
assert not torch.isnan(o).any()
return outputs
model_outputs = basic_checks(model)
fx_model_outputs = basic_checks(fx_model)
for model_output, fx_model_output, info in zip(model_outputs, fx_model_outputs, model.feature_info.info):
try:
assert torch.allclose(model_output, fx_model_output)
except RuntimeError as e:
print(f"{model_name} torch.allclose failed. {e}")
except AssertionError:
print("torch.allclose failed")
print(info)
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*']
EXCLUDE_FEAT_FILTERS = [
'*pruned*', # hopefully fix at some point
] + NON_STD_FILTERS
for model_name in tqdm(timm.list_models(exclude_filters=EXCLUDE_FEAT_FILTERS)):
compare_features(model_name=model_name, batch_size=1) |
For info: How I made the models symbolically traceableHere's a list of changes somewhat ranked in order of most impactful to least impactful. Make timm layers into leaf modulesA leaf module doesn't get traced through, just the reference to it is recorded. This means we can avoid whatever issues come up within. Mostly these have been flow control, and Note: By default, standard torch modules are leaf modules. Maybe there could be consequences of treating timm modules as leaf modules. Not sure.
|
@alexander-soare with the merged functionality in torchvision (congrats!), timm integration should probably leverage what's going to be in the next PyTorch 1.10 + torchvision release with any timm specific extras/interfaces... or is a specialized version needed? |
@rwightman thanks! I don't think you need a specialised version. Some things to consider:
... annnd maybe that's it. Happy to drop another PR with what's needed. How do you plan to resolve the issue of PyTorch/Torchvision version. Will you wait a while till you bump the minimum required version for the whole lib? Or add this now and throw up a version error when applicable? |
@alexander-soare if you have time (no rush), I wouldn't mind rebooting this with your PR in torchvision now released with 1.10 + .11.1. The fx_features would try to import the relevant symbols from torchvision with an imort try/except guard and if the FeatureGraphNet gets used and the import wasn't valid I assume it would look pretty similar up to the create_feature_extractor and then all the mechanics of that live in tv? Regarding the mods to support tracing,
|
@rwightman sounds good. Yeah I'm pretty sure we can just piggyback on what's in TV now. Answering your Qs
With the bulk of the FX magic moved to TV what's left over in this PR is just:
Can you confirm if you want to enforce ALL models be traceable (see the 3 tests I added), or leave it on an as-and-when-needed basis? From this PR then, whenever someone adds a new model or wants to add feature extraction to an existing model, they can do so with the tools provided. On the other hand this PR doesn't tackle the whole idea of making a unified interface for feature extraction across all models. IMO though, that's a separate PR. Thoughts? |
To start, I think having full coverage of all the vision transformers and
mlp / mixer models is a good idea. Also the main resnet.py models and
efficientnet variants would cover a lot of popular models initially. Can
skip the `tf_` models with same padding for now, I believe they all have
tracing warning from the pad.
…On Fri, Oct 29, 2021, 3:41 AM Alexander Soare ***@***.***> wrote:
@rwightman <https://github.com/rwightman> sounds good. Yeah I'm pretty
sure we can just piggyback on what's in TV now. Answering your Qs
- Yeah agreed NormFreeBlock was weird so I just magicked it away with
my decorator. Will have another dig though.
- Afaik there's still nothing in FX to help with int, but on another
glance I realised I can just leafify the few places in which it's found in
timm.
With the bulk of the FX magic moved to TV what's left over in this PR is
just:
1. the mechanisms to make everything traceable (torch._assert, leaf
modules, leaf functions)
2. a thin adaptor between create_feature_extractor and the logic of
build_model_with_config
<https://github.com/alexander-soare/pytorch-image-models/blob/ac53dedec19db508feeba6ab9e72b58f267be910/timm/models/helpers.py#L468-L477>
Can you confirm if you want to enforce ALL models be traceable (see the 3
tests I added), or leave it on an as-and-when-needed basis?
From this PR then, whenever someone adds a new model or wants to add
feature extraction to an existing model, they can do so with the tools
provided. On the other hand this PR doesn't tackle the whole idea of making
a unified interface for feature extraction across all models. IMO though,
that's a separate PR.
Thoughts?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#800 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABLQICHL7QMXIXU7KP3RTVLUJKB5XANCNFSM5CBLNIUQ>
.
|
And I agree that the unified interface is a follow up PR. This is to get
the groundwork in.
…On Mon, Nov 1, 2021, 8:02 AM Ross Wightman ***@***.***> wrote:
To start, I think having full coverage of all the vision transformers and
mlp / mixer models is a good idea. Also the main resnet.py models and
efficientnet variants would cover a lot of popular models initially. Can
skip the `tf_` models with same padding for now, I believe they all have
tracing warning from the pad.
On Fri, Oct 29, 2021, 3:41 AM Alexander Soare ***@***.***>
wrote:
> @rwightman <https://github.com/rwightman> sounds good. Yeah I'm pretty
> sure we can just piggyback on what's in TV now. Answering your Qs
>
> - Yeah agreed NormFreeBlock was weird so I just magicked it away with
> my decorator. Will have another dig though.
> - Afaik there's still nothing in FX to help with int, but on another
> glance I realised I can just leafify the few places in which it's found in
> timm.
>
> With the bulk of the FX magic moved to TV what's left over in this PR is
> just:
>
> 1. the mechanisms to make everything traceable (torch._assert, leaf
> modules, leaf functions)
> 2. a thin adaptor between create_feature_extractor and the logic of
> build_model_with_config
> <https://github.com/alexander-soare/pytorch-image-models/blob/ac53dedec19db508feeba6ab9e72b58f267be910/timm/models/helpers.py#L468-L477>
>
> Can you confirm if you want to enforce ALL models be traceable (see the 3
> tests I added), or leave it on an as-and-when-needed basis?
>
> From this PR then, whenever someone adds a new model or wants to add
> feature extraction to an existing model, they can do so with the tools
> provided. On the other hand this PR doesn't tackle the whole idea of making
> a unified interface for feature extraction across all models. IMO though,
> that's a separate PR.
>
> Thoughts?
>
> —
> You are receiving this because you were mentioned.
> Reply to this email directly, view it on GitHub
> <#800 (comment)>,
> or unsubscribe
> <https://github.com/notifications/unsubscribe-auth/ABLQICHL7QMXIXU7KP3RTVLUJKB5XANCNFSM5CBLNIUQ>
> .
>
|
@alexander-soare I pulled _assert into timm due to the linked issue above, I had added the torch._assert to just patch_embed so I could do some experiments but broke PyTorch < 1.8. I made a tracer_utils under layers that can be used from layers down into models, other tracing specific workaround helpers should end up there. |
c493ff9
to
8872467
Compare
8872467
to
d299401
Compare
@rwightman this is now ready to review. Some notes to add to the above:
|
@alexander-soare looking through... A comment on the decorators, while 'autorwrap_function' and 'leaf_module' make sense in terms of the mechanics of the implementation, they don't make much sense as far as a self descriptive (library user perspective) API. How about Also, outside of layers _assert should be imported via layers, not need the full layers.trace_utils |
@rwightman sorted. Thanks! |
@alexander-soare so, you may or may not have noticed, it's been a REAL PITA getting the FX tests to pass on GitHub actions. The memory is constrained and having the original model def plus the train and eval trace graphs just blows it up. This part is within torchvision now, but what do you think about having an option to return just one trace (train or eval) instead of always two... so with an extra flag you get a SingleGraph wrapper instead of Dual? Think it's worth opening an issue on tv? Creating both seems unecessary if you know you'll only use it in one mode (and would likely cause less problems with the tests here)... |
@rwightman hmm I think I did something to make sure params are not duplicated. And I just cheekily stepped through this code while staring at my resource monitor to confirm that. import torch
from timm.models.fx_features import _leaf_modules, _autowrap_functions
from tests.test_models import _get_input_size, TARGET_BWD_SIZE, TARGET_FWD_SIZE
import timm
from torchvision.models.feature_extraction import NodePathTracer, get_graph_node_names, create_feature_extractor
import pdb; pdb.set_trace()
model_name = 'nfnet_f2'
model = timm.create_model(model_name, pretrained=False)
model.eval()
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
eval_return_nodes = [eval_nodes[ix] for ix in output_node_indices]
fx_model = create_feature_extractor(
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=eval_return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
inputs = torch.randn((1, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
fx_outputs = tuple(fx_model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs) Resources only increase on initial model creation and forward passes (unless I put import torch
from torchvision.models.feature_extraction import create_feature_extractor
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.nn.Parameter(torch.tensor(0.))
def forward(self, x):
x += self.a
x *= 1
return x
model = Model()
print(model.a)
fx_model = create_feature_extractor(model, return_nodes=['mul'])
with torch.no_grad():
model.a += 1
print(model.a)
print(fx_model.a) So I believe the real issue in the tests is the fact that we are doing forward passes with grad twice. Using no_grad should sort it for the forward test. I noticed most of the tests in the commit history failed on forward, but the very last one failed on the backward pass - I tried the same experiment with that code and didn't notice any other places where lots of memory is being used redundantly. Maybe it's just barely tipping it over the edge? |
@alexander-soare I'm not exactly sure what's happening, but the memory (CPU memory we're talking about here so it could easily be a mix of parameters and Python overhead) is significant vs the non FX tests. The number of iterations I've gone through to filter out models is getting ridiculous. Sigh. Think I'm getting close. I did fix the forward FX tests for no grad, but the backward is the limiting case now, still a lot of memory compared to non FX tests, I think the max model size is maybe roughly 1/3 (based on param count) the non FX tests? |
@rwightman weird, I even tried my exercise with the non fx version and memory consumption is about the same. Is it possible that memory is not being release properly and by the time you get to the fx tests there's a buildup? I'm running tests now on my machine and seeing a lot more memory usage than any individual test should take. Sorry about that. Honestly, I'd rather have been responsible for cleaning up the mess 😞. In fact, I'd be happy to investigate next chance I get if there's an easy way for my to do it without getting in the way. |
@alexander-soare Hmm, I think if you make a new branch and open as a PR here it'll run the tests within that PR... |
Yep, just was wondering if there was extra access to see memory consumption and the like. Anyway feel free to comment stuff out and I'll find time to look at it towards the end of the week |
yeah I'm not sure what a good way of debugging these tests is, much easier when running locally but then I don't have the problem running locally :) If the next run fails I may just disable the backward FX test for now... |
timm.models.fx_features.FeatureGraphNet
as another option for feature extraction. (works as a standalone commit)Caveat - Right now we can only safely say it works in eval mode. Control flow that depends on the value ofmodel.training
is frozen into place by the tracing operation. So if the model was traced in eval mode, it stays that way (actually only those parts that were traced through, leaf modules and leaf functions respect the training mode). Therefore, we cannot expectmodel.train()
to have the desired effect. This is a TODO, so right now there is a warning when the user tries to domodel.train()
.All local tests passed.
EDIT - This feature has been added to torchvision pytorch/vision@72d650a