Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Class Weight to Detection Trainer #7655

Open
wants to merge 31 commits into
base: main
Choose a base branch
from

Conversation

zshureih
Copy link

@zshureih zshureih commented Jan 18, 2024

  1. Check for Existing Contributions: Before submitting, kindly explore existing PRs to ensure your contribution is unique and complementary.

This PR exists but has not been on-track with master for 8+ months. As well, it focuses solely on the classification task while mine is implemented for the detection task

  1. Link Related Issues: If your PR addresses an open issue, please link it in your submission. This helps us better understand the context and impact of your contribution.

Haven't seen any issues specifically related to the lack of this feature

  1. Elaborate Your Changes: Clearly articulate the purpose of your PR. Whether it's a bug fix or a new feature, a detailed description aids in a smoother integration process.

I noticed a lack of ability to modify specific weights when training detection models. It seems that there was support for this feature previously but it has been removed from the repository at some point, so I put it back in.

This feature implements a new, optional config/CLI parameter called cls_weights. If the user does provide any option, there is no change in model behavior. Otherwise, the user may provide one of three options:

  1. User provides a list of float values as long as the number of classes as defined in their dataset. This list is converted to a tensor and stored as args.cls_weights
  2. User provides the string 'median'. The median class frequency weights are computed from the provided training set and stored as args.cls_weights
  3. User provides the string 'inverse'. The inverse class frequency weights are computer from the provided training set and stored as args.cls_weights

When the DetectionLoss object is instantiated, if args.cls_weights is None, then the BCELoss object is instantiated with no pos_weights parameter, meaning it isn't weighted. Otherwise, the tensor is put on the training device and used to weight the per-class loss of each sample.

The implementation has a key flaw in repeating the weight calculation on each GPU when it should be done in the main process. I expect to address it soon.

  1. Ultralytics Contributor License Agreement (CLA): To uphold the quality and integrity of our project, we require all contributors to sign the CLA. Please confirm your agreement by commenting below:

    I have read the CLA Document and I hereby sign the CLA

🛠️ PR Summary

Made with ❤️ by Ultralytics Actions

📊 Key Changes

  • Configuration Addition: A new configuration setting 'cls_weights' has been added to the default YAML file to specify class weights.
  • Methods for Weight Calculation: Two methods are added in the training module to calculate class weights based on median frequency (get_median_frequency_weights) and inverse class frequency (get_inverse_class_frequency_weights).
  • Class Weight Application: Class weights can be user defined or calculated within the trainer (calc_weights), and are then applied to the BCEWithLogitsLoss in the loss module.

🎯 Purpose & Impact

  • Better Model Performance: Including class weights helps in dealing with class imbalance in datasets, possibly improving model performance on underrepresented classes.
  • Customization and Flexibility: Users can now either define their own weights or let the system calculate them, offering more control over the training process.

🌟 Summary

This PR introduces configurable class weights to improve object detection training, especially for imbalanced datasets. 🏋️‍♀️💡

Copy link

sentry-io bot commented Jan 18, 2024

🔍 Existing Issues For Review

Your pull request is modifying functions with the following pre-existing issues:

📄 File: ultralytics/engine/trainer.py

Function Unhandled Issue
_setup_ddp ValueError: Error initializing torch.distributed using env:// rendezvous: environment variable MASTER_PORT ex... ...
Event Count: 34
_setup_train [**SyntaxError: '�[31m�[1mfl_gamma�0m' is not a valid YOLO argument. ** ultralytics.cfg in c...
Event Count: 16
_setup_train Error: An unexpected error occurred wandb.sdk.wan...
Event Count: 5
_setup_train Error: An unexpected error occurred wandb.sdk.wan...
Event Count: 4
_setup_train AssertionError: /kaggle/working/runs/classify/train2/weights/last.pt training to 50 epochs is finished, nothing t... ...
Event Count: 3
📄 File: ultralytics/utils/loss.py (Click to Expand)
Function Unhandled Issue
__init__ AttributeError: 'RTDETRDecoder' object has no attribute 'stride' torch.nn.modules.modul...
Event Count: 2
__init__ TypeError: cannot assign 'int' object to buffer 'pos_weight' (torch Tensor or None required) ...
Event Count: 1
__init__ AttributeError: 'RTDETRDecoder' object has no attribute 'no' torch.nn.modules.module in...
Event Count: 1
__init__ AttributeError: 'RTDETRDecoder' object has no attribute 'reg_max' torch.nn.modules.modu...
Event Count: 1
__init__ AttributeError: 'RTDETRDecoder' object has no attribute 'reg_max' torch.nn.modules.modu...
Event Count: 1
---

Did you find this useful? React with a 👍 or 👎

Copy link

github-actions bot commented Jan 18, 2024

CLA Assistant Lite bot:
Thank you for your submission, we really appreciate it. Like many open-source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution. You can sign the CLA by just posting a Pull Request Comment same as the below format.


I have read the CLA Document and I sign the CLA


3 out of 4 committers have signed the CLA.
✅ (zshureih)[https://github.com/zshureih]
✅ (davegrays)[https://github.com/davegrays]
✅ (UltralyticsAssistant)[https://github.com/UltralyticsAssistant]
@david Grayson
David Grayson seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You can retrigger this bot by commenting recheck in this Pull Request

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👋 Hello @zshureih, thank you for submitting an Ultralytics YOLOv8 🚀 PR! To allow your work to be integrated as seamlessly as possible, we advise you to:

  • ✅ Verify your PR is up-to-date with ultralytics/ultralytics main branch. If your PR is behind you can update your code by clicking the 'Update branch' button or by running git pull and git merge main locally.
  • ✅ Verify all YOLOv8 Continuous Integration (CI) checks are passing.
  • ✅ Update YOLOv8 Docs for any new or updated features.
  • ✅ Reduce changes to the absolute minimum required for your bug fix or feature addition. "It is not daily increase but daily decrease, hack away the unessential. The closer to the source, the less wastage there is." — Bruce Lee

See our Contributing Guide for details and let us know if you have any questions!

@zshureih
Copy link
Author

I have read the CLA Document and I sign the CLA

@zshureih zshureih changed the title feat/add tested class weight hyper-param to detection trainer Add Class Weight to Detection Trainer Jan 18, 2024
Copy link

codecov bot commented Jan 18, 2024

Codecov Report

Attention: Patch coverage is 33.92857% with 37 lines in your changes are missing coverage. Please review.

Project coverage is 76.68%. Comparing base (58a05f8) to head (6446c5c).
Report is 5 commits behind head on main.

Files Patch % Lines
ultralytics/engine/trainer.py 20.93% 34 Missing ⚠️
ultralytics/utils/loss.py 76.92% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #7655      +/-   ##
==========================================
- Coverage   76.88%   76.68%   -0.20%     
==========================================
  Files         117      117              
  Lines       14854    14901      +47     
==========================================
+ Hits        11420    11427       +7     
- Misses       3434     3474      +40     
Flag Coverage Δ
Benchmarks 36.53% <8.92%> (-0.22%) ⬇️
GPU 38.76% <26.78%> (-0.07%) ⬇️
Tests 71.75% <33.92%> (-0.17%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@davegrays
Copy link

I have read the CLA Document and I sign the CLA

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants