Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace L1 loss with beta smooth L1 loss to achieve better performance #2113

Merged
merged 3 commits into from
Apr 24, 2020

Conversation

potterhsu
Copy link
Contributor

@potterhsu potterhsu commented Apr 17, 2020

As discussed in #2083

@potterhsu potterhsu changed the title Replace L1 loss with beta smooth L1 loss to achieve better performance. Replace L1 loss with beta smooth L1 loss to achieve better performance Apr 20, 2020
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR @potterhsu !

Tests are failing because torchscript expects type annotations for all function arguments that are not tensors, can you fix it?

@@ -346,3 +346,16 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):

pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1]
matches[pred_inds_to_update] = all_matches[pred_inds_to_update]


def smooth_l1_loss(input, target, beta=1. / 9, size_average=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add type annotations for beta and size_average using type annotations from Python 3?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added type annotations, but some tests still failed. Is there anything I should pay attention to?

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good to merge, thanks a lot @potterhsu !

I have only one question before merging, that I would like to understand a bit better, and then I'll merge the PR

box_regression[sampled_pos_inds_subset, labels_pos],
regression_targets[sampled_pos_inds_subset],
reduction="sum",
beta=1 / 9,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question: were the numbers you reported obtained this value as well?
Because IIRC the original implementation of Faster R-CNN uses beta=1 for the box heads, and 1 / 9 for the RPN, see https://github.com/facebookresearch/maskrcnn-benchmark/blob/57eec25b75144d9fb1a6857f32553e1574177daf/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py#L163

I'm curious if making them both 1/9 is better than making it 1 and 1/9, respectively

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make beta 1/9 for RPN and 1.0 for ROI head seems to harm the performance (got 0.357), see the following experiments (EXPT 4.)

Hardware

  • NVIDIA Tesla V100 32G x 4

Software

  • Python 3.7
  • torch 1.4.0
  • torchvision 0.5.0

Results

1. With beta smooth l1 loss (beta=1/9 for both RPN and ROI)

Script: $ python -m torch.distributed.launch --nproc_per_node=4 --use_env train.py --data-path /path/to/COCO2017 --dataset coco --model fasterrcnn_resnet50_fpn --epochs 13 --lr-steps 8 11 --aspect-ratio-group-factor 3 --lr 0.01 --batch-size 2 --world-size 4

Repo Network box AP scheduler epochs lr-steps batch size lr Group Factor
vision R-50 FPN 0.370 1x 13 8, 11 8 0.01 3
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.370
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.579
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.398
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.214
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.406
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.476
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.306
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.495
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.522
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.327
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.561
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.657
Training time 12:24:10
2. With beta smooth l1 loss (beta=1/9 for both RPN and ROI), BS=16

Script: $ python -m torch.distributed.launch --nproc_per_node=4 --use_env train.py --data-path /path/to/COCO2017 --dataset coco --model fasterrcnn_resnet50_fpn --epochs 13 --lr-steps 8 11 --aspect-ratio-group-factor 3 --lr 0.02 --batch-size 4 --world-size 4

Repo Network box AP scheduler epochs lr-steps batch size lr Group Factor
vision R-50 FPN 0.368 1x 13 8, 11 16 0.02 3
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.368
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.579
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.393
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.212
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.405
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.476
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.307
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.494
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.520
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.324
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.558
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.649
Training time 11:32:59
3. With beta smooth l1 loss (beta=1/9 for both RPN and ROI), BS=16, GF=0

Script: $ python -m torch.distributed.launch --nproc_per_node=4 --use_env train.py --data-path /path/to/COCO2017 --dataset coco --model fasterrcnn_resnet50_fpn --epochs 13 --lr-steps 8 11 --aspect-ratio-group-factor 0 --lr 0.02 --batch-size 4 --world-size 4

Repo Network box AP scheduler epochs lr-steps batch size lr Group Factor
vision R-50 FPN 0.369 1x 13 8, 11 16 0.02 0
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.369
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.577
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.398
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.213
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.409
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.475
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.309
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.499
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.524
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.332
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.565
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.659
Training time 11:53:13
4. With beta smooth l1 loss (beta=1/9 for RPN and 1.0 for ROI), BS=16, GF=0

Script: $ python -m torch.distributed.launch --nproc_per_node=4 --use_env train.py --data-path /path/to/COCO2017 --dataset coco --model fasterrcnn_resnet50_fpn --epochs 13 --lr-steps 8 11 --aspect-ratio-group-factor 0 --lr 0.02 --batch-size 4 --world-size 4

Repo Network box AP scheduler epochs lr-steps batch size lr Group Factor
vision R-50 FPN 0.357 1x 13 8, 11 16 0.02 0
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.357
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.576
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.384
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.210
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.393
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.462
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.300
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.482
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.507
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.316
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.550
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.637
Training time 11:48:34

@fmassa
Copy link
Member

fmassa commented Apr 24, 2020

This is very helpful and insightful, thanks a lot @potterhsu !

@fmassa fmassa merged commit 8df3f29 into pytorch:master Apr 24, 2020
@hukkai
Copy link

hukkai commented Apr 26, 2020

@fmassa Hi, is there any plan for updating the reported performances in the document https://pytorch.org/docs/stable/torchvision/models.html?

@fmassa
Copy link
Member

fmassa commented Apr 27, 2020

@KaiHoo the pre-trained models haven't changed, so the reported performances are still correct.

But if users train new models, they will see better performances.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants