In [None]:
import torch
from torch import nn

import numpy as np
import scipy.stats as st

import matplotlib.pyplot as plt

In [None]:
class AdaptiveSoftHistogram(nn.Module):
    """
        Create soft histogram from samples
    """
    def __init__(self, bin_edges, sigma=500):
        """
        Parameters
        ----------
        
        sigma : float
            Slope of sigmoid
        """
        super().__init__()
        self.sigma = sigma
        t1 = bin_edges[:-1]
        t2 = bin_edges[1:]
        bin_widths = t2 - t1
        bin_centers = t1 + 0.5 * bin_widths

        self.delta = nn.Parameter(bin_widths.unsqueeze(1)/2, requires_grad=False)
        self.centers = nn.Parameter(bin_centers, requires_grad=False)

    def forward(self, x):
        """Computes soft histogram"""

        x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1)
        _t1 = torch.sigmoid(self.sigma * (x + self.delta))
        _t2 = torch.sigmoid(self.sigma * (x - self.delta))
        x = _t1 - _t2
        x = x.sum(dim=-1) + 1e-6 # epsilon for zero bins
        x = x / x.sum(dim=-1).unsqueeze(1)
        
        return x

In [None]:
d = st.beta(3, 3)
x = d.rvs(size=(32, 256))

plt.hist(x.flatten())

In [None]:
bin_edges=[0.0, 0.01, 0.05, 0.10, 0.90, 0.95, 0.99, 1.00]
h = AdaptiveSoftHistogram(torch.tensor(bin_edges))

In [None]:
h.centers

In [None]:
h.delta

In [None]:
y = h(torch.from_numpy(x))

In [None]:
y.mean(dim=0)

In [None]:
d.cdf(bin_edges) - np.insert(d.cdf(bin_edges)[:-1], 0, 0)