Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pos_weight hyperparameter for dealing with imbalanced dataset #2703 #8557

Closed
wants to merge 5 commits into from

Conversation

hulkds
Copy link
Contributor

@hulkds hulkds commented Feb 29, 2024

model = YOLO("yolov8s.pt")
model.train(data="coco.yaml", pos_weight=[0.5, 2]) # pos_weight must be the same length as class dimension.

I have read the CLA Document and I hereby sign the CLA

🛠️ PR Summary

Made with ❤️ by Ultralytics Actions

🌟 Summary

Addition of class imbalance handling with pos_weight in loss calculation.

📊 Key Changes

  • Introduced a new configuration option pos_weight in the default settings (default.yaml).
  • Updated the loss calculation logic in loss.py to consider the pos_weight for the binary cross-entropy calculation.

🎯 Purpose & Impact

  • 🎯 Purpose: To improve model performance on imbalanced datasets by adjusting the weight of positive samples in loss computation.
  • 👥 Impact: Users working with datasets that have class imbalances may notice improved model training results as the model now accounts for this imbalance, potentially leading to better generalization and performance.

Copy link

github-actions bot commented Feb 29, 2024

CLA Assistant Lite bot All Contributors have signed the CLA. ✅

@hulkds
Copy link
Contributor Author

hulkds commented Feb 29, 2024

I have read the CLA Document and I sign the CLA

Copy link

codecov bot commented Feb 29, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 75.75%. Comparing base (36408c9) to head (605e503).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #8557      +/-   ##
==========================================
+ Coverage   75.74%   75.75%   +0.01%     
==========================================
  Files         117      117              
  Lines       14693    14695       +2     
==========================================
+ Hits        11129    11132       +3     
+ Misses       3564     3563       -1     
Flag Coverage Δ
Benchmarks 36.29% <0.00%> (-0.01%) ⬇️
GPU 39.02% <100.00%> (+0.02%) ⬆️
Tests 70.86% <100.00%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Burhan-Q
Copy link
Member

Burhan-Q commented Mar 1, 2024

@hulkds thank you for your work!

I don't have final say on if your changes will get integrated or not, but I'd like to provide some feedback for you to help improve the chances that it could get accepted. I have not tested the new code, these are just pointers from my review of the changes made.

  1. You have included a new initialization for self.bce on L171 however the original on L156 remains, which means it would be initialized twice. It' probably better to move your new line up to replace the old one and insert the line for self.pos_weight above.

  2. I see you added this new argument to the configuration YAML, but have you verified that it works? When you add the argument for training like you show in your example

    model = YOLO("yolov8s.pt")
    model.train(
        data="coco.yaml",
        pos_weight=[0.5, 2],  # pos_weight must be the same length as class dimension.
        )

    are you certain that this change is applied? If so, can you share your results that demonstrate the applied change?

  3. Have you tested what implications there are when changing the pos_weight argument for a balanced dataset?

  4. Changes that will update arguments should also be incorporated into the documentation. Please be sure to add your changes to docs as well.

@satyrmipt
Copy link

@hulkds , please answer Burhan-Q's questions. We all are waiting for your improvement to be implemented. In particular those of us who use Google Colab right now have no workarounds except for over\undersampling. For his 3rd question i think it's user responsibility to use correct weights. Default weights must be ones anyway.

@hulkds
Copy link
Contributor Author

hulkds commented Mar 3, 2024

@Burhan-Q thanks for your feedback!

  1. This is my bad, I'll fix it :)
  2. I confirmed that the pos_weight argument is functioning as expected. After printing self.hyp, I noticed the inclusion of pos_weight. Furthermore, I can access self.hyp.pos_weight without any issues, indicating that the addition of pos_weight to the configuration YAML file is indeed effective.
  3. I successfully trained my model incorporating this modification and observed a significant improvement in detecting vehicles and license plates (with a car to license plate ratio in my dataset is nearly 6:1), thanks to the pos_weight parameter.
  4. I plan to submit another pull request soon. This PR will address your initial suggestion and include updates to the documentation to reflect these changes.

@Burhan-Q
Copy link
Member

Burhan-Q commented Mar 3, 2024

@hulkds thanks for the follow up! I checked out this PR and was able to verify that it works as well for both the case when adding weights or using the default value. I tested with the coco128 dataset and did not encounter any issues. I did not test with any other task segment, classify, obb, or pose but since the CI tests are passing, it seems that it's unlikely to be an issue for those tasks. Obviously with a limited test like mine, it's difficult to observe the differences. Later I will test with your changes versus without your changes to ensure that for the default there is no change in outcomes.

It would be preferable if you could checkout this PR or push commits to your fork for the documentation updates so that everything is included with a single PR. This will help to ensure your PR has the best chance to get accepted (I can't say for certain it will tho).

@hulkds
Copy link
Contributor Author

hulkds commented Mar 3, 2024

@Burhan-Q
The change was made specifically to the v8DetectionLoss class, so I don't expect it to affect other tasks, but it's still good to double-check. I've created a new branch and pushed a commit there. If everything looks fine, I will then create a new PR like so, everything will be included with a single PR.

@Burhan-Q
Copy link
Member

Burhan-Q commented Mar 3, 2024

@hulkds yes your updates look like they should be okay. The reason I mentioned the other tasks is because they inherit from the v8DetectionLoss for example, the Segmentation Loss

class v8SegmentationLoss(v8DetectionLoss):

I checked the results from the existing training (without pos_weight) code against the training code with pos_weight=[1] and the results from my limited test seem to match. Seems good to me overall. However you'd like to commit your changes, it's up to you, but if you make a PR it would be good to close this PR and mention it in the new one. 🚀

@hulkds
Copy link
Contributor Author

hulkds commented Mar 3, 2024

@Burhan-Q
Here is the PR: #8620

@hulkds hulkds closed this Mar 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants