Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SWAG model weight that only the linear head is finetuned to ImageNet1K #5793

Merged
merged 7 commits into from
Apr 11, 2022

Conversation

YosuaMichael
Copy link
Contributor

@YosuaMichael YosuaMichael commented Apr 8, 2022

subtask of #5708

Model Description

This model has trunk weight from weakly supervised learning described in https://arxiv.org/pdf/2201.08371.pdf. The linear head is fine-tuned to IMAGENET1K dataset while the pre-trained trunk weights are frozen.

This model is suitable for users that want to fine tune the pre-trained trunk on other downstream datasets

Linear head Fine-tuning parameters on IMAGENET1K:

Regnet model (for all size 16gf, 32gf, 128gf):

  • Num epochs: 28
  • Trained on 1 nodes with 8 voltas GPU (32Gb) each
  • Batch size per GPU: 32
  • image size: 224
  • SGD Optimizer with params:
    • weight decay: 0.001
    • momentum: 0.9
    • use Nesterov: True
  • Learning Rate param:
    • scheduler: CosineAnnealingLR
    • Start value: 0.001
  • ImageAugmentation transforms:
    • RandomResizeCrop of size 224 with interpolation 3
    • RandomHorizontalFlip
    • Normalize
  • Note: Trained with pytorch mixed precision

VIsion Transformer (for all size b/16, l/16, h/14):

  • Num epochs: 28
  • Trained on 4 nodes with 8 voltas GPU (32Gb) each
  • Batch size per GPU: 32
  • image size: 224
  • SGD Optimizer with params:
    • weight decay: 1.00 E-09
    • momentum: 0.9
    • use Nesterov: True
  • Learning Rate param:
    • scheduler: CosineAnnealingLR
    • Start value: 0.04
  • ImageAugmentation transforms:
    • RandomResizeCrop of size 224 with interpolation 3
    • RandomHorizontalFlip
    • Normalize
  • Note: Trained with pytorch mixed precision

Validation script and result

## RegNet_Y_16GF
python -u ~/script/run_with_submitit.py --timeout 3000 --ngpus 1 --nodes 1 --partition train --model regnet_y_16gf --data-path="/datasets01_ontap/imagenet_full_size/061417" --test-only --batch-size=1 --weights="RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_LINEAR_V1"
# Acc@1 83.976 Acc@5 97.244

## RegNet_Y_32GF
python -u ~/script/run_with_submitit.py --timeout 3000 --ngpus 1 --nodes 1 --partition train --model regnet_y_32gf --data-path="/datasets01_ontap/imagenet_full_size/061417" --test-only --batch-size=1 --weights="RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_LINEAR_V1"
# Acc@1 84.622 Acc@5 97.480

## RegNet_Y_128GF
python -u ~/script/run_with_submitit.py --timeout 3000 --ngpus 1 --nodes 1 --partition train --model regnet_y_128gf --data-path="/datasets01_ontap/imagenet_full_size/061417" --test-only --batch-size=1 --weights="RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_LINEAR_V1"
# Acc@1 86.068 Acc@5 97.844

## ViT_B_16
python -u ~/script/run_with_submitit.py --timeout 3000 --ngpus 1 --nodes 1 --partition train --model vit_b_16 --data-path="/datasets01_ontap/imagenet_full_size/061417" --test-only --batch-size=1 --weights="ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1"
# Acc@1 81.886 Acc@5 96.180

## ViT_L_16
python -u ~/script/run_with_submitit.py --timeout 3000 --ngpus 1 --nodes 1 --partition train --model vit_l_16 --data-path="/datasets01_ontap/imagenet_full_size/061417" --test-only --batch-size=1 --weights="ViT_L_16_Weights.IMAGENET1K_SWAG_LINEAR_V1"
# Acc@1 85.146 Acc@5 97.422

## ViT_H_14
python -u ~/script/run_with_submitit.py --timeout 3000 --ngpus 1 --nodes 1 --partition train --model vit_h_14 --data-path="/datasets01_ontap/imagenet_full_size/061417" --test-only --batch-size=1 --weights="ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1"
# Acc@1 85.708 Acc@5 97.730

Sample script to load model

from torchvision.models.vision_transformer import vit_b_16, ViT_B_16_Weights

m = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1)

@datumbox datumbox mentioned this pull request Apr 10, 2022
24 tasks
@datumbox datumbox linked an issue Apr 10, 2022 that may be closed by this pull request
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Only a couple of thoughts. Let me know what you think.

torchvision/models/regnet.py Outdated Show resolved Hide resolved
torchvision/models/regnet.py Show resolved Hide resolved
torchvision/models/regnet.py Show resolved Hide resolved
torchvision/models/regnet.py Outdated Show resolved Hide resolved
@YosuaMichael YosuaMichael marked this pull request as ready for review April 11, 2022 10:03
@lauragustafson
Copy link

I would add for both the transforms used during training:
RandomResizeCrop of size 224 with interpolation 3
RandomHorizontalFlip
Normalize

I would add this just because using interpolation 3 isn't always standard. Alternatively, you could just mention interpolation 3 in the size section.

@lauragustafson
Copy link

Also it might be worth mentioning that they were trained with pytorch mixed precision

@lauragustafson
Copy link

The nodes are 8 gpus/node 32GB voltas.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@datumbox datumbox merged commit 3fa2414 into pytorch:main Apr 11, 2022
facebook-github-bot pushed a commit that referenced this pull request May 5, 2022
… to ImageNet1K (#5793)

Summary:
* Add SWAG model that only the linear classifier head is finetuned with frozen trunk weight

* Add accuracy from experiments

* Change name from SWAG_LC to SWAG_LINEAR

* Add comment on SWAG_LINEAR weight

* Remove the comment docs (moved to PR description), and add the PR url as recipe. Also change name of previous swag model to SWAG_E2E_V1

(Note: this ignores all push blocking failures!)

Reviewed By: jdsgomes, NicolasHug

Differential Revision: D36095704

fbshipit-source-id: 4a134412cd8f1366cfb55584a1cdecc568e1a78f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add the SWAG pre-trained weights in TorchVision
4 participants