Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the problem of sigmoid gradient generating NaN #1140

Merged
merged 6 commits into from Jan 16, 2024

Conversation

wcshds
Copy link
Contributor

@wcshds wcshds commented Jan 13, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

fix #1139

Changes

Use sigmoid's derivative formula directly to avoid differentiating log and exp in autodiff.

Copy link

codecov bot commented Jan 13, 2024

Codecov Report

Attention: 1 lines in your changes are missing coverage. Please review.

Comparison is base (76c9358) 85.67% compared to head (197cf27) 85.95%.
Report is 4 commits behind head on main.

Files Patch % Lines
burn-autodiff/src/ops/activation.rs 95.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1140      +/-   ##
==========================================
+ Coverage   85.67%   85.95%   +0.27%     
==========================================
  Files         513      518       +5     
  Lines       57006    57724     +718     
==========================================
+ Hits        48841    49616     +775     
+ Misses       8165     8108      -57     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines 929 to 943
/// Returns a new tensor with sigmoid values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take the sigmoid of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with sigmoid values.
fn sigmoid<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
B::exp(B::neg(B::log(B::add_scalar(
B::exp(B::neg(tensor)),
1.0_f32.elem(),
))))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add this function in burn-tensor/src/tensor/ops/activation.rs instead.

Comment on lines 81 to 88
match B::FloatElem::precision() {
Precision::Half => {
let tensor_full = tensor.to_full_precision();
let tensor_tmp = tensor_full.sigmoid();
Tensor::from_full_precision(tensor_tmp)
}
_ => tensor.sigmoid(),
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need full precision here, as it is now declared as a method in the backend. The backend implementations can choose to use full precision regardless of the circumstances. Perhaps we can consider incorporating full precision into the default implementation.

Comment on lines +117 to +125
let tensor_full = B::to_full_precision(&tensor);
let tensor_tmp = B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg(
B::FullPrecisionBackend::log(B::FullPrecisionBackend::add_scalar(
B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg(tensor_full)),
1.0.elem(),
)),
));

B::from_full_precision(tensor_tmp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure @louisfd if there is a more numerically stable implementation possible here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looked it up, i think it's the best we can do

Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, can you check my comment about the naming of the argument to backward, afterwards it will be ready

Comment on lines +117 to +125
let tensor_full = B::to_full_precision(&tensor);
let tensor_tmp = B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg(
B::FullPrecisionBackend::log(B::FullPrecisionBackend::add_scalar(
B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg(tensor_full)),
1.0.elem(),
)),
));

B::from_full_precision(tensor_tmp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looked it up, i think it's the best we can do

///
/// The output tensor.
fn sigmoid_backward<const D: usize>(
x: FloatTensor<B, D>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think x should have a better name, like output, because it's actually sigmoid(x) that was saved in the state, not the original input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot

@louisfd louisfd merged commit a5bdf38 into tracel-ai:main Jan 16, 2024
14 checks passed
@wcshds wcshds deleted the sigmoid-backward branch January 17, 2024 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Small negative values cause the gradient of sigmoid to become NaN
3 participants