Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to improve recall rate on specific class? #1995

Closed
zero90169 opened this issue Jan 20, 2021 · 18 comments
Closed

How to improve recall rate on specific class? #1995

zero90169 opened this issue Jan 20, 2021 · 18 comments
Labels
question Further information is requested

Comments

@zero90169
Copy link

zero90169 commented Jan 20, 2021

❔Question

First, thanks to the author for sharing this amazing repo.

I have used yolov5s and trained on my custom dataset. Due to inference speed reason so I set the image size to 320x320 ( I know it is too small... ) The overall precision (70%) looks fine, almost the same (or even higher) with the another model MobileDet. However the overall recall performs bad (compare with the MobileDet which got 70% recall, yolov5s only got 60%).

I have tried to use focal loss, but the training process become very long. I want to know is there has another trick to enhance the recall rate or I can improve the recall rate on specific class via some methods (different class with different weighted)?

Additional context

@zero90169 zero90169 added the question Further information is requested label Jan 20, 2021
@github-actions
Copy link
Contributor

github-actions bot commented Jan 20, 2021

👋 Hello @zero90169, thank you for your interest in 🚀 YOLOv5! Please visit our ⭐️ Tutorials to get started, where you can find quickstart guides for simple tasks like Custom Data Training all the way to advanced concepts like Hyperparameter Evolution.

If this is a 🐛 Bug Report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online W&B logging if available.

For business inquiries or professional support requests please visit https://www.ultralytics.com or email Glenn Jocher at glenn.jocher@ultralytics.com.

Requirements

Python 3.8 or later with all requirements.txt dependencies installed, including torch>=1.7. To install run:

$ pip install -r requirements.txt

Environments

YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):

Status

CI CPU testing

If this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training (train.py), testing (test.py), inference (detect.py) and export (export.py) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.

@ZhWL123456
Copy link

Maybe data augment for the specific class is helpful.

@glenn-jocher
Copy link
Member

@zero90169 mAP is by far the best metric to track, as it provides absolute performance across all confidence levels. Precision and recall are relative metrics that vary as a function of confidence threshold. Therefore do not concern yourself with P and R, focus on mAP for an apples to apples comparison.

Best results are achieved when training and detecting at the same --img size, so if you want to deploy with 320 inference size, you should train at 320 as well.

Individual class weights can be specified in the class_weights vector, and weighted classification loss can then employ the class weights in the ComputeLoss() class. Note that class weights are defined as inverse frequencies by default.

yolov5/train.py

Line 219 in d921214

model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights

@zero90169
Copy link
Author

@glenn-jocher Thanks for your suggestions. I will try the class_weights vector.

I did train on 320x320 and inference my well-trained model on 320x320. I knew that R and P will be affected by the threshold, but in my application scenario I have to focus on R in some class (given threshold=0.5).

p.s. I found that in detect.py the input image will be resized to 192x320 not 320x320, it is different to the image we used in training process 320x320. I don't know whether it will affect to the performance or not.

@glenn-jocher
Copy link
Member

@zero90169 if you really need to focus on R, make sure you set a reasonable threshold then for your data here (default is 0.1, this value is different than than --conf)

pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898

@ggyybb
Copy link

ggyybb commented Feb 18, 2021

@zero90169 mAP is by far the best metric to track, as it provides absolute performance across all confidence levels. Precision and recall are relative metrics that vary as a function of confidence threshold. Therefore do not concern yourself with P and R, focus on mAP for an apples to apples comparison.

Best results are achieved when training and detecting at the same --img size, so if you want to deploy with 320 inference size, you should train at 320 as well.

Individual class weights can be specified in the class_weights vector, and weighted classification loss can then employ the class weights in the ComputeLoss() class. Note that class weights are defined as inverse frequencies by default.

yolov5/train.py

Line 219 in d921214

model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights

I am interested in model.class_weights, I konw how to compute this weights vector, but in loss.py ,I can't find where to use this model.class_weights, please help me understand this place at your convenience,thanks a lot!

@glenn-jocher
Copy link
Member

@ggyybb model.class_weights can be applied in the classification BLE loss function here:

BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))

The weights can be passed to the loss function to define a value for the pos_weights argument. See documentation for details:
https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html

@ggyybb
Copy link

ggyybb commented Feb 19, 2021

@ggyybb model.class_weights can be applied in the classification BLE loss function here:

BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))

The weights can be passed to the loss function to define a value for the pos_weights argument. See documentation for details:
https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html

image
I think the cls_pw in hyp do not work actually, it's just a number, if I modified cls_pw in hyp.yaml, the pos_weights of BCEcls will change, but will not work actually, is my idea right?
So I use model.class_weights in pos_weights, It work, I hope i am right, am I? Thanks so much for your response!

@glenn-jocher
Copy link
Member

glenn-jocher commented Feb 19, 2021

@ggyybb that's a good question. The interpretation of the pos_weight argument seems up for debate based on a few forum responses I saw. We were treating it as 'positive weight' before that I assumed was applied to positive samples, but the current documentation shows this as 'position weight'.

You may want to experiment with this to see what the true effect is.

@glenn-jocher
Copy link
Member

glenn-jocher commented Feb 19, 2021

@ggyybb BTW I'm testing pos_weight=model.class_weights on YOLOv5m COCO, and in early training so far it's trailing the default result, but I'll need a few more days to have a better picture.

To be clear, the model.class_weights vector is definitely created correctly, I spent a lot of time getting this right. It is an inverse frequency vector that sums to the class count (80). The main question is whether it can be applied directly using pos_weight=model.class_weights, or whether you may need to apply it manually by setting reduction differently and multiplying the loss matrix by this row vector before doing the reduction. You'll need to experiment.

@ggyybb
Copy link

ggyybb commented Feb 19, 2021

@ggyybb BTW I'm testing pos_weight=model.class_weights on YOLOv5m COCO, and in early training so far it's trailing the default result, but I'll need a few more days to have a better picture.

To be clear, the model.class_weights vector is definitely created correctly, I spent a lot of time getting this right. It is an inverse frequency vector that sums to the class count (80). The main question is whether it can be applied directly using pos_weight=model.class_weights, or whether you may need to apply it manually by setting reduction differently and multiplying the loss matrix by this row vector before doing the reduction. You'll need to experiment.

I studied the source code carefully,I think the pos_weights para is positive weights, It's same as you thought before, It should be a number ,or a hyp, and it will broadcast to fit loss function.
And another para 'weight' of nn.BCEWithLogistLoss is the weights of each class, the doc of torch say:'weight (Tensor, optional): a manual rescaling weight given to the loss of each batch element'.so, I think the the 'each batch element' means that 'class',and I do some experiment with some random number, weight=model.class_weights will apply to each class when the loss was computed.
So, I think make weight=model.class_weights maybe the right use of model.class_weights.

@ggyybb
Copy link

ggyybb commented Feb 19, 2021

@ggyybb BTW I'm testing pos_weight=model.class_weights on YOLOv5m COCO, and in early training so far it's trailing the default result, but I'll need a few more days to have a better picture.

To be clear, the model.class_weights vector is definitely created correctly, I spent a lot of time getting this right. It is an inverse frequency vector that sums to the class count (80). The main question is whether it can be applied directly using pos_weight=model.class_weights, or whether you may need to apply it manually by setting reduction differently and multiplying the loss matrix by this row vector before doing the reduction. You'll need to experiment.

Like this
image

@glenn-jocher
Copy link
Member

@ggyybb oh good work. That's what I thought originally. So I will cancel my current training, which is doing pos_weight=model.class_weights and start a new one with weight=model.class_weights.

@glenn-jocher
Copy link
Member

@ggyybb just to check, the behavior with weight=model.class_weights you are observing is that each loss column (1 class per column) is being multiplied by it's model.class_weights value right?

An example loss shape for COCO might be 1000x80, where all of the 1000 values in the first column should be multiplied by model.class_weights[0].

@ggyybb
Copy link

ggyybb commented Feb 21, 2021

@ggyybb just to check, the behavior with weight=model.class_weights you are observing is that each loss column (1 class per column) is being multiplied by it's model.class_weights value right?

An example loss shape for COCO might be 1000x80, where all of the 1000 values in the first column should be multiplied by model.class_weights[0].

I think Its right,I will be looking forward to your new training result, and I will do experiment on my own datasets,Thanks very much!

@glenn-jocher
Copy link
Member

@ggyybb looking like this so far. cls loss is down by 50% because of the weightings, and mAP is a bit lower too in early training. Much of the current effect may simply due to a lower cls loss in general now though rather than the actual weightings, i.e. it may make sense to increase hyp['cls']=1.5 or 2.0 to rebalance the losses after the change. In any case too early to tell, maybe 3 days left of training.

Screen Shot 2021-02-21 at 11 13 24 AM

@glenn-jocher
Copy link
Member

@ggyybb comparison is done. class_weights (red) reduces mAP about -.01 for COCO YOLOv5m vs default (green) in my experiment.

Screen Shot 2021-02-24 at 12 15 06 AM

@ShirleyHe2020
Copy link

@ggyybb BTW I'm testing pos_weight=model.class_weights on YOLOv5m COCO, and in early training so far it's trailing the default result, but I'll need a few more days to have a better picture.
To be clear, the model.class_weights vector is definitely created correctly, I spent a lot of time getting this right. It is an inverse frequency vector that sums to the class count (80). The main question is whether it can be applied directly using pos_weight=model.class_weights, or whether you may need to apply it manually by setting reduction differently and multiplying the loss matrix by this row vector before doing the reduction. You'll need to experiment.

Like this
image

Hi

@ggyybb BTW I'm testing pos_weight=model.class_weights on YOLOv5m COCO, and in early training so far it's trailing the default result, but I'll need a few more days to have a better picture.
To be clear, the model.class_weights vector is definitely created correctly, I spent a lot of time getting this right. It is an inverse frequency vector that sums to the class count (80). The main question is whether it can be applied directly using pos_weight=model.class_weights, or whether you may need to apply it manually by setting reduction differently and multiplying the loss matrix by this row vector before doing the reduction. You'll need to experiment.

Like this
image

Hi, can you please share your loss_read.py as shown in the pic ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants