Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Huber loss #1444

Merged
merged 3 commits into from
Mar 13, 2024
Merged

Implement Huber loss #1444

merged 3 commits into from
Mar 13, 2024

Conversation

WorldSEnder
Copy link
Contributor

@WorldSEnder WorldSEnder commented Mar 10, 2024

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Closes #1441 as I think the remaining feature request for a sign function is already tracked in #522.

Changes

Implements the Huber Loss function.

Instead of strictly following the definition of using a sign or abs function, the implementation uses clamping, which computes the same value outside the delta bounds but is better behaved on the autodiff backend and does not need any extra primitive ops. See also #1441 for my first attempt of implementing this.

Testing

Test data should cover all relevant branches of the operation, and critical points on the autodiff backend, i.e. zero residuals and the point where the loss switches between the branches. Test assertions have been generated from executing the equivalent in scipy.

Note: the test_downsample_interpolation test in nearest_interpolate.rs is failing locally for me. Not caused by the patch, I've ignored it when running run-checks.

Instead of using a sign or abs function, uses clamping to compute
it outside the bounds. This is better for the autodiff backend.
@antimora
Copy link
Collaborator

Submitted a PR for sign tensor operator: #1446

@WorldSEnder
Copy link
Contributor Author

Note: I think the method of clamping even works out better for Huber than using sign, since it avoids a mul_scalar(delta). Though I'm still interested in sign for alternative losses such as SmoothL1.

@antimora
Copy link
Collaborator

CI failed due to:

 test tests::jit_fusion::var::tests::test_var_mean_bias ... ok
  
  failures:
  
  ---- tests::jit::kernel::normal::tests::subsequent_calls_give_different_tensors stdout ----
  thread 'tests::jit::kernel::normal::tests::subsequent_calls_give_different_tensors' panicked at crates/burn-wgpu/src/lib.rs:73:5:
  assertion failed: tensor_1.to_data().value[i] != tensor_2.to_data().value[i]
  
  
  failures:
      tests::jit::kernel::normal::tests::subsequent_calls_give_different_tensors
  
  test result: FAILED. 1341 passed; 1 failed; 0 ignored; 0 measured; 0 filtered out; finished in 55.86s

Rerunning to see if it's a fluke. Tagging @nathanielsimard and @louisfd since they're currently working in this area.

Copy link

codecov bot commented Mar 10, 2024

Codecov Report

Attention: Patch coverage is 96.93878% with 3 lines in your changes are missing coverage. Please review.

Project coverage is 85.97%. Comparing base (4ed90a9) to head (7d7b181).
Report is 25 commits behind head on main.

Files Patch % Lines
crates/burn-core/src/nn/loss/huber.rs 96.93% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1444      +/-   ##
==========================================
+ Coverage   85.81%   85.97%   +0.15%     
==========================================
  Files         610      646      +36     
  Lines       70417    71847    +1430     
==========================================
+ Hits        60428    61769    +1341     
- Misses       9989    10078      +89     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@antimora
Copy link
Collaborator

Sign tensor op PR is merged.

@antimora antimora requested a review from laggui March 11, 2024 15:41
Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM @louisfd for further review.

Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but can you clarify the math in comments? I think it works out but it's a bit confusing, in particular I think r, err and res are three names for the same thing?

@WorldSEnder
Copy link
Contributor Author

LGTM, but can you clarify the math in comments? I think it works out but it's a bit confusing, in particular I think r, err and res are three names for the same thing?

Ah yes, I initially wanted to use "error" for the difference between targets and predictions, then remembered the better term residuals and I guess didn't catch a few things when renaming. Will fix.

@louisfd
Copy link
Member

louisfd commented Mar 12, 2024

Thanks, we can merge once CI passes

@antimora antimora merged commit 53eb3ec into tracel-ai:main Mar 13, 2024
14 checks passed
@WorldSEnder WorldSEnder deleted the huber-loss branch March 13, 2024 18:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Problematic derivative of Tensor::abs and Huber loss
4 participants