You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi @tvayer ,
I have been following some of your works and the POT framework lately. I noticed that there isn't an autograd.Function available for Gromov-Wasserstein distance available so I started off creating it myself.
Can you let me know if/how you compute the Loss of gromov_wasserstein2 (given here) with respect to C1 and C2 cost matrices of source and target spaces.
I would need that because the backward() func of the corresponding autograd.Function I mentioned above would need that.
A look at what I did as of now :
import numpy as np
import torch
from torch.autograd import Function
from ot.gromov import gromov_wasserstein2
class GromovWassersteinLossFunction(Function):
"""Return GW Loss for input (C1,C2,p,q) """
@staticmethod
def forward(ctx, C1,C2,p,q):
# convert to numpy
C1 = C1.detach().cpu().numpy().astype(np.float64)
C2 = C2.detach().cpu().numpy().astype(np.float64)
p = p.detach().cpu().numpy().astype(np.float64)
q = q.detach().cpu().numpy().astype(np.float64)
p /= p.sum()
q /= q.sum()
T,log= gromov_wasserstein2(C1,C2,p,q,loss_fun='kl_loss',log=True)
T = torch.from_numpy(np.asarray(T))
grad_C1 = #TODO
grad_C2 = #TODO
mark_non_differentiable(p,q)
ctx.save_for_backward(grad_C1,grad_C2)
return torch.sum(T)
@staticmethod
def backward(ctx, grad_output):
grad_C10,grad_C20 = ctx.saved_tensors
grad_C1,grad_C2 = None,None
if ctx.needs_input_grad[0]:
grad_C1 = grad_C10
if ctx.needs_input_grad[1]:
grad_C2 = grad_C20
return grad_C1,grad_C2
def GW(C1,C2,p,q):
"""loss=gromov_wasserstein(C1,C2,p,q)"""
return GromovWassersteinLossFunction.apply(C1,C2,p,q)
It would be a huge help if you could just let me know in regards to whether or how you are computing the gradients (someone on the POT slack suggested that I compute them by hand but I couldn't fathom how to find the original equation for this func as I am a beginner in Optimal Transport).
The text was updated successfully, but these errors were encountered:
Hi @tvayer ,
I have been following some of your works and the POT framework lately. I noticed that there isn't an autograd.Function available for Gromov-Wasserstein distance available so I started off creating it myself.
Can you let me know if/how you compute the Loss of gromov_wasserstein2 (given here) with respect to C1 and C2 cost matrices of source and target spaces.
I would need that because the backward() func of the corresponding autograd.Function I mentioned above would need that.
A look at what I did as of now :
It would be a huge help if you could just let me know in regards to whether or how you are computing the gradients (someone on the POT slack suggested that I compute them by hand but I couldn't fathom how to find the original equation for this func as I am a beginner in Optimal Transport).
The text was updated successfully, but these errors were encountered: