In [2]:
from torch import Tensor
import torch


def softmax_n_shifted_zeros(input: Tensor, n: int) -> Tensor:
    """
    $\text(softmax)_n(x_i) = exp(x_i) / (n + \sum_j exp(x_j))$

    Note: softmax_n, with fixed input, is _not_ shift-symmetric when n != 0, and we must account for this.
    Normally when computing a softmax, the maxes are subtracted from the inputs for numeric stability.
    """
    # compute the maxes along the last dimension
    input_maxes = input.max(dim=-1, keepdim=True).values
    # shift the input to prevent overflow (and underflow in the denominator)
    shifted_inputs = torch.subtract(input, input_maxes)
    # compute the numerator and softmax_0 denominator using the shifted input
    numerator = torch.exp(shifted_inputs)
    original_denominator = numerator.sum(dim=-1, keepdim=True)
    # we need to shift the zeros in the same way we shifted the inputs
    shifted_zeros = torch.multiply(input_maxes, -1)
    # and then add this contribution to the denominator
    denominator = torch.add(
        original_denominator, torch.multiply(torch.exp(shifted_zeros), n)
    )
    return torch.divide(numerator, denominator)


softmax_n_shifted_zeros(torch.tensor([10] * 10), 1)

tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
        0.1000])

In [3]:
softmax_n_shifted_zeros(torch.tensor([-5] * 10), torch.e**0).sum()

tensor(0.0631)

In [4]:
softmax_n_shifted_zeros(torch.tensor([0] * 10), 1).sum()

tensor(0.9091)

In [5]:
import torch.nn.functional as F
import math

F.softmax(torch.tensor([math.log(10), math.log(100), math.log(1000)]), dim=-1)

tensor([0.0090, 0.0901, 0.9009])

In [10]:
t = torch.randn(10)

# Can pad F.softmax(t + [0])[:-1] to imitate softermax
a = F.softmax(torch.concat((t, torch.tensor([0]))), dim=-1)[:-1]

b = softmax_n_shifted_zeros(t, 1)

a, b

(tensor([0.0903, 0.1215, 0.0297, 0.0450, 0.0417, 0.3175, 0.0744, 0.1690, 0.0471,
         0.0086]),
 tensor([0.0903, 0.1215, 0.0297, 0.0450, 0.0417, 0.3175, 0.0744, 0.1690, 0.0471,
         0.0086]))