-
Notifications
You must be signed in to change notification settings - Fork 353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix the problem of sigmoid gradient generating NaN #1140
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #1140 +/- ##
==========================================
+ Coverage 85.67% 85.95% +0.27%
==========================================
Files 513 518 +5
Lines 57006 57724 +718
==========================================
+ Hits 48841 49616 +775
+ Misses 8165 8108 -57 ☔ View full report in Codecov by Sentry. |
burn-tensor/src/tensor/ops/tensor.rs
Outdated
/// Returns a new tensor with sigmoid values. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `tensor` - The tensor to take the sigmoid of. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A tensor with the same shape as `tensor` with sigmoid values. | ||
fn sigmoid<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> { | ||
B::exp(B::neg(B::log(B::add_scalar( | ||
B::exp(B::neg(tensor)), | ||
1.0_f32.elem(), | ||
)))) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can add this function in burn-tensor/src/tensor/ops/activation.rs
instead.
match B::FloatElem::precision() { | ||
Precision::Half => { | ||
let tensor_full = tensor.to_full_precision(); | ||
let tensor_tmp = tensor_full.sigmoid(); | ||
Tensor::from_full_precision(tensor_tmp) | ||
} | ||
_ => tensor.sigmoid(), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need full precision here, as it is now declared as a method in the backend. The backend implementations can choose to use full precision regardless of the circumstances. Perhaps we can consider incorporating full precision into the default implementation.
let tensor_full = B::to_full_precision(&tensor); | ||
let tensor_tmp = B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg( | ||
B::FullPrecisionBackend::log(B::FullPrecisionBackend::add_scalar( | ||
B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg(tensor_full)), | ||
1.0.elem(), | ||
)), | ||
)); | ||
|
||
B::from_full_precision(tensor_tmp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure @louisfd if there is a more numerically stable implementation possible here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looked it up, i think it's the best we can do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, can you check my comment about the naming of the argument to backward, afterwards it will be ready
let tensor_full = B::to_full_precision(&tensor); | ||
let tensor_tmp = B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg( | ||
B::FullPrecisionBackend::log(B::FullPrecisionBackend::add_scalar( | ||
B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg(tensor_full)), | ||
1.0.elem(), | ||
)), | ||
)); | ||
|
||
B::from_full_precision(tensor_tmp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looked it up, i think it's the best we can do
/// | ||
/// The output tensor. | ||
fn sigmoid_backward<const D: usize>( | ||
x: FloatTensor<B, D>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think x
should have a better name, like output
, because it's actually sigmoid(x) that was saved in the state, not the original input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
fix #1139
Changes
Use sigmoid's derivative formula directly to avoid differentiating
log
andexp
in autodiff.