-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: special: Add the softmax function #8872
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple of minor; optional changes.
@@ -640,7 +641,7 @@ | |||
from ._ufuncs import * | |||
|
|||
from .basic import * | |||
from ._logsumexp import logsumexp | |||
from ._logsumexp import logsumexp, softmax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to rename this to _expfuncs.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it is a private module, I'm not too worried about the name. And once you realize that the softmax function is the gradient of logsumexp
, it isn't so bad leaving the name as _logsumexp.py
. 😃 That also suggests that a better name for softmax
is something like logsumexp_grad
, but softmax
is what people will be looking for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And once you realize that the softmax function is the gradient of logsumexp
This seems like an interesting enough observation that it could go into the notes section of the docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like an interesting enough observation that it could go into the notes section of the docs
Done (if somewhat tersely).
scipy/special/_logsumexp.py
Outdated
---------- | ||
x : array_like | ||
Input array. | ||
axis : int, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be tuple of int
s as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks.
5907f8a
to
161d72f
Compare
Over in #8556, @person142 was against adding this to |
Since the majority are in favor, I would hardly want to hold this back. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more minor comments.
scipy/special/tests/test_softmax.py
Outdated
@@ -0,0 +1,59 @@ | |||
from __future__ import division, print_function, absolute_import |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be nice if this could be merged into test_logsumexp.py
, since the file is the same, the testing file should be the same.
It's nice if the tests mirror the structure of the actual code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. I was this close to making this change earlier anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
scipy/special/_logsumexp.py
Outdated
|
||
# compute in log space for numerical stability | ||
sigma = np.exp(x - logsumexp(x, axis=axis, keepdims=True)) | ||
return sigma |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit by @eric-wieser that I agree wtih: Merge these two lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong preference here, but since folks are commenting on it a pull request (and I seem to recall seeing it in another one), I'd like to get some clarification. @hameerabbasi, @eric-wieser, can you give explicit reasons for making this change? Is it for better performance? I see about a 5 ns improvement between
def func1(x):
y = 2*x
return y
and
def func2(x):
return 2*x
The argument against the change is that assigning the return value to a variable, and then returning that variable in a separate return
statement, eases future degugging. One can set a break point at the return statement (or add a print statement before it) to inspect the result before it is returned. Indeed, this pattern is taught in some courses on programming style.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like I said, it's a nit. Feel free to ignore. I tend to only store a variable if it's not a one-liner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The argument was only ever one of brevity, not of performance. Also, it's weird to see:
def thing_one():
thing_two = ...
return thing_two
because it suggests the author wasn't really sure what to call their function.
to inspect the result before it is returned
Note that pdb provides __return__
for this purpose. Typing return
(execute and run the return statement) followed by p __return__
does exactly this.
Indeed, this pattern is taught in some courses on programming style.
These courses would do better to teach how to use the debugger usefully
Adds the softmax function, commonly used in machine learning and statistics, to scipy.special.
* Update the docstring: * Copy-edit a bit. * Added "Examples" section. * Move LaTeX notation to the Notes section. * Add versionadded annotation. * Tests: * More explicit tests of "normal" cases and extreme cases. * More tests using the `axis` argument. * Use `assert_allclose` instead of `assert_almost_equal`.
161d72f
to
12dd12e
Compare
0d9f6f9
to
607be34
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for the patience, @WarrenWeckesser!
""" | ||
|
||
# compute in log space for numerical stability | ||
return np.exp(x - logsumexp(x, axis=axis, keepdims=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure that this is as stable as the one suggested here
def stablesoftmax(x, axis=None):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx = x - np.max(x, axis=axis, keepdims=True)
exps = np.exp(shiftx)
return exps / np.sum(exps, axis=axis, keepdims=True)
That implementation has the advantage that translating the input (without precision loss) results in exactly the same result, not just a to-within-floating-accuracy result
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tend to disagree... There will be floating-point inaccuracies incurred by division, np.exp
, np.sum
. The key point here is that even subtraction has floating point inaccuracies if the exponent of the two operands isn't exactly equal.
This is just the softmax done in the log-domain. Agreed that logsumexp
could introduce some floating point errors, but the other implementation doesn't necessarily produce the same result as just the last line. It could be that the errors cancel out, but I'm not well-versed enough in floating-point arithmetic to see how that'd work.
Can you add a test that all of the following give
I'm not sure there's a clear definition for what cases with repeated infinities like |
@eric-wieser I don't understand why you're saying you expect
|
Either way, that's about Thanks all! |
@rgommers I think he meant |
Ah, that makes more sense. Is not the case though for either the current implementation or the
|
Of course. In one case, you're doing |
This pull request is a continuation of #8556