In [4]:
import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable, Function

In [5]:
class WassersteinLossVanilla(Function):
    def __init__(self,cost, lam = 1e-3, sinkhorn_iter = 50):
        super(WassersteinLossVanilla,self).__init__()
        
        # cost = matrix M = distance matrix
        # lam = lambda of type float > 0
        # sinkhorn_iter > 0
        # diagonal cost should be 0
        self.cost = cost
        self.lam = lam
        self.sinkhorn_iter = sinkhorn_iter
        self.na = cost.size(0)
        self.nb = cost.size(1)
        self.K = torch.exp(-self.cost/self.lam)
        self.KM = self.cost*self.K
        self.stored_grad = None
        
    def forward(self, pred, target):
        """pred: Batch * K: K = # mass points
           target: Batch * L: L = # mass points"""
        assert pred.size(1)==self.na
        assert target.size(1)==self.nb

        nbatch = pred.size(0)
        
        u = self.cost.new(nbatch, self.na).fill_(1.0/self.na)
        
        for i in range(self.sinkhorn_iter):
            v = target/(torch.mm(u,self.K.t())) # double check K vs. K.t() here and next line
            u = pred/(torch.mm(v,self.K))
            #print ("stability at it",i, "u",(u!=u).sum(),u.max(),"v", (v!=v).sum(), v.max())
            if (u!=u).sum()>0 or (v!=v).sum()>0 or u.max()>1e9 or v.max()>1e9: # u!=u is a test for NaN...
                # we have reached the machine precision
                # come back to previous solution and quit loop
                raise Exception(str(('Warning: numerical errrors',i+1,"u",(u!=u).sum(),u.max(),"v",(v!=v).sum(),v.max())))

        loss = (u*torch.mm(v,self.KM.t())).mean(0).sum() # double check KM vs KM.t()...
        grad = self.lam*u.log()/nbatch # check whether u needs to be transformed        
        grad = grad-torch.mean(grad,dim=1).expand_as(grad)
        grad = grad-torch.mean(grad,dim=1).expand_as(grad) # does this help over only once?
        self.stored_grad = grad

        dist = self.cost.new((loss,))
        return dist
    def backward(self, grad_output):
        #print (grad_output.size(), self.stored_grad.size())
        return self.stored_grad*grad_output[0],None

In [6]:
a = Variable(torch.zeros(2,10), requires_grad = True)
b = torch.zeros(a.size())
a.data[0][3:7] = 1
b.data[0][3:5] = 1
a.data[1][1:3] = 1
b.data[1][8:] = 1

In [None]:
wv = WassersteinLossVanilla