In [2]:
import torch
from torch import Tensor
class SoftSort(torch.nn.Module):
    def __init__(self, tau=1.0, hard=False, pow=1.0):
        super(SoftSort, self).__init__()
        self.hard = hard
        self.tau = tau
        self.pow = pow

    def forward(self, scores: Tensor):
        """
        scores: elements to be sorted. Typical shape: batch_size x n
        """
        scores = scores.unsqueeze(-1)
        sorted = scores.sort(descending=True, dim=1)[0]
        pairwise_diff = (scores.transpose(1, 2) - sorted).abs().pow(self.pow).neg() / self.tau
        P_hat = pairwise_diff.softmax(-1)

        if self.hard:
            P = torch.zeros_like(P_hat, device=P_hat.device)
            P.scatter_(-1, P_hat.topk(1, -1)[1], value=1)
            P_hat = (P - P_hat).detach() + P_hat
        return P_hat

In [8]:
import numpy
ss = SoftSort(hard=True)

In [16]:
value = torch.tensor([[1.0,2.0,5.0]], dtype=torch.float64)
mat = ss(-value)
mat

tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]], dtype=torch.float64)

In [18]:
torch.einsum('blk, bl -> bk', mat, value)

tensor([[1., 2., 5.]], dtype=torch.float64)

In [19]:
import torch

# Assuming dot_prod is your input tensor with shape [128, 49]
dot_prod = torch.randn(128, 49)  # Example input; replace with your actual tensor

# Create an instance of SoftSort
soft_sort = SoftSort(tau=1.0, hard=True)

# Use SoftSort to rearrange the values based on dot_prod
rearranged_values = soft_sort(dot_prod)

# The rearranged_values will have the same shape as dot_prod
print(rearranged_values.shape)  # Sh

torch.Size([128, 49, 49])


In [22]:
torch.einsum('blk, bl -> bk', rearranged_values, dot_prod)[0]

tensor([-1.1124,  0.3074, -1.1164,  0.3205, -1.1716,  2.3287, -0.8495,  1.1453,
         0.5923, -1.5569, -0.1205, -0.8293, -0.6210,  0.5117, -0.3681,  0.2268,
         1.3512,  0.9039, -1.2890,  0.0369, -0.4759,  1.0174, -0.2304, -1.6933,
        -1.0110,  0.5619, -1.6541, -1.0059, -0.8064, -0.7143,  0.1221,  0.5179,
         1.4403, -0.7206,  1.5478,  0.7394,  0.2535,  0.7082, -2.4703,  1.3706,
        -0.3273, -0.7106, -1.3301, -2.0515, -0.0367, -1.3309,  0.0812,  1.7324,
        -1.0255])

In [23]:
torch.einsum('blk, bl -> bk', rearranged_values, dot_prod)[1]

tensor([-0.6355,  1.7180, -0.5275, -0.8325, -0.7621,  0.5230,  1.1946,  0.2854,
        -0.3085, -0.0471, -2.4254, -0.4359, -0.7054, -1.3043, -0.3422, -0.0800,
         1.1700, -1.1339, -1.0904, -1.1093,  0.3365,  0.6108, -2.7919, -1.8271,
        -0.0098, -0.3422, -0.5166,  1.1289,  1.1492, -1.1392,  2.1992, -1.3139,
         2.1645,  0.2463,  1.5276, -2.4445, -1.2331,  0.8298,  0.0398,  0.0240,
         0.3940,  1.8558,  0.0504,  0.7366, -2.2436, -2.2673, -0.2939, -0.8592,
         1.4809])

In [29]:
dot_prod = torch.randn(3, 4)  # Example input; replace with your actual tensor
dot_prod

tensor([[-0.3010, -0.0103,  0.2229,  0.9927],
        [ 1.4477, -0.4305,  0.3110, -0.3709],
        [ 0.4433, -2.0150,  0.9629, -0.3408]])

In [34]:
import torch

# Assuming dot_prod is your input tensor with shape [128, 49]


# Create an instance of SoftSort
soft_sort = SoftSort(tau=1.0, hard=True)

# Use SoftSort to rearrange the values based on dot_prod
rearranged_values = soft_sort(-1 * dot_prod)

# The rearranged_values will have the same shape as dot_prod
print(rearranged_values.shape)  # Should be [128, 49]

torch.Size([3, 4, 4])


In [35]:
rearranged_values

tensor([[[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.]],

        [[0., 1., 0., 0.],
         [0., 0., 0., 1.],
         [0., 0., 1., 0.],
         [1., 0., 0., 0.]],

        [[0., 1., 0., 0.],
         [0., 0., 0., 1.],
         [1., 0., 0., 0.],
         [0., 0., 1., 0.]]])

In [37]:
dot_prod

tensor([[-0.3010, -0.0103,  0.2229,  0.9927],
        [ 1.4477, -0.4305,  0.3110, -0.3709],
        [ 0.4433, -2.0150,  0.9629, -0.3408]])

In [33]:
torch.einsum('blk, bl -> bk', rearranged_values, dot_prod)[1]

tensor([ 1.4477, -0.3709, -0.4305,  0.3110])

In [36]:
torch.einsum('blk, bl -> bk', rearranged_values, dot_prod)[1]

tensor([-0.3709,  1.4477,  0.3110, -0.4305])