-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Class Weight to Detection Trainer #7655
base: main
Are you sure you want to change the base?
Conversation
🔍 Existing Issues For ReviewYour pull request is modifying functions with the following pre-existing issues: 📄 File: ultralytics/engine/trainer.py
📄 File: ultralytics/utils/loss.py (Click to Expand)
Did you find this useful? React with a 👍 or 👎 |
CLA Assistant Lite bot: I have read the CLA Document and I sign the CLA 3 out of 4 committers have signed the CLA. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👋 Hello @zshureih, thank you for submitting an Ultralytics YOLOv8 🚀 PR! To allow your work to be integrated as seamlessly as possible, we advise you to:
- ✅ Verify your PR is up-to-date with
ultralytics/ultralytics
main
branch. If your PR is behind you can update your code by clicking the 'Update branch' button or by runninggit pull
andgit merge main
locally. - ✅ Verify all YOLOv8 Continuous Integration (CI) checks are passing.
- ✅ Update YOLOv8 Docs for any new or updated features.
- ✅ Reduce changes to the absolute minimum required for your bug fix or feature addition. "It is not daily increase but daily decrease, hack away the unessential. The closer to the source, the less wastage there is." — Bruce Lee
See our Contributing Guide for details and let us know if you have any questions!
I have read the CLA Document and I sign the CLA |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7655 +/- ##
==========================================
- Coverage 76.88% 76.68% -0.20%
==========================================
Files 117 117
Lines 14854 14901 +47
==========================================
+ Hits 11420 11427 +7
- Misses 3434 3474 +40
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
…hting Simplify class weighting code
I have read the CLA Document and I sign the CLA |
Feature/fix enable focal loss
Add focal loss
This PR exists but has not been on-track with master for 8+ months. As well, it focuses solely on the classification task while mine is implemented for the detection task
Haven't seen any issues specifically related to the lack of this feature
I noticed a lack of ability to modify specific weights when training detection models. It seems that there was support for this feature previously but it has been removed from the repository at some point, so I put it back in.
This feature implements a new, optional config/CLI parameter called
cls_weights
. If the user does provide any option, there is no change in model behavior. Otherwise, the user may provide one of three options:args.cls_weights
args.cls_weights
args.cls_weights
When the DetectionLoss object is instantiated, if
args.cls_weights
is None, then the BCELoss object is instantiated with no pos_weights parameter, meaning it isn't weighted. Otherwise, the tensor is put on the training device and used to weight the per-class loss of each sample.The implementation has a key flaw in repeating the weight calculation on each GPU when it should be done in the main process. I expect to address it soon.
Ultralytics Contributor License Agreement (CLA): To uphold the quality and integrity of our project, we require all contributors to sign the CLA. Please confirm your agreement by commenting below:
I have read the CLA Document and I hereby sign the CLA
🛠️ PR Summary
Made with ❤️ by Ultralytics Actions
📊 Key Changes
'cls_weights'
has been added to the default YAML file to specify class weights.get_median_frequency_weights
) and inverse class frequency (get_inverse_class_frequency_weights
).calc_weights
), and are then applied to theBCEWithLogitsLoss
in the loss module.🎯 Purpose & Impact
🌟 Summary
This PR introduces configurable class weights to improve object detection training, especially for imbalanced datasets. 🏋️♀️💡