Adding entropy function analogous to SciPy #43255
Labels
feature
A request for a proper, new feature.
module: numpy
Related to numpy support, and also numpy compatibility of our operators
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃殌 Feature
Add an entropy function in PyTorch to compute entropy = -sum(p * log(p), axis=axis), which is analogous to scipy.stats.entropy.
Motivation
It is a common thing to compute entropy given a distribution p in many use cases. Currently there are several ways to compute entropy in PyTorch listed as below.
entropy = -((p * p.log()).sum(dim=-1))
. This one works in many cases, but not elegant and efficient enough.entropy = -(torch.bmm(p.view(M, 1, N), p.log().view(M, N, 1))).squeeze()
. This one has better forward pass performance compared to the 1st one on CPU when MKL is enabled. However, the backward pass performance will be worse because of the backward computation of bmm.entropy = torch.distributions.categorical.Categorical(p).entropy()
. This one looks the most direct one. However, it is much slower than both of the previous approaches in both forward and backward pass even if we don't need logit at all.The 1st and 2nd approaches are not a direct function to compute entropy as scipy. So we can consider either to add a scipy-like entropy function or optimize the implementation of torch.distributions.categorical.Categorical.entropy.
cc @mruberry @rgommers
The text was updated successfully, but these errors were encountered: