In [2]:
import torch
from torch import Tensor
class SoftSort(torch.nn.Module):
    def __init__(self, tau=1.0, hard=False, pow=1.0):
        super(SoftSort, self).__init__()
        self.hard = hard
        self.tau = tau
        self.pow = pow

    def forward(self, scores: Tensor):
        """
        scores: elements to be sorted. Typical shape: batch_size x n
        """
        scores = scores.unsqueeze(-1)
        sorted = scores.sort(descending=True, dim=1)[0]
        pairwise_diff = (scores.transpose(1, 2) - sorted).abs().pow(self.pow).neg() / self.tau
        P_hat = pairwise_diff.softmax(-1)

        if self.hard:
            P = torch.zeros_like(P_hat, device=P_hat.device)
            P.scatter_(-1, P_hat.topk(1, -1)[1], value=1)
            P_hat = (P - P_hat).detach() + P_hat
        return P_hat

In [8]:
import numpy
ss = SoftSort(hard=True)

In [16]:
value = torch.tensor([[1.0,2.0,5.0]], dtype=torch.float64)
mat = ss(-value)
mat

tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]], dtype=torch.float64)

In [18]:
torch.einsum('blk, bl -> bk', mat, value)

tensor([[1., 2., 5.]], dtype=torch.float64)