Skip to content

Remove _log_softmax/_softmax in favor of log_softmax and softmax respectively. #76433

@Chillee

Description

@Chillee

🚀 The feature, motivation and pitch

Similar rationale to #73135

The only difference between _softmax and softmax is that softmax takes an optional dtype argument, while _softmax takes a half_to_float argument. However, this half_to_float arg is not needed in the backwards pass, and even if it were, we could simply reconstruct it from the argument dtypes.

https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/SoftMax.cpp#L427

So, to remove these, we can simply turn the current _softmax and _log_softmax into C++ functions (not operators), and mostly preserve the same logic as today. The backwards pass would need to be modified slightly to take account of the cast.

cc: @albanD @zou3519 @ngimel @ezy

Alternatives

No response

Additional context

No response

cc @ezyang @bhosmer @smessmer @ljk53 @bdhirsh

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: internalsRelated to internal abstractions in c10 and ATentriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions