In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# 1. Differentiable Soft Argmax Layer

In [6]:
class DifferentiableSoftArgmaxLayer(nn.Module):
    def __init__(self, dim=-1, temperature=1.0):
        super(DifferentiableSoftArgmaxLayer, self).__init__()
        self.dim = dim
        self.temperature = temperature

        self.computation = lambda x: torch.sum(
            F.softmax(x / self.temperature, dim=self.dim) * 
            torch.arange(x.size(self.dim), device=x.device, dtype=x.dtype).view(*([1] * (x.dim() - 1)), -1),
            dim=self.dim
        )

    def __call__(self, x):
        return torch.sum(
            F.softmax(x / self.temperature, dim=self.dim) * 
            torch.arange(x.size(self.dim), device=x.device, dtype=x.dtype).view(*([1] * (x.dim() - 1)), -1),
            dim=self.dim
        )

In [7]:
tensor_data = torch.tensor([[[ -10.4249,  12.3542, -10.5049, -10.4770, -14.2411, -11.1246, -10.6761],
                              [-20.8229,  -7.6676,  15.1833,  -7.6492,  -7.5959,  -7.9427,  -8.7889],
                              [-13.4050, -13.1791, -13.2362,   9.5195, -13.0957, -14.9319, -23.5854],
                              [-13.4611, -17.0730,  -7.6502,  -7.2442,  15.3833,  -7.8079,  -7.5043],
                              [ -4.7446,  -5.5294,  -4.8128,  -5.7451,  -4.8391,  18.0267,  -7.5303],
                              [ -4.6189,  -2.7684,  -3.2782,  -4.6394,  -2.5019,  -3.3491,  19.2203]]])

soft_argmax_layer = DifferentiableSoftArgmaxLayer(dim=-1, temperature=1.0)

positions = soft_argmax_layer(tensor_data)
positions

tensor([[1., 2., 3., 4., 5., 6.]])
