In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import scipy.stats as st
from scipy.special import betainc
from scipy.optimize import minimize

In [None]:
class MultiSoftHistogram(nn.Module):
    """
        Create soft histogram from samples
    """
    def __init__(self, threshold_1, threshold_2, sigma = 500):
        """
        Parameters
        ----------
        threshold_1 : float
            first softmax threshold
        threshold_2 : float
            second softmax threshold
        sigma : float
            Slope of sigmoid
        """
        super().__init__()
        self.threshold_1 = threshold_1
        self.threshold_2 = threshold_2
        self.sigma = sigma

    def forward(self, x):
        """Computes soft histogram"""
        high_hist = self.higher_than_threshold_1(x)
        middle_hist = self.between_threshold_1_and_2(x)
        low_hist = 1. - high_hist.sum(dim=1).reshape(-1,1) - middle_hist.sum(dim=1).reshape(-1,1)

        x = torch.cat((high_hist, middle_hist, low_hist), 1)
        x = x.sum(0) / x.sum()

        return x
    
    def higher_than_threshold_1(self,x):
        return torch.sigmoid(self.sigma * (x - self.threshold_1))
    
    def between_threshold_1_and_2(self,x):
        return torch.sigmoid(self.sigma * (x - self.threshold_2)) - torch.sigmoid(self.sigma * (x - self.threshold_1))

In [None]:
msh = MultiSoftHistogram(threshold_1 = 0.95, threshold_2 = 0.8)

In [None]:
x = torch.tensor(
    [
        [[0.02, 0.10, 0.88],[0.96, 0.01, 0.03], [0.4, 0.4, 0.2]],
        [[0.02, 0.10, 0.88],[0.96, 0.01, 0.03], [0.4, 0.4, 0.2]]
    ], 
    requires_grad = True)

In [None]:
msh(x)

In [None]:
x = torch.tensor([[0.02, 0.10, 0.01, 0.01, 0.86],
                  [0.97, 0.01, 0.01, 0.005, 0.005], 
                  [0.2, 0.2, 0.2, 0.2, 0.2]], requires_grad = True)

In [None]:
msh(x)