In [None]:
import ot

In [None]:

"""
Smooth and Sparse Optimal Transport solvers (KL an L2 reg.)

Implementation of :
Smooth and Sparse Optimal Transport.
Mathieu Blondel, Vivien Seguy, Antoine Rolet.
In Proc. of AISTATS 2018.
https://arxiv.org/abs/1710.06276

[17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal
Transport. Proceedings of the Twenty-First International Conference on
Artificial Intelligence and Statistics (AISTATS).

Original code from https://github.com/mblondel/smooth-ot/

"""

import cupy as np
import torch

def projection_simplex(V, z=1, axis=None):
    """ Projection of x onto the simplex, scaled by z

        P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2
    z: float or array
        If array, len(z) must be compatible with V
    axis: None or int
        - axis=None: project V by P(V.ravel(); z)
        - axis=1: project each V[i] by P(V[i]; z[i])
        - axis=0: project each V[:, j] by P(V[:, j]; z[j])
    """
    if axis == 1:
        n_features = V.shape[1]
        U = torch.sort(V, axis=1)[:, ::-1]
        z = torch.ones(len(V)) * z
        cssv = torch.cumsum(U, axis=1) - z[:, np.newaxis]
        ind = np.arange(n_features) + 1
        cond = U - cssv / ind > 0
        rho = np.count_nonzero(cond, axis=1)
        theta = cssv[np.arange(len(V)), rho - 1] / rho
        return np.maximum(V - theta[:, np.newaxis], 0)

    elif axis == 0:
        return projection_simplex(V.T, z, axis=1).T

    else:
        V = V.ravel().reshape(1, -1)
        return projection_simplex(V, z, axis=1).ravel()


def projection_simplex(V, z=1, axis=None):
    """ Projection of x onto the simplex, scaled by z

        P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2
    z: float or array
        If array, len(z) must be compatible with V
    axis: None or int
        - axis=None: project V by P(V.ravel(); z)
        - axis=1: project each V[i] by P(V[i]; z[i])
        - axis=0: project each V[:, j] by P(V[:, j]; z[j])
    """
    if axis == 1:
        n_features = V.shape[1]
        U = np.sort(V, axis=1)[:, ::-1]
        z = np.ones(len(V)) * z
        cssv = np.cumsum(U, axis=1) - z[:, np.newaxis]
        ind = np.arange(n_features) + 1
        cond = U - cssv / ind > 0
        rho = np.count_nonzero(cond, axis=1)
        theta = cssv[np.arange(len(V)), rho - 1] / rho
        return np.maximum(V - theta[:, np.newaxis], 0)

    elif axis == 0:
        return projection_simplex(V.T, z, axis=1).T

    else:
        V = V.ravel().reshape(1, -1)
        return projection_simplex(V, z, axis=1).ravel()

class Regularization(object):
    """Base class for Regularization objects

        Notes
        -----
        This class is not intended for direct use but as aparent for true
        regularizatiojn implementation.
    """

    def __init__(self, gamma=1.0):
        """

        Parameters
        ----------
        gamma: float
            Regularization parameter.
            We recover unregularized OT when gamma -> 0.

        """
        self.gamma = gamma

    def delta_Omega(X):
        """
        Compute delta_Omega(X[:, j]) for each X[:, j].
        delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y).

        Parameters
        ----------
        X: array, shape = len(a) x len(b)
            Input array.

        Returns
        -------
        v: array, len(b)
            Values: v[j] = delta_Omega(X[:, j])
        G: array, len(a) x len(b)
            Gradients: G[:, j] = nabla delta_Omega(X[:, j])
        """
        raise NotImplementedError

    def max_Omega(X, b):
        """
        Compute max_Omega_j(X[:, j]) for each X[:, j].
        max_Omega_j(x) = sup_{y >= 0, sum(y) = 1} y^T x - Omega(b[j] y) / b[j].

        Parameters
        ----------
        X: array, shape = len(a) x len(b)
            Input array.

        Returns
        -------
        v: array, len(b)
            Values: v[j] = max_Omega_j(X[:, j])
        G: array, len(a) x len(b)
            Gradients: G[:, j] = nabla max_Omega_j(X[:, j])
        """
        raise NotImplementedError

    def Omega(T):
        """
        Compute regularization term.

        Parameters
        ----------
        T: array, shape = len(a) x len(b)
            Input array.

        Returns
        -------
        value: float
            Regularization term.
        """
        raise NotImplementedError


class NegEntropy(Regularization):
    """ NegEntropy regularization """

    def delta_Omega(self, X):
        G = np.exp(X / self.gamma - 1)
        val = self.gamma * np.sum(G, axis=0)
        return val, G

    def max_Omega(self, X, b):
        max_X = np.max(X, axis=0) / self.gamma
        exp_X = np.exp(X / self.gamma - max_X)
        val = self.gamma * (np.log(np.sum(exp_X, axis=0)) + max_X)
        val -= self.gamma * np.log(b)
        G = exp_X / np.sum(exp_X, axis=0)
        return val, G

    def Omega(self, T):
        return self.gamma * np.sum(T * np.log(T))


class SquaredL2(Regularization):
    """ Squared L2 regularization """

    def delta_Omega(self, X):
        max_X = np.maximum(X, 0)
        val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma)
        G = max_X / self.gamma
        return val, G

    def max_Omega(self, X, b):
        G = projection_simplex(X / (b * self.gamma), axis=0)
        val = np.sum(X * G, axis=0)
        val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0)
        return val, G

    def Omega(self, T):
        return 0.5 * self.gamma * np.sum(T ** 2)

