Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pos_weight should be a vector instead of a scalar #8749

Closed
2 tasks done
seermer opened this issue Jul 27, 2022 · 7 comments
Closed
2 tasks done

pos_weight should be a vector instead of a scalar #8749

seermer opened this issue Jul 27, 2022 · 7 comments
Labels
bug Something isn't working Stale

Comments

@seermer
Copy link

seermer commented Jul 27, 2022

Search before asking

  • I have searched the YOLOv5 issues and found no similar bug report.

YOLOv5 Component

Training

Bug

In the current implementation, pos_weight is set to a scalar (default 1.0).
However, according to the PyTorch official documentation, quote:

pos_weight (Tensor, optional) – a weight of positive examples. Must be a vector with a length equal to the number of classes.
also a given example:

For example, if a dataset contains 100 positive and 300 negative examples of a single class, then pos_weight for the class should be equal to 300/100=3. The loss would act as if the dataset contains 3 x 100=300 positive examples.

therefore, pos_weight should be basically class weight for classification loss.

Environment

  • YOLOv5 6.1

Minimal Reproducible Example

current implementation passes only a single element vector (a scalar) no matter how many classes it has;
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))

Additional

related issue: #5604
in the above issue, pos_weight is misinterpreted as a scalar that is different from the class weight

possible fix:

  1. change default hyp yaml to be a list instead of a scalar (I personally don't like this because if there are too many classes, this list will be super long)
  2. change default cls loss pos_weight to torch.ones([nc]) * hyp['cls_pw'] (this sets all class to a same class weight, and hyp['cls_pw'] can stay as a scalar)
  3. change default cls loss pos_weight to model.class_weights * hyp['cls_pw'] (I personally prefer this since class_weights are computed already, we can use it easily, and the hyp['cls_pw'] will stay as a scalar that scales all the class weights ) edit: I actually just realized that the computed model.class_weights is a different thing from the PyTorch doc, so if we want to match what is suggested in PyTorch doc, we might need to compute a new class weight, for example, we can define it as (total samples / nc) / samples per class * hyp['cls_pw']. In this way, all classes will be trained as if there are (total samples / nc) samples and scaled by hyp['cls_pw'].

edit: please let me know if I misunderstood anything. Thanks.

Are you willing to submit a PR?

  • Yes I'd like to help by submitting a PR!
@seermer seermer added the bug Something isn't working label Jul 27, 2022
@github-actions
Copy link
Contributor

github-actions bot commented Jul 27, 2022

👋 Hello @seermer, thank you for your interest in YOLOv5 🚀! Please visit our ⭐️ Tutorials to get started, where you can find quickstart guides for simple tasks like Custom Data Training all the way to advanced concepts like Hyperparameter Evolution.

If this is a 🐛 Bug Report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online W&B logging if available.

For business inquiries or professional support requests please visit https://ultralytics.com or email support@ultralytics.com.

Requirements

Python>=3.7.0 with all requirements.txt installed including PyTorch>=1.7. To get started:

git clone https://github.com/ultralytics/yolov5  # clone
cd yolov5
pip install -r requirements.txt  # install

Environments

YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):

Status

CI CPU testing

If this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training (train.py), validation (val.py), inference (detect.py) and export (export.py) on macOS, Windows, and Ubuntu every 24 hours and on every commit.

@glenn-jocher
Copy link
Member

@seermer thanks for the feedback! Yes #3 is the best solution I think, or perhaps #3 multiplied by the hyp pw scalar?

@seermer
Copy link
Author

seermer commented Jul 29, 2022

@seermer thanks for the feedback! Yes #3 is the best solution I think, or perhaps #3 multiplied by the hyp pw scalar?

yes I think so, 3 probably is the best, but 3 probably doesn't work well with the --img-weights argument since it is changing the sampling frequency, so maybe we can do 2 when --img-weights is enabled, and use 3 otherwise

@glenn-jocher
Copy link
Member

@seermer oh good point! Can you please submit a PR for this change? Would be interested to see what it does on COCO training.

@seermer
Copy link
Author

seermer commented Jul 29, 2022

I have just submitted a PR, its my first time doing a PR, hopefully, I'm doing things correctly

@glenn-jocher
Copy link
Member

@seermer thank you! I will take a look when I have some time, extremely busy unfortunately.

@github-actions
Copy link
Contributor

github-actions bot commented Sep 1, 2022

👋 Hello, this issue has been automatically marked as stale because it has not had recent activity. Please note it will be closed if no further activity occurs.

Access additional YOLOv5 🚀 resources:

Access additional Ultralytics ⚡ resources:

Feel free to inform us of any other issues you discover or feature requests that come to mind in the future. Pull Requests (PRs) are also always welcomed!

Thank you for your contributions to YOLOv5 🚀 and Vision AI ⭐!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Stale
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants