You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Unbalanced FGW is unstable especially when margins are provided. I played with epsilon and tau's but still doesn't converge. I think this happened after 41906a2
To Reproduce
importnumpyasnpimportjax.numpyasjnpfromott.geometryimportpointcloudfromott.solvers.quadraticimportsolve# Generating random data for x and yx=np.random.rand(96, 2) # 96 points in 2Dy=np.random.rand(96, 2) # Another 96 points in 2D# Create PointCloud instancesgeom_xx=pointcloud.PointCloud(x)
geom_yy=pointcloud.PointCloud(y)
geom_xy=pointcloud.PointCloud(x, y)
# a and b are vectors of ones with lengths matching the number of points in x and y, respectivelya=jnp.ones(x.shape[0])
b=jnp.ones(y.shape[0])
# Call solve function with the specified parameterssolve(geom_xx=geom_xx, geom_yy=geom_yy, geom_xy=geom_xy, tau_a=0.9, tau_b=0.9,
fused_penalty=1.0, epsilon=1.0, a=a, b=b)
The text was updated successfully, but these errors were encountered:
selmanozleyen
changed the title
Unbalanced FGW is doesn't converge when margins are provided
Unbalanced FGW doesn't converge when margins are provided
Apr 17, 2024
Hi @selmanozleyen , this seems to come from numerical imprecisions; more specifically, the NaNs come directly from initialization here, where marginal_1 is an array of all 0s (leads to a transport mass of 0), and later to the rescaling factor to be NaN.
I will take a look whether there's more numerically stable way of computing this, however simply using
@michalk8, as you said when I normalize it works. But when they don't sum to 1 it still doesn't work in many cases. For example see the cases below. I'd assume unbalanced ot to not expect marginals sum to 1
a=np.ones(x.shape[0])*2a[0:4] =1b=np.ones(y.shape[0])*2b[0:4] =1# or a=np.ones(x.shape[0])*2b=np.ones(y.shape[0])*2
Describe the bug
For application use case see tests from moscot https://github.com/theislab/moscot/actions/runs/8709537760/job/23889450330?pr=677
Unbalanced FGW is unstable especially when margins are provided. I played with epsilon and tau's but still doesn't converge. I think this happened after 41906a2
To Reproduce
The text was updated successfully, but these errors were encountered: