Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the accuracy of Classification models by using SOTA recipes and primitives #3995

Closed
datumbox opened this issue Jun 7, 2021 · 14 comments · Fixed by #4444, #4493, #4734, #4811 or #4836
Closed

Comments

@datumbox
Copy link
Contributor

datumbox commented Jun 7, 2021

🚀 Feature

Update the weights of all pre-trained models to improve their accuracy.

Motivation

New Recipe + FixRes mitigations

torchrun --nproc_per_node=8 train.py --model $MODEL_NAME --batch-size 128 --lr 0.5 \
--lr-scheduler cosineannealinglr --lr-warmup-epochs 5 --lr-warmup-method linear \
--auto-augment ta_wide --epochs 600 --random-erase 0.1 --weight-decay 0.00002 \
--norm-weight-decay 0.0 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 \
--train-crop-size 176 --model-ema --val-resize-size 232

Using a recipe which includes Warmup, Cosine Annealing, Label Smoothing, Mixup, Cutmix, Random Erasing, TrivialAugment, No BN weight decay, EMA and long training cycles and optional FixRes mitigations we are able to improve the resnet50 accuracy by over 4.5 points. For more information on the training recipe, check here:

Old ResNet50:
Acc@1 76.130 Acc@5 92.862

New ResNet50:
Acc@1 80.674 Acc@5 95.166

Running other models through the same recipe, achieves the following improved accuracies:

ResNet101:
Acc@1 81.728 Acc@5 95.670

ResNet152:
Acc@1 82.042 Acc@5 95.926

ResNeXt50_32x4d:
Acc@1 81.116 Acc@5 95.478

ResNeXt101_32x8d:
Acc@1 82.834 Acc@5 96.228

MobileNetV3 Large:
Acc@1 74.938 Acc@5 92.496

Wide ResNet50 2:
Acc@1 81.602 Acc@5 95.758 (@prabhat00155)

Wide ResNet101 2:
Acc@1 82.492 Acc@5 96.110 (@prabhat00155)

regnet_x_400mf:
Acc@1 74.864 Acc@5 92.322 (@kazhang)

regnet_x_800mf:
Acc@1 77.522 Acc@5 93.826 (@kazhang)

regnet_x_1_6gf:
Acc@1 79.668 Acc@5 94.922 (@kazhang)

New Recipe (without FixRes mitigations)

torchrun --nproc_per_node=8 train.py --model $MODEL_NAME --batch-size 128 --lr 0.5 \
--lr-scheduler cosineannealinglr --lr-warmup-epochs 5 --lr-warmup-method linear \
--auto-augment ta_wide --epochs 600 --random-erase 0.1 --weight-decay 0.00002 \
--norm-weight-decay 0.0 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 \
--model-ema --val-resize-size 232

Removing the optional FixRes mitigations seems to yield better results for some deeper architectures and variants with larger receptive fields:

ResNet101:
Acc@1 81.886 Acc@5 95.780

ResNet152:
Acc@1 82.284 Acc@5 96.002

ResNeXt50_32x4d:
Acc@1 81.198 Acc@5 95.340

ResNeXt101_32x8d:
Acc@1 82.812 Acc@5 96.226

MobileNetV3 Large:
Acc@1 75.152 Acc@5 92.634

Wide ResNet50_2:
Acc@1 81.452 Acc@5 95.544 (@prabhat00155)

Wide ResNet101_2:
Acc@1 82.510 Acc@5 96.020 (@prabhat00155)

regnet_x_3_2gf:
Acc@1 81.196 Acc@5 95.430

regnet_x_8gf:
Acc@1 81.682 Acc@5 95.678

regnet_x_16g:
Acc@1 82.716 Acc@5 96.196

regnet_x_32gf:
Acc@1 83.014 Acc@5 96.288

regnet_y_400mf:
Acc@1 75.804 Acc@5 92.742

regnet_y_800mf:
Acc@1 78.828 Acc@5 94.502

regnet_y_1_6gf:
Acc@1 80.876 Acc@5 95.444

regnet_y_3_2gf:
Acc@1 81.982 Acc@5 95.972

regnet_y_8gf:
Acc@1 82.828 Acc@5 96.330

regnet_y_16gf:
Acc@1 82.886 Acc@5 96.328

regnet_y_32gf:
Acc@1 83.368 Acc@5 96.498

New Recipe + Regularization tuning

torchrun --nproc_per_node=8 train.py --model $MODEL_NAME --batch-size 128 --lr 0.5 \
--lr-scheduler cosineannealinglr --lr-warmup-epochs 5 --lr-warmup-method linear \
--auto-augment ta_wide --epochs 600 --random-erase 0.1 --weight-decay 0.00001 \
--norm-weight-decay 0.0 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 \
--model-ema --val-resize-size 232

Adjusting slightly the regularization can help us improve the following:

MobileNetV3 Large:
Acc@1 75.274 Acc@5 92.566

In addition to regularization adjustment we can also apply the Repeated Augmentation trick --ra-sampler --ra-reps 4:

MobileNetV2:
Acc@1 72.154 Acc@5 90.822

Post-Training Quantized models

ResNet50:
Acc@1 80.282 Acc@5 94.976

ResNeXt101_32x8d:
Acc@1 82.574 Acc@5 96.132

New Recipe (LR+weight_decay+train_crop_size tuning)

torchrun --ngpus 8 --nodes 1 --model $MODEL_NAME --batch-size 128 --lr 1 \
--lr-scheduler cosineannealinglr --lr-warmup-epochs 5 --lr-warmup-method linear \
--auto-augment ta_wide --epochs 600 --random-erase 0.1 --weight-decay 0.000002 \
--norm-weight-decay 0.0 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 \
--train-crop-size 208 --model-ema --val-crop-size 240 --val-resize-size 255
EfficientNet-B1:
Acc@1 79.838 Acc@5 94.934

Pitch

To be able to improve the pre-trained model accuracy, we need to complete the "Batteries Included" work as #3911. Moreover we will need to extend our existing model builders to support multiple weights as described at #4611. Then we will be able to:

  • Update our reference scripts for classification to support the new primitives added by the "Batteries Included" initiative.
  • Find a good training recipe for the most important pre-trained models and re-train them. Note that different training configuration might be required for different types of models (for example mobile models are less likely to overfit comparing to bigger models and thus make use of different recipes/primitives)
  • Update the weights of the models in the library.

cc @datumbox @vfdev-5

@datumbox datumbox self-assigned this Jun 7, 2021
@datumbox datumbox moved this from Blocked to To Do in Batteries Included - Phase 1 Sep 18, 2021
@datumbox datumbox reopened this Sep 21, 2021
@datumbox datumbox moved this from To Do to In Progress in Batteries Included - Phase 1 Sep 28, 2021
@datumbox datumbox reopened this Oct 22, 2021
@datumbox datumbox linked a pull request Oct 25, 2021 that will close this issue
@datumbox datumbox linked a pull request Nov 1, 2021 that will close this issue
@datumbox datumbox linked a pull request Nov 2, 2021 that will close this issue
@datumbox datumbox linked a pull request Nov 5, 2021 that will close this issue
@xiaohu2015
Copy link
Contributor

@datumbox Can you release the training code, maybe the configs of training because the reference training code has already implemented the training tricks.

@datumbox
Copy link
Contributor Author

datumbox commented Nov 9, 2021

@xiaohu2015 Of course! I'm in the middle of writing a blogpost that will include the configs, the training methodology, detailed ablations etc. It should be out next week. :)

Edit: Here is the blogpost that documents the training recipe.

@netw0rkf10w
Copy link
Contributor

@datumbox For the commands that start with torchrun --nproc_per_node=8 train.py, could you tell me how many nodes were used for the trainings? Thanks.

@netw0rkf10w
Copy link
Contributor

Hi @datumbox. I have tried your New Recipe (without FixRes mitigations) on ResNet101 and obtained only a peak top-1 accuracy of 81.328 (at epoch 418), which is 0.558 behind your result (81.886).

I launched the following command on 64 GPUs:

python train.py --model $MODEL_NAME --batch-size 64 --lr 0.5  \
  --lr-scheduler cosineannealinglr --lr-warmup-epochs 5 --lr-warmup-method linear  \
  --auto-augment ta_wide --epochs 600 --random-erase 0.1 \
  --weight-decay 0.00002  --norm-weight-decay 0.0 \
  --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0  --model-ema \
--val-resize-size 232  --data-path /home/user/data/imagenet 

I only used a batch size of 64 because 128 led to out of memory (on 16GB GPUs). Thus the effective batch size in my training is 64*64 = 4096. Could you please tell me how many GPUs you used in your training? Or even better, could you please share the training configurations shown in your training logs? For example, mine is:

Namespace(amp=False, auto_augment='ta_wide', batch_size=64, cache_dataset=False, clip_grad_norm=None, clonefuse=1, cutmix_alpha=1.0, data_path='/home/user/data/imagenet', dataset='imagenet', device='cuda', dist_backend='nccl', dist_url='env://', distributed=True, epochs=600, gpu=0, interpolation='bilinear', label_smoothing=0.1, lr=0.5, lr_gamma=0.1, lr_scheduler='cosineannealinglr', lr_step_size=30, lr_warmup_decay=0.01, lr_warmup_epochs=5, lr_warmup_method='linear', mixup_alpha=0.2, model='resnet101', model_ema=True, model_ema_decay=0.99998, model_ema_steps=32, momentum=0.9, norm_weight_decay=0.0, opt='sgd', output_dir='/home/user/experiments/output/imagenet_resnet101_new', pretrained=False, print_freq=10, ra_reps=3, ra_sampler=False, random_erase=0.1, rank=0, resume='', resume_from_fused=False, skip_resumed_lr_steps=False, start_epoch=0, sync_bn=False, test_only=False, train_crop_size=224, train_samples=-1, use_deterministic_algorithms=False, val_crop_size=224, val_resize_size=232, val_samples=-1, weight_decay=2e-05, weights=None, workers=16, world_size=64)

The number of GPUs is important information because it affects the effective batch size. I would need to scale my learning rate accordingly to match your results (and for that I would need to know the number of GPUs, and the learning rate, that you used).

FYI the following file contains the metrics values at each epoch of my training.

imagenet_resnet101.txt

Unfortunately the training log file is too big (700MB) to be shared. It is filled with the following annoying warning message:

/home/user/.local/lib/python3.8/site-packages/torch/utils/data/utils/collate.py:138: UserWarning: An output with one or more elements was resized since it had shape [9633792], which does not match the required output shape [64, 3, 224, 224].This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize(0). (Triggered internally at ../aten/src/ATen/native/Resize.cpp:24.)
return torch.stack(batch, 0, out=out)

(By the way, do you know how to get rid of this kind of messages please? Should I create a GitHub issue somewhere?)

Thank you very much in advance for your reply!

@datumbox
Copy link
Contributor Author

datumbox commented Jan 10, 2022

Hi @netw0rkf10w .

Unfortunately the training log file is too big (700MB) to be shared.

This is exactly why it's hard for me as well to share the training log file. We are working on improving the model documentation and figure out how to share these easier.

Here is the fully command used to train the model, it should contain all the information you need to reproduce this:

PYTHONPATH=$PYTHONPATH:`pwd` python -u run_with_submitit.py --ngpus 8 --nodes 1 --model resnext101_32x8d --batch-size 128 --lr 0.5 --lr-scheduler cosineannealinglr --lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.00002 --norm-weight-decay 0.0 --train-crop-size 176 --model-ema --val-resize-size 232

Note that we used submitit and a custom script to launch our jobs.

Thus the effective batch size in my training is 64*64 = 4096. Could you please tell me how many GPUs you used in your training?

This is probably why you don't match my results. I used an effective batch size of 1024. 8x A100 GPUs with 128 batch-size per GPU. I would recommend to maintain the total batch size equal to 1024 to avoid requiring adapting the rest of the parameters.

By the way, do you know how to get rid of this kind of messages please? Should I create a GitHub issue somewhere?

Concerning the warning message, I would recommend opening a GitHub issue on main PyTorch with the minimum snippet that reproduces it to investigate further.

Let me know if you face further problems reproducing the results.

@netw0rkf10w
Copy link
Contributor

@datumbox Great, thanks a lot for your reply! I'll try again and keep you informed about the results.

@tbennun
Copy link
Contributor

tbennun commented Jan 16, 2022

@datumbox As per the discussion in #5084, below is a recipe that achieved the following result on ResNet-50 and ImageNet:
Acc@1 80.858 Acc@5 95.434

torchrun --nproc_per_node=8 train.py --model resnet50 --batch-size 128 --lr 0.5 \
--lr-scheduler cosineannealinglr --lr-warmup-epochs 5 --lr-warmup-method linear \
--auto-augment ta_wide --epochs 600 --random-erase 0.1 --weight-decay 0.00002 \
--norm-weight-decay 0.0 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 \
--train-crop-size 176 --model-ema --val-resize-size 232 \
--ra-sampler --ra-reps=4

Overview of changes to the current recipe (New Recipe + FixRes mitigations):

  • Repeated Augmentation (--ra-sampler --ra-reps=4): In each batch, we sample 1/4 of the original batch size and reuse each sample 4 times with different data augmentations (taken from the same set of augmentations as the original recipe). Repeated Augmentations (RA, also called Batch Augmentation) was successfully used to boost generalization on various models and datasets via gradient variance reduction. In particular, RA with four repetitions was used on ImageNet in prior literature.
    • Reference used for four repeated augmentations: E. Hoffer et al. "Augment Your Batch: Improving Generalization Through Instance Repetition", CVPR 2020.
  • Since I ran this on a 4-GPU node, I changed the number of processes per node to 4 and batch size to 256. It should be equivalent to the 8x128 batch size in the current recipe.
  • --cache-dataset (omitted) was also used to speed up initial Python loading time. Should have no effect on the recipe.

@netw0rkf10w
Copy link
Contributor

@tbennun Great contributions. I guess increasing the number of repetitions also leads to slower training. Could you tell me how much slower it was for your training?

I am about to launch a few trainings and if --ra-reps=4 would not make them much slower I would go for it and skip the default option.

Thanks in advance for your reply.

@tbennun
Copy link
Contributor

tbennun commented Jan 16, 2022

@netw0rkf10w Actually, this didn't slow down training at all. The current version of RA in the classifier example uses the DeiT scheme, in which the epoch length is also 1/reps long. I'm assuming this will have diminishing returns, though, but would be really cool to know!

@netw0rkf10w
Copy link
Contributor

@tbennun I see, thanks. Let me try --ra-reps=4 and see what will happen.

@netw0rkf10w
Copy link
Contributor

I was able to reach Acc@1 81.901 Acc@5 95.772 for ResNet101 with

--batch-size 64 --lr 2.0 --ra-sampler --ra-reps 4 on 64 GPUs.

The effective batch size is 64*64 = 4096 thus I scaled the learning rate to 2.0, which seems to work, though I am not sure how much --ra-sampler --ra-reps 4 contributed.

@datumbox You said in #5084 that you were about to launch a new set of trainings with --ra-sampler --ra-reps 4. Have you observed any improvements? Thanks.

@datumbox
Copy link
Contributor Author

@netw0rkf10w Thanks for confirming, good to know you matched the accuracy.

No plans to retrain all the models for now. It's very expensive and time consuming to train everything from scratch and not sure it makes sense to do this as the improvement is expected to be in the scale of 0.1-0.2 points.

@datumbox datumbox changed the title Improve the accuracy of models by using SOTA recipes and primitives Improve the accuracy of Classification models by using SOTA recipes and primitives Jan 28, 2022
@datumbox
Copy link
Contributor Author

I have modified the scope of the ticket to focus on Classification so that we can conclude the phase 1 of our Batteries Included project. We will focus on Detection and Segmentation on our phase 2.

Big thanks to everyone involved to this project for helping us keep TorchVision fresh!

@zjykzj
Copy link

zjykzj commented Sep 27, 2022

@datumbox As per the discussion in #5084, below is a recipe that achieved the following result on ResNet-50 and ImageNet: Acc@1 80.858 Acc@5 95.434

torchrun --nproc_per_node=8 train.py --model resnet50 --batch-size 128 --lr 0.5 \
--lr-scheduler cosineannealinglr --lr-warmup-epochs 5 --lr-warmup-method linear \
--auto-augment ta_wide --epochs 600 --random-erase 0.1 --weight-decay 0.00002 \
--norm-weight-decay 0.0 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 \
--train-crop-size 176 --model-ema --val-resize-size 232 \
--ra-sampler --ra-reps=4

Overview of changes to the current recipe (New Recipe + FixRes mitigations):

  • Repeated Augmentation (--ra-sampler --ra-reps=4): In each batch, we sample 1/4 of the original batch size and reuse each sample 4 times with different data augmentations (taken from the same set of augmentations as the original recipe). Repeated Augmentations (RA, also called Batch Augmentation) was successfully used to boost generalization on various models and datasets via gradient variance reduction. In particular, RA with four repetitions was used on ImageNet in prior literature.

    • Reference used for four repeated augmentations: E. Hoffer et al. "Augment Your Batch: Improving Generalization Through Instance Repetition", CVPR 2020.
  • Since I ran this on a 4-GPU node, I changed the number of processes per node to 4 and batch size to 256. It should be equivalent to the 8x128 batch size in the current recipe.

  • --cache-dataset (omitted) was also used to speed up initial Python loading time. Should have no effect on the recipe.

Hi @tbennun , follow your recipe, i tried to reappear the result. Download the latest pytorch/vision code and do it like below:

torchrun --nproc_per_node=8 train.py --model resnet50 --batch-size 128 --lr 0.5 --lr-scheduler cosineannealinglr --lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 --weight-decay 0.00002 --norm-weight-decay 0.0 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --train-crop-size 176 --model-ema --val-resize-size 232 --ra-sampler --ra-reps 4 --data-path /data/imagenet/ --output-dir outputs/vision/

The only modified is the way to loading resnet50

    print("Creating model")
    # model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
    model = torchvision.models.resnet50(weights=args.weights, num_classes=num_classes)
    model.to(device)
>>> import torch
>>> torch.__version__
'1.12.1+cu113'
>>> import torchvision
>>> torchvision.__version__
'0.13.1+cu113'

The result is not very ideal, even worse than the original training.

Best Epoch: [599] Acc@1 72.642 Acc@5 90.316
Best Epoch: [515] ENA Acc@1 73.768 Acc@5 91.384

I'm not sure if I need to load the trained resnet50. This result is a reference for everyone.

...
...
Epoch: [599]  [1180/1251]  eta: 0:00:16  lr: 3.484775606299451e-06  img/s: 517.4125772089047  loss: 2.2298 (2.3932)  acc1: 85.9375 (76.3144)  acc5: 94.5312 (91.6808)  time: 0.2357  data: 0.0002  max mem: 10121
Epoch: [599]  [1190/1251]  eta: 0:00:14  lr: 3.484775606299451e-06  img/s: 610.9803893006064  loss: 2.8320 (2.3956)  acc1: 69.5312 (76.2647)  acc5: 92.1875 (91.6431)  time: 0.2352  data: 0.0002  max mem: 10121
Epoch: [599]  [1200/1251]  eta: 0:00:11  lr: 3.484775606299451e-06  img/s: 568.8505523521321  loss: 2.5181 (2.3951)  acc1: 80.4688 (76.3088)  acc5: 95.3125 (91.6651)  time: 0.2331  data: 0.0002  max mem: 10121
Epoch: [599]  [1210/1251]  eta: 0:00:09  lr: 3.484775606299451e-06  img/s: 564.9036928541966  loss: 2.5683 (2.3970)  acc1: 80.4688 (76.2515)  acc5: 95.3125 (91.6598)  time: 0.2320  data: 0.0002  max mem: 10121
Epoch: [599]  [1220/1251]  eta: 0:00:07  lr: 3.484775606299451e-06  img/s: 589.8034087188851  loss: 2.6939 (2.3984)  acc1: 75.0000 (76.2272)  acc5: 93.7500 (91.6289)  time: 0.2401  data: 0.0008  max mem: 10121
Epoch: [599]  [1230/1251]  eta: 0:00:04  lr: 3.484775606299451e-06  img/s: 621.6561009895624  loss: 2.6298 (2.4012)  acc1: 71.0938 (76.1747)  acc5: 89.8438 (91.6074)  time: 0.2256  data: 0.0007  max mem: 10121
Epoch: [599]  [1240/1251]  eta: 0:00:02  lr: 3.484775606299451e-06  img/s: 622.9863095721596  loss: 2.7580 (2.4038)  acc1: 71.0938 (76.1269)  acc5: 89.8438 (91.5876)  time: 0.2062  data: 0.0001  max mem: 10121
Epoch: [599]  [1250/1251]  eta: 0:00:00  lr: 3.484775606299451e-06  img/s: 621.1914377617846  loss: 2.5694 (2.4009)  acc1: 79.6875 (76.1847)  acc5: 93.7500 (91.6155)  time: 0.2068  data: 0.0001  max mem: 10121
Epoch: [599] Total time: 0:04:50
Test:   [ 0/49]  eta: 0:03:19  loss: 1.9507 (1.9507)  acc1: 91.4062 (91.4062)  acc5: 96.8750 (96.8750)  time: 4.0614  data: 3.9483  max mem: 10121
Test:  Total time: 0:00:09
Test:  Acc@1 72.642 Acc@5 90.316
Best Epoch: [599] Acc@1 72.642 Acc@5 90.316
Test: EMA  [ 0/49]  eta: 0:02:57  loss: 1.9203 (1.9203)  acc1: 91.4062 (91.4062)  acc5: 96.0938 (96.0938)  time: 3.6158  data: 3.5251  max mem: 10121
Test: EMA Total time: 0:00:10
Test: EMA Acc@1 72.856 Acc@5 90.554
Best Epoch: [515] ENA Acc@1 73.768 Acc@5 91.384
Training time 1 day, 3:13:27

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment