# Batch Sinkhorn Iteration Wasserstein Distance

Thomas Viehmann

This notebook implements sinkhorn iteration wasserstein distance layers.

## Important note: This is under construction and does not yet work as well as it should.

In [1]:
import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable, Function
from torchvision import datasets, transforms


The following is a "plain sinkhorn" implementation that could be used in
[C. Frogner et. al.: Learning with a Wasserstein Loss](https://arxiv.org/abs/1506.05439)

Note that we use a different convention for $\lambda$ (i.e. we use $\lambda$ as the weight for the regularisation, later versions of the above use $\lambda^-1$ as the weight).

The implementation has benefitted from

- Chiyuan Zhang's implementation in [Mocha](https://github.com/pluskid/Mocha.jl),
- Rémi Flamary's implementation of various sinkhorn algorithms in [Python Optimal Transport](https://github.com/rflamary/POT)

Thank you!

In [319]:
class WassersteinLossVanilla(Function):
    def __init__(self,cost, lam = 1e-3, sinkhorn_iter = 50):
        super(WassersteinLossVanilla,self).__init__()
        
        # cost = matrix M = distance matrix
        # lam = lambda of type float > 0
        # sinkhorn_iter > 0
        # diagonal cost should be 0
        self.cost = cost
        self.lam = lam
        self.sinkhorn_iter = sinkhorn_iter
        self.na = cost.size(0)
        self.nb = cost.size(1)
        self.K = torch.exp(-self.cost/self.lam)
        self.KM = self.cost*self.K
        self.stored_grad = None
        
    def forward(self, pred, target):
        """pred: Batch * K: K = # mass points
           target: Batch * L: L = # mass points"""
        assert pred.size(1)==self.na
        assert target.size(1)==self.nb

        nbatch = pred.size(0)
        
        u = self.cost.new(nbatch, self.na).fill_(1.0/self.na)
        
        for i in range(self.sinkhorn_iter):
            v = target/(torch.mm(u,self.K.t())) # double check K vs. K.t() here and next line
            u = pred/(torch.mm(v,self.K))
            #print ("stability at it",i, "u",(u!=u).sum(),u.max(),"v", (v!=v).sum(), v.max())
            if (u!=u).sum()>0 or (v!=v).sum()>0 or u.max()>1e9 or v.max()>1e9: # u!=u is a test for NaN...
                # we have reached the machine precision
                # come back to previous solution and quit loop
                raise Exception(str(('Warning: numerical errrors',i+1,"u",(u!=u).sum(),u.max(),"v",(v!=v).sum(),v.max())))

        loss = (u*torch.mm(v,self.KM.t())).mean(0).sum() # double check KM vs KM.t()...
        grad = self.lam*u.log()/nbatch # check whether u needs to be transformed        
        grad = grad-torch.mean(grad,dim=1).expand_as(grad)
        grad = grad-torch.mean(grad,dim=1).expand_as(grad) # does this help over only once?
        self.stored_grad = grad

        dist = self.cost.new((loss,))
        return dist
    def backward(self, grad_output):
        #print (grad_output.size(), self.stored_grad.size())
        return self.stored_grad*grad_output[0],None


The following is a variant of the "log-stabilized sinkhorn" algorithm as described by [B. Schmitzer: Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems](https://arxiv.org/abs/1610.06519).
However, the author (for his application of computing the transport map for a single pair of measures) uses a form that modifies the $K$ matrix. This makes is less suitable for processing (mini-) batches, where we want to avoid the additional dimension.

To the best of my knowledge, this is the first implementation of a batch stabilized sinkhorn algorithm and I would appreciate if you find it useful, you could credit
*Thomas Viehmann: Batch Sinkhorn Iteration Wasserstein Distance*, [https://github.com/t-vi/pytorch-tvmisc/wasserstein-distance/Pytorch_Wasserstein.ipynb](https://github.com/t-vi/pytorch-tvmisc/wasserstein-distance/Pytorch_Wasserstein.ipynb).


In [505]:
class WassersteinLossStab(Function):
    def __init__(self,cost, lam = 1e-3, sinkhorn_iter = 50):
        super(WassersteinLossStab,self).__init__()
        
        # cost = matrix M = distance matrix
        # lam = lambda of type float > 0
        # sinkhorn_iter > 0
        # diagonal cost should be 0
        self.cost = cost
        self.lam = lam
        self.sinkhorn_iter = sinkhorn_iter
        self.na = cost.size(0)
        self.nb = cost.size(1)
        self.K = torch.exp(-self.cost/self.lam)
        self.KM = self.cost*self.K
        self.stored_grad = None
        
    def forward(self, pred, target):
        """pred: Batch * K: K = # mass points
           target: Batch * L: L = # mass points"""
        assert pred.size(1)==self.na
        assert target.size(1)==self.nb

        batch_size = pred.size(0)
        
        log_a, log_b = torch.log(pred), torch.log(target)
        log_u = self.cost.new(batch_size, self.na).fill_(-numpy.log(self.na))
        log_v = self.cost.new(batch_size, self.nb).fill_(-numpy.log(self.nb))
        
        for i in range(self.sinkhorn_iter):
            log_u_max = torch.max(log_u, dim=1)[0]
            u_stab = torch.exp(log_u-log_u_max.expand_as(log_u))
            log_v = log_b - torch.log(torch.mm(self.K.t(),u_stab.t()).t()) - log_u_max.expand_as(log_v)
            log_v_max = torch.max(log_v, dim=1)[0]
            v_stab = torch.exp(log_v-log_v_max.expand_as(log_v))
            log_u = log_a - torch.log(torch.mm(self.K, v_stab.t()).t()) - log_v_max.expand_as(log_u)

        log_v_max = torch.max(log_v, dim=1)[0]
        v_stab = torch.exp(log_v-log_v_max.expand_as(log_v))
        logcostpart1 = torch.log(torch.mm(self.KM,v_stab.t()).t())+log_v_max.expand_as(log_u)
        wnorm = torch.exp(log_u+logcostpart1).mean(0).sum() # sum(1) for per item pair loss...
        grad = log_u*self.lam
        grad = grad-torch.mean(grad,dim=1).expand_as(grad)
        grad = grad-torch.mean(grad,dim=1).expand_as(grad) # does this help over only once?
        grad = grad/batch_size
        
        self.stored_grad = grad

        return self.cost.new((wnorm,))
    def backward(self, grad_output):
        #print (grad_output.size(), self.stored_grad.size())
        #print (self.stored_grad, grad_output)
        res = grad_output.new()
        res.resize_as_(self.stored_grad).copy_(self.stored_grad)
        if grad_output[0] != 1:
            res.mul_(grad_output[0])
        return res,None


We may test our implementation against Rémi Flamary's algorithms in [Python Optimal Transport](https://github.com/rflamary/POT).

In [506]:
import ot
import numpy
from matplotlib import pyplot
%matplotlib inline

In [507]:
# test problem from Python Optimal Transport
n=100
a=ot.datasets.get_1D_gauss(n,m=20,s=10).astype(numpy.float32)
b=ot.datasets.get_1D_gauss(n,m=60,s=30).astype(numpy.float32)
c=ot.datasets.get_1D_gauss(n,m=40,s=20).astype(numpy.float32)
a64=ot.datasets.get_1D_gauss(n,m=20,s=10).astype(numpy.float64)
b64=ot.datasets.get_1D_gauss(n,m=60,s=30).astype(numpy.float64)
c64=ot.datasets.get_1D_gauss(n,m=40,s=20).astype(numpy.float64)
# distance function
x=numpy.arange(n,dtype=numpy.float32)
M=(x[:,numpy.newaxis]-x[numpy.newaxis,:])**2
M/=M.max()
x64=numpy.arange(n,dtype=numpy.float64)
M64=(x64[:,numpy.newaxis]-x64[numpy.newaxis,:])**2
M64/=M64.max()


In [508]:
transp = ot.bregman.sinkhorn(a,b,M,reg=1e-3)
transp2 = ot.bregman.sinkhorn_stabilized(a,b,M,reg=1e-3)

In [509]:
(transp*M).sum(), (transp2*M).sum()

(0.15025606400382638, 0.1502669613228855)

In [532]:
cabt = Variable(torch.from_numpy(numpy.stack((c,a,b),axis=0)))
abct = Variable(torch.from_numpy(numpy.stack((a,b,c),axis=0)))

In [594]:
lossvanilla = WassersteinLossVanilla(torch.from_numpy(M), lam=0.1)
loss = lossvanilla
losses = loss(cabt,abct), loss(cabt[:1],abct[:1]), loss(cabt[1:2],abct[1:2]), loss(cabt[2:],abct[2:])
sum(losses[1:])/3, losses

(Variable containing:
  0.1053
 [torch.FloatTensor of size 1], (Variable containing:
   0.1053
  [torch.FloatTensor of size 1], Variable containing:
  1.00000e-02 *
    7.5846
  [torch.FloatTensor of size 1], Variable containing:
   0.1773
  [torch.FloatTensor of size 1], Variable containing:
  1.00000e-02 *
    6.2796
  [torch.FloatTensor of size 1]))

In [595]:
loss = WassersteinLossStab(torch.from_numpy(M), lam=0.1)
losses = loss(cabt,abct), loss(cabt[:1],abct[:1]), loss(cabt[1:2],abct[1:2]), loss(cabt[2:],abct[2:])
sum(losses[1:])/3, losses

(Variable containing:
  0.1053
 [torch.FloatTensor of size 1], (Variable containing:
   0.1053
  [torch.FloatTensor of size 1], Variable containing:
  1.00000e-02 *
    7.5846
  [torch.FloatTensor of size 1], Variable containing:
   0.1773
  [torch.FloatTensor of size 1], Variable containing:
  1.00000e-02 *
    6.2796
  [torch.FloatTensor of size 1]))

The stabilized version can handle the extended range needed to get closer to the Python Optimal Transport loss.

In [597]:
transp3 = ot.bregman.sinkhorn_stabilized(a,b,M,reg=1e-2)
loss = WassersteinLossStab(torch.from_numpy(M), lam=0.01)
(transp3*M).sum(), loss(cabt[1:2],abct[1:2]).data[0]

(0.15445974773824594, 0.15445145964622498)

By the linear expansion, we should have
$$
L(x2) \approx L(x1)+\nabla L(\frac{x1+x2}{2})(x2-x1),
$$
so in particular we can see if for an example
$L(x+\epsilon \nabla L)-L(x1) / \epsilon \|\nabla L\|^2 \approx 1$.

This seems to be the case ... sometimes.

In [599]:
theloss = WassersteinLossStab(torch.from_numpy(M), lam=0.01, sinkhorn_iter=50)
cabt = Variable(torch.from_numpy(numpy.stack((c,a,b),axis=0)))
abct = Variable(torch.from_numpy(numpy.stack((a,b,c),axis=0)),requires_grad=True)
lossv1 = theloss(abct,cabt)
lossv1.backward()
grv = abct.grad
epsilon = 1e-5
abctv2 = Variable(abct.data-epsilon*grv.data, requires_grad=True)
lossv2 = theloss(abctv2, cabt)
lossv2.backward()
grv2 = abctv2.grad
(lossv1.data-lossv2.data)/(epsilon*((0.5*(grv.data+grv2.data))**2).sum()) # should be around 1



 0.9228
[torch.FloatTensor of size 1]

Naturally, one has to check whether the abctv2 is a valid probability distribution (i.e. all entries $>0$). It seems that the range of $\lambda$ in which the gradient works well is somewhat limited. This may point to a bug in the implementation.

Note also that feeding the same distribution in both arguments results in a NaN, when 0 is the correct answer.