# SoftMax Implemention details

[torch 里使用默认使用logsoftmax](https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#Softmax)

```

note::
        This module doesn't work directly with NLLLoss,
        which expects the Log to be computed between the Softmax and itself.
        Use `LogSoftmax` instead (it's faster and has better numerical properties).
```

## Safe-Softmax

先实现基本的[safe-softmax](https://ogunlao.github.io/2020/04/26/you_dont_really_know_softmax.html)

$$
\begin{equation} 
sm(x_i) = \dfrac{e^{x_i - c}}{\sum_{j=1}^{d} e^{x_j -c}}
\end{equation}
$$

## Safe Softmax implemention

In [1]:
import torch
def SoftMax(logits):
    '''
    logits: [batch_size, dim]
    output: [batch_size, dim]
    '''
    logits_max, _ = logits.max(dim = -1)
    logits = logits - logits_max.unsqueeze(1) 
    logits = logits.exp()
    logits_sum = logits.sum(-1, keepdim = True)
    prob = logits / logits_sum
    return prob.abs()
    
logits = torch.randn(8, 10)
prob = SoftMax(logits)
print('原始数据:\n', logits[0]) 
print('softmax:\n', prob[0]) 
print('prob sum:\n',prob[0,:].sum())
print('logprob:\n', prob[0].log())

# LogSoftmax

## Softmax overflow

在MLE(最大似然估计中)， 通常要计算logprob(), 以下例子产生了数值溢出: -inf

In [2]:
# 这个例子取log后，会产生数值不稳定性
logits = torch.tensor([[10, 2, 10000, 4]])
prob = SoftMax(logits)
print(logits)
print(prob)
print(prob.log())

## LogSoftmax Implemention

在pytorch的实现中，使用logsoftmax代替softmax，避免计算logprob产生溢出

可以在原始的Softmax上加入log, 可以推导出logsoftmax

$$
\begin{align}
\text{softmax}(x_i) &= \dfrac{e^{x_i - c}}{\sum_{j=1}^{d} e^{x_j -c}} \\
\text{log\_softmax}(x_i) &= \log \dfrac{e^{x_i - c}}{\sum_{j=1}^{d} e^{x_j -c}} \\
\text{log\_softmax}(x_i) &= x_i - c - \log {\sum_{j=1}^{d} e^{x_j -c}} \\
\end{align} 
$$

此时的logsoftmax得出的logprob不会溢出，同样可以将logprob转化成prob

$$
\text{softmax}(x_i) = \dfrac{e^{\log~probs}}{\sum_{j=1}^{d} e^{\log~probs}}
$$

In [3]:
import torch
def LogSoftMax(logits, recover_prob = True):
    '''
    logits: [batch_size, dim]
    output: [batch_size, dim]
    '''
    # raw_logits = logits
    logits_max, _ = logits.max(dim = -1)
    safe_logits = logits - logits_max.unsqueeze(1)
    safe_logits_exp = safe_logits.exp()
    safe_logits_sum = safe_logits_exp.sum(-1, keepdim = True)
    log_logits_sum = safe_logits_sum.log()
    log_probs = logits - logits_max.unsqueeze(1) - log_logits_sum

    if recover_prob is True:
        exp_log_probs = log_probs.exp()
        sum_log_probs = exp_log_probs.sum(-1, keepdim = True)
        probs = exp_log_probs / sum_log_probs
    
    return probs, log_probs 
    
logits = torch.randn(2, 5)

# softmax
print('--------------softmax------------------')
softmax_probs = SoftMax(logits)
softmax_log_probs = softmax_probs.log()
print(f'softmax probs: \n{softmax_probs}')
print(f'softmax_log_probs: \n{softmax_log_probs}')

# log softmax
print('--------------log softmax------------------')
lsoftmax_probs, lsoftmax_log_probs = LogSoftMax(logits, recover_prob = True)
print(f'logsoftmax probs: \n{lsoftmax_probs}')
print(f'logsoftmax_log_probs: \n{lsoftmax_log_probs}')

## Softmax VS LogSoftmax

In [4]:
logits = torch.tensor([[10, 2, 10000, 4]])

# softmax
softmax_probs = SoftMax(logits)
softmax_log_probs = softmax_probs.log()
print('--------------softmax------------------')
print(f'softmax probs: \n{softmax_probs}')
print(f'softmax_log_probs: \n{softmax_log_probs}')

# log softmax
lsoftmax_probs, lsoftmax_log_probs = LogSoftMax(logits, recover_prob = True)
print('--------------log softmax------------------')
print(f'logsoftmax probs: \n{lsoftmax_probs}')
print(f'logsoftmax_log_probs: \n{lsoftmax_log_probs}') 

使用`LogSoftmax()` 得到的`logprob` 不会产生`-inf`

# Pytorch LogSoftmax

使用pytorch测试softmax是否会溢出

In [1]:
import torch

x = torch.tensor([[10, 2, 10000, 4]], dtype=torch.float32)

x_softmax = torch.nn.functional.softmax(x)


x_logsoftmax = torch.nn.functional.log_softmax(x)

print(x_softmax)

print(x_softmax.log())

print(x_logsoftmax)

  x_softmax = torch.nn.functional.softmax(x)
  x_logsoftmax = torch.nn.functional.log_softmax(x)
