Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: support timm features_only functionality #373

Closed
rwightman opened this issue Mar 30, 2021 · 18 comments
Closed

Feature: support timm features_only functionality #373

rwightman opened this issue Mar 30, 2021 · 18 comments

Comments

@rwightman
Copy link

I've noticed more and more timm backbones being added here, which is great, but a lot of the effort is currently duplicating some features of timm, ie tracking channel numbers, modifying the networks, etc.

timm has a features_only arg in the model factory that will return a model setup as a backbone to produce pyramid features. It has a .features_info attribute you can query to understand what the channels of each output, the approx reduction factor is, etc.

I've adapted the unet and deeplab impl here in the past to use this successfully, although it was quick hack and train work, nothing to serve as a clean example.

If this was supported, any timm model (vit excluded right now) can be used as a backbone in generic fashion, just by model name string passed to creation fn, possibly a small config mapping of model types to index specificiations (some models have slightly different out_indices alignment to strides if they happen be a stride 64 model, or don't have a stride=2 feature, etc). All tap points are the latest possible point for a given feature map stride. Some, but not all of the timm backbones also support an output_stride= arg that will dilate the blocks appropriately for 8, 16 network strides.

Some references:

For most of the models, the featuers are extracted by flattening part of the backbone model via wrapper. A few models where the feature taps are embedded deep within the model use hooks, which causes some issues with torchscript but that will likely be fixed soon in PyTorch.

@rwightman
Copy link
Author

rwightman commented Mar 30, 2021

Also, the functionality is part of my model CI so it should be fairly stable: https://github.com/rwightman/pytorch-image-models/blob/master/tests/test_models.py#L177

@JulienMaille
Copy link
Contributor

JulienMaille commented Mar 30, 2021

Hello @rwightman , big respect for the fantastic work you have done on pytorch-image-models 🙇‍♂️
Are suggesting we could load the model.encoder from something like timm.create_model('resnest26d', features_only=True, pretrained=True)?
Maybe a generalization of something like what can be found in that PR:

model = create_model(model_name=model,
scriptable=True, # torch.jit scriptable
exportable=True, # onnx export
features_only=True)

@rwightman
Copy link
Author

rwightman commented Mar 30, 2021

@JulienMaille basically yes, but you should look at the links, including usage for my efficientdet impl, and the docs. You can specify specific out_indices and any backbone created with features_only=True has a feature_info attribute that will report the channels, reduction (stride), and name for each feature tap. Eliminates the need to keep track of channels for each encoder, you can just query the info attribute during creation.

There is a bit of variation in what out_indices mean, as it's 0 based index and some models have different numbers of possible feature taps. I plan to extend the API in the future to support str based out_indices=('C2', 'C5') which would remain absolute in terms of their mapping to feature map stride levels and would error out if the model didnt' support.

It's possible to specify output_stride for many of the most common backbones (resnets, efficientnet, etc). So for deeplab you could do the following.

>>> import timm
>>> encoder = timm.create_model('resnest26d', features_only=True, out_indices=(1,4), pretrained=True)
>>> encoder.feature_info.channels()
[256, 2048]
>>> encoder.feature_info.reduction()
[4, 32]
>>> encoder = timm.create_model('resnest26d', features_only=True, out_indices=(1,4), output_stride=8, pretrained=True)
>>> encoder.feature_info.reduction()
[4, 8]
>>> o = encoder(torch.randn(2, 3, 224, 224))
>>> for x in o:
...   print(x.shape)
... 
torch.Size([2, 256, 56, 56])
torch.Size([2, 2048, 28, 28])

Comparing models with different striding like vgg (which has stride=1-32 features, and most other nets with stride=2-32)

>>> encoder = timm.create_model('vgg19_bn', features_only=True, out_indices=(0,1,2,3,4,5), pretrained=True)
>>> encoder.feature_info.reduction()
[1, 2, 4, 8, 16, 32]
>>> o = encoder(torch.randn(2, 3, 224, 224))
>>> for x in o:
...   print(x.shape)
... 
torch.Size([2, 64, 224, 224])
torch.Size([2, 128, 112, 112])
torch.Size([2, 256, 56, 56])
torch.Size([2, 512, 28, 28])
torch.Size([2, 512, 14, 14])
torch.Size([2, 512, 7, 7])


>>> encoder = timm.create_model('gernet_l', features_only=True, out_indices=(0,1,2,3,4), pretrained=True)
>>> encoder.feature_info.reduction()
[2, 4, 8, 16, 32]
>>> for x in o:
...   print(x.shape)
... 
torch.Size([2, 32, 112, 112])
torch.Size([2, 128, 56, 56])
torch.Size([2, 192, 28, 28])
torch.Size([2, 640, 14, 14])
torch.Size([2, 2560, 7, 7])

One gotcha is that some models use forward hooks to grab the features from deep activations in the net. These will not work with torchscript (will be fixed in PyTorch someday soon). But it's limited to Xception and TF NasNet models right now, and I may change the default to use non-activated output from a higher level in the net instead for those.

@rwightman
Copy link
Author

rwightman commented Mar 30, 2021

To try it out, without breaking any of the existing models or timm usage, one could create something like a TimmGeneric/TimmFeaturesEncoder that acts as an adapter for this and lets the user pass any timm model as a string and which stage features they want.

@JulienMaille
Copy link
Contributor

JulienMaille commented Mar 30, 2021

To try it out, without breaking any of the existing models or timm usage, one could create something like a TimmGeneric/TimmFeaturesEncoder that acts as an adapter for this and lets the user pass any timm model as a string and which stage features they want.

I did exactly this for the past 50mins. Will share something so you can tell if I'm heading in the right direction.

EDIT: JulienMaille@096aff4

tested with a Unet-EfficientNet-b0

@rwightman
Copy link
Author

@JulienMaille is get_stages() used by any functions besides the forward()? That won't really work as there isn't consistent layer/stage names from model to model.

The idea with features_only is that it's constructed knowing the structure of the original backbone so it just spits out a list of the feature_maps specified by out_indices for you.

The feature wrapper already cuts off head and later modules that aren't used.

class TimmUniversalEncoder(nn.Module, EncoderMixin):
    def __init__(self, model, in_channels, depth=5, pretrained=True, **kwargs):
        super().__init__()
        self._depth = depth
        self._in_channels = in_channels

        model = create_model(model_name=model,
                             in_chans=in_channels,
                             exportable=True,   # onnx export
                             features_only=True,
                             pretrained=pretrained,
                             out_indices=tuple(range(depth))) # FIXME need to handle a few special cases for specific models

        channels = model.feature_info.channels()
        self._out_channels = (in_channels,)  + tuple(channels)

        self.formatted_settings = {}
        self.formatted_settings["input_space"] = "RGB"
        self.formatted_settings["input_range"] = (0, 1)
        self.formatted_settings["mean"] = model.default_cfg['mean']
        self.formatted_settings["std"] = model.default_cfg['std']

    def get_stages(self):
        # FIXME cannot currently support returning stages as some models aren't simple 
        # sequences of modules

    def forward(self, x):
        features = self.encoder(x)
        return [x] + features

@JulienMaille
Copy link
Contributor

JulienMaille commented Mar 31, 2021

@rwightman I thought that get_stages was required when constructing decoder from encoder but I was wrong. So this can go. See #374
A couple questions:

  • in this repo we can reduce the number of down-sampling stages with depth, however in the code above the whole encoder will run forward, how can we delete/remove the unused layer to speedup training/avoid wasting memory?
  • do you have encoders pre-trained with multiple different dataset? EDIT: yes, but they have different name.
  • I'm not familiar with DeepLab, in your previous message, did you imply that different architecture will require different output_stride instead of the current patching done here:?
self.encoder.make_dilated(
   stage_list=[4, 5],
   dilation_list=[2, 4]
)

@rwightman
Copy link
Author

@JulienMaille Hmm, yeah, make_dilated depends on get_stages and that would not necessarily work universally. Although I could probably make it work for a number of models by generating stage indices.. hmm.

The equivalent functionality for timm factory is to specify the output_stride as one of 32 (the deafult), 16, or 8. I think 4 might work on a few but I haven't tested. A number of models are limited to only output_stride=32 though, so it may err out on some backbones.

@JulienMaille
Copy link
Contributor

JulienMaille commented Apr 1, 2021

@rwightman I had a look at all the failures from qubvel test module

_test_forward(model, sample, test_shape)

Unet+timm-u-adv_inception_v3: Exception Calculated padded input size per channel: (2 x 2). Kernel size: (3 x 3). Kernel size can't be greater than actual input size
Unet+timm-u-cspdarknet53: Exception assert torch.Size([128, 128]) == torch.Size([64, 64])
Unet+timm-u-cspdarknet53_iabn: Exception Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'
Unet+timm-u-cspresnext50_iabn: Exception Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'
Unet+timm-u-darknet53: Exception assert torch.Size([128, 128]) == torch.Size([64, 64])
Unet+timm-u-densenet264d_iabn: Exception Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'
Unet+timm-u-dla***: Exception assert torch.Size([128, 128]) == torch.Size([64, 64])
Unet+timm-u-ecaresnet50d_pruned: Exception Given groups=1, weight of size [16, 3072, 3, 3], expected input[1, 2840, 4, 4] to have 3072 channels, but got 2840 channels instead

inplace_abn fails to install on my system, some models seems to return a tensor with an unexpected size.

@qubvel
Copy link
Owner

qubvel commented Apr 2, 2021

Hi @rwightman and @JulienMaille
This could be a nice feature for the library to have a universal encoder for timm backbones!
For now, I don't have enough time to review all the PR and suggested features, but I definitely will try to find at least a couple of hours in the near future. I appreciate the work you did. Thank you very much!

@rwightman
Copy link
Author

@JulienMaille Yes, there are some models that are a bit different in terms of the output strides they support (or require an extra module like the '*iabn'/tresnet models). It's a small number relative to the whole that have those exceptions. I feel it's be net benefit to have the universal timm encoder, but maybe indicate it's beta/alpha or could be issues with specific backbones.

For the tests, a subset of known 'good' timm models could be specified (via inclusion/exclusion model names, wildcard/regex, etc).

While it's being vetted, manual encoders could still be defined if desired. Once the corner cases have been dealt with in the universal encoder (both here and driving some changes on my end), there shouldnt' be any need to define custom encoders for timm models in the future.

@JulienMaille
Copy link
Contributor

JulienMaille commented Apr 15, 2021

@rwightman which models should be excluded from tests? Right now this is what I'm filtering

def get_timm_u_encoders():
timm_exclude_encoders = [
'vit_*', 'tnt_*', 'pit_*',
'*iabn*', 'tresnet*', # models using inplace abn
'dla*', 'hrnet*', # hopefully fix at some point
]

@Khadija-bef
Copy link

@rwightman Creating my model through model=timm.create_model('resnet50', features_only=True, pretrained=True)
is returning a list. And I need the features as Tensor so I can get them by applying model(image).squeeze() to all of my images.
Do you have any suggestion please ?

@qubvel
Copy link
Owner

qubvel commented Jul 5, 2021

merged to master

@qubvel qubvel closed this as completed Jul 5, 2021
@mdabbah
Copy link

mdabbah commented Aug 3, 2021

can you please add to the API a way to make the network output logits and last layer features?

@qubvel
Copy link
Owner

qubvel commented Aug 3, 2021

Maybe I did not understand you, but model return logits of the last layer if activation function argument is not provided

@mdabbah
Copy link

mdabbah commented Aug 3, 2021

I understand
but currently, the API as I understand it allows the network to:
1) return the logits via: logits = model(img) # assuming num_classes>0 and not activation is passed
2) unpooled features via: features = model.forward_features(img) # which outputs unpooled last layer features i.e. the features that the model has before applying the classification layer

I couldn't find a way in the current API that allows something like:
features, logits = model(img)

it would be really helpful if there was such a thing in many research areas including training models with custom loss
or estimating uncertainty based on logits and features statistics.

one could use PyTorch hooks to achieve this for each model but there wouldn't be a consistent API for all models and would require looking up the names that each model gives for the final features layer.

Thank you for the quick reply!

@bhaktatejas922
Copy link

@rwightman what would be needed to use ViT in segmentation models?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants