-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Open
Labels
module: internalsRelated to internal abstractions in c10 and ATenRelated to internal abstractions in c10 and ATentriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🚀 The feature, motivation and pitch
Similar rationale to #73135
The only difference between _softmax
and softmax
is that softmax
takes an optional dtype argument, while _softmax
takes a half_to_float
argument. However, this half_to_float
arg is not needed in the backwards pass, and even if it were, we could simply reconstruct it from the argument dtypes.
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/SoftMax.cpp#L427
So, to remove these, we can simply turn the current _softmax
and _log_softmax
into C++ functions (not operators), and mostly preserve the same logic as today. The backwards pass would need to be modified slightly to take account of the cast.
cc: @albanD @zou3519 @ngimel @ezy
Alternatives
No response
Additional context
No response
ngimel
Metadata
Metadata
Assignees
Labels
module: internalsRelated to internal abstractions in c10 and ATenRelated to internal abstractions in c10 and ATentriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module