Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option for ML-Decoder - an improved classification head #1012

Merged
merged 4 commits into from Mar 21, 2022

Conversation

mrT23
Copy link
Contributor

@mrT23 mrT23 commented Nov 30, 2021

While almost every aspect of ImageNet training had improved in the last couple of years (backbones, augmentations, loss,...), a plain classification head, GAP + fully connected, remains the default option.
In our paper, "ML-Decoder: Scalable and Versatile Classification Head" ( https://github.com/Alibaba-MIIL/ML_Decoder ),
we propose a new attention-based classification head, that not only improves results, but also provides better speed-accuracy tradeoff on various classification tasks - multi-label, single-label and zero shot.

image

image

A technical note about the merge request - since each model has a unique coding style, systematically using a different classification head is challenging. This merge request enables ML-Decoder head to all CNNs (I specifically checked ResNet, ResNetD, EfficientNet, RgeNet and TResNet). For Transformers, the GAP operation is embedded inside the 'forward_features' pass, so it is hard to use a different classification head without editing each model separately.

@rwightman
Copy link
Collaborator

@mrT23 thanks for the PR, I'd definitely like to add this. I've had an outstanding TODO to figure out a clean mechanism to use different heads on the models. Obviously the differences across model archs is a challenge, and I probably should move the pooling for the transformer models as you pointed out.

I've also got an alternative to GAP that worked well (but fixes the resolution of the network as I imagine yours does?) https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/attention_pool2d.py .. and wanted to support that too. I was thinking of trying to re-working some of (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/classifier.py ), adding alternate head / 'decoder' module support, and some sort of head factory and interface for changing them that works across all models....

@mrT23
Copy link
Contributor Author

mrT23 commented Dec 1, 2021

Our solution does not require a fixed resolution. We tested it in the article on 224,448 and 640, and switched between them all the time seamlessly.

While the name is similar, i think that our proposed head (ML-Decoder) is very different than the attention_pool2d.py head:

  • If i understand correctly, attention_pool2d.py is a simple attention-based processing layer on the spatial features (Encoder-like).

  • With ML-Decoder, the predicted classes are initiated by external queries to the decoder (more similar maybe to DETR).
    To prevent quadratic dependence in the number of input classes, we propose a scheme called group-decoding

image

In any case, using a factory head scheme is a great idea :-)

@mrT23
Copy link
Contributor Author

mrT23 commented Dec 16, 2021

@rwightman

If you are working\planing to refactor the implementation of the classification head in all the models, i will close the merge request.

However, this is a major task. Each model has its own quirks and specific details, and you will need to edit the models one by one. it might be better to enable via this merge request different classification head at least for the CNN models, so that people will get the chance to experiment with other heads (GAP vs ML-Decoder vs attention-based), and compare performances and speed-accuracy tradeof.

@rwightman
Copy link
Collaborator

@mrT23 don't close yet, I have quite a few things on the go and just haven't had much time to think about this one yet

@rwightman
Copy link
Collaborator

@mrT23 I have an application where I'd like to try this, so going to get the merge going, but will leave out the factory changes for now since I don't want to support this approach long term.

A question re the MLDecoder imp. It looks a bit hardcoded to a certain size/scale of model? The decoder dim is fixed at 2048 which seems to imply a certain capacity range (medium-large) ResNet, TResNet, etc...

@mrT23
Copy link
Contributor Author

mrT23 commented Jan 2, 2022

"2048" (dim_forward) is an internal size of the multi-head attention.

num_of_features (which is 2048 for plain ResNet50) is taken from the model parameters

    if hasattr(model, 'global_pool') and hasattr(model, 'fc'):  # most CNN models, like Resnet50
        model.global_pool = nn.Identity()
        del model.fc
        num_classes = model.num_classes
        num_features = model.num_features
        model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)

ML-Decoder works seamlessly with any number of input features. I think that 'EfficientNet' has a different number of features, and i tested it also there

    elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'):  # EfficientNet
        model.global_pool = nn.Identity()
        del model.classifier
        num_classes = model.num_classes
        num_features = model.num_features
        model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features)

Looking forward for your application and comparisons. :-)
ML-Decoder was developed for multi-label classification, but from my tests it worked well also with ImageNet training

@rwightman
Copy link
Collaborator

@rwightman
Copy link
Collaborator

@mrT23 any comment on the transformer decodery layer?

You have

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        tgt = tgt + self.dropout1(tgt)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(tgt, memory, memory)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

I think below makes more sense?

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(tgt, memory, memory)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

tgt = tgt + dropout(tgt) is 2*tgt with some stochasticity... not sure what the objective is?

@mrT23
Copy link
Contributor Author

mrT23 commented Jan 14, 2022

TL;DR
i know it looks redundant. i specially tested with and without it, and it helps.

long answer:
the transformer architecture design is amazing. the article gives most of the credit to the attention operation, but i think that more credit should go to the general engineering design. A basic transformer decoder unit.

    def forward(...)
        tgt = self.multihead_attn(tgt, tgt, tgt)[0] # self-attention
        tgt = tgt + self.dropout1(tgt)
        tgt = self.norm1(tgt)
        tgt= self.multihead_attn(tgt, memory, memory)[0]  # cross-attention
        tgt = tgt + self.dropout2(tgt)
        tgt = self.norm2(tgt)
        tgt= self.linear2(self.dropout(self.activation(self.linear1(tgt))))  # MLP
        tgt = tgt + self.dropout3(tgt)
        tgt = self.norm3(tgt)
        return tgt

notice the nice engineering features:

  1. residual after every operation (self-attention, cross-attention and MLP)
  2. normalization after every residual
  3. dropout on the residual component only
  4. MLP implemented as double linear projection, with activation and dropout between them

none of these options is trivial, and i am sure that a lot of experiments led to this good design.

For our classification head, the decoder input is from fixed external queries. Hence the expensive self-attention module just provides a fixed transformation on them, and can be removed (we talk about it thoroughly in the paper). My initial choice was also to remove the dropout and the normalization - they also seem redundant. However, the score dropped a bit. only when i kept them, i could remove the self-attention and get exactly the same accuracy.

My guess is that the initial dropout component provides a needed regularization, that helps the model converge well.

@rwightman
Copy link
Collaborator

@mrT23 thanks for the explanation, I'll pull it in as is then (minus the factory additions for now), experiment, and then figure out how to deal with the head generically later.

@rwightman rwightman merged commit 72b5716 into huggingface:master Mar 21, 2022
@alleniver
Copy link

awesome work! thanks

@fffffgggg54 fffffgggg54 mentioned this pull request Dec 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants