Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pytorch Question] How to differentiate GW wrt C1 and C2? #1

Open
ucalyptus opened this issue Sep 27, 2020 · 0 comments
Open

[Pytorch Question] How to differentiate GW wrt C1 and C2? #1

ucalyptus opened this issue Sep 27, 2020 · 0 comments

Comments

@ucalyptus
Copy link

Hi @tvayer ,
I have been following some of your works and the POT framework lately. I noticed that there isn't an autograd.Function available for Gromov-Wasserstein distance available so I started off creating it myself.
Can you let me know if/how you compute the Loss of gromov_wasserstein2 (given here) with respect to C1 and C2 cost matrices of source and target spaces.
I would need that because the backward() func of the corresponding autograd.Function I mentioned above would need that.

A look at what I did as of now :

import numpy as np
import torch
from torch.autograd import Function
from ot.gromov import gromov_wasserstein2
class GromovWassersteinLossFunction(Function):
    """Return GW Loss for input (C1,C2,p,q) """

    @staticmethod
    def forward(ctx, C1,C2,p,q):

        # convert to numpy
        C1 = C1.detach().cpu().numpy().astype(np.float64)
        C2 = C2.detach().cpu().numpy().astype(np.float64)
        p = p.detach().cpu().numpy().astype(np.float64)
        q = q.detach().cpu().numpy().astype(np.float64)
        p /= p.sum()
        q /= q.sum()
        T,log= gromov_wasserstein2(C1,C2,p,q,loss_fun='kl_loss',log=True)
        T = torch.from_numpy(np.asarray(T))
        grad_C1 = #TODO
        grad_C2 = #TODO
        mark_non_differentiable(p,q)
        ctx.save_for_backward(grad_C1,grad_C2)
        return torch.sum(T) 

    @staticmethod
    def backward(ctx, grad_output):

        grad_C10,grad_C20 = ctx.saved_tensors
        grad_C1,grad_C2 = None,None
        if ctx.needs_input_grad[0]:
          grad_C1 = grad_C10
        if ctx.needs_input_grad[1]:
          grad_C2 = grad_C20
        
        return grad_C1,grad_C2



def GW(C1,C2,p,q):
    """loss=gromov_wasserstein(C1,C2,p,q)"""
    return GromovWassersteinLossFunction.apply(C1,C2,p,q)


It would be a huge help if you could just let me know in regards to whether or how you are computing the gradients (someone on the POT slack suggested that I compute them by hand but I couldn't fathom how to find the original equation for this func as I am a beginner in Optimal Transport).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant