You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
However, I am getting different results: The W2 distance is much larger when computed with OTT than in the paper, and larger epsilon gives larger W2, which is opposite to the figure in the paper. (Note that the color coding is opposite between the two figures.)
Figure produced with the attached script:
Figure in the paper:
importfunctoolsimportjaximportjax.numpyasjnpfromott.geometryimportcosts, pointcloudfromott.problems.linearimportlinear_problemfromott.solvers.linearimportsinkhornimportmatplotlib.pyplotaspltimportnumpyasnp# Samples from the hypercubedefmake_samples(key, ndims, nsamples):
returnjax.random.uniform(key, shape=(nsamples, ndims))
@jax.jitdefW(x, y, epsilon):
geom=pointcloud.PointCloud(x, y, epsilon=epsilon)
ot_prob=linear_problem.LinearProblem(geom)
solver=sinkhorn.Sinkhorn()
ot=solver(ot_prob)
returnot.primal_costdefrun():
key=jax.random.PRNGKey(0)
Epsilons=10.**jnp.arange(-3, 3)
Nsamples=np.int64(np.exp(np.linspace(1.0, 2.5, 200)))
ndims=7numiter=300@functools.partial(jax.jit, static_argnums=(2, 3))defcalc_W(key, epsilon, nsamples, ndims):
key, subkey=jax.random.split(key)
x=make_samples(subkey, ndims, nsamples)
key, subkey=jax.random.split(key)
y=make_samples(subkey, ndims, nsamples)
returnW(x, y, epsilon)
# Calculate log of W2 distances for all options in the paperresults=jnp.log(jnp.asarray([[
jax.vmap(
lambdakey, i: calc_W(key, epsilon, nsamples, ndims)
)(jax.random.split(key, numiter), jnp.arange(numiter))
forkey, nsamplesinzip(jax.random.split(key, len(Nsamples)), Nsamples)
] forkey, epsiloninzip(jax.random.split(key, len(Epsilons)), Epsilons)
]))
fig, ax=plt.subplots()
fori, epsiloninenumerate(Epsilons):
ax.errorbar(np.log(Nsamples), jnp.mean(results[i], axis=-1), yerr=jnp.std(results[i], axis=-1), label=f"epsilon={epsilon:3f}")
ax.legend()
fig.tight_layout()
fig.savefig("results.png")
run()
The text was updated successfully, but these errors were encountered:
This seems to be wrong, consider using jnp.logspace(1.0, 2.5, 200):
You're compute the $OT_\epsilon$, not the Sinkhorn Divergence (SD) $\tilde{W}_\epsilon$, please see here for the docs.
You can compute SD as:
fromott.toolsimportsinkhorn_divergence@jax.jitdefW(x, y, epsilon):
sinkhorn_divergence.segment_sinkhorn_divergence(
pointcloud.PointCloud, x, y, epsilon=epsilon
).divergence
Thanks a lot, I was not aware of this distinction. Minor correction: I'm guessing the function should be sinkhorn_divergence instead of segment_sinkhorn_divergence. If so, you can close the issue.
Hi, I'm trying to reproduce Figure 1 (right subfigure for d=7) from this paper: https://arxiv.org/abs/1810.02733
However, I am getting different results: The W2 distance is much larger when computed with OTT than in the paper, and larger epsilon gives larger W2, which is opposite to the figure in the paper. (Note that the color coding is opposite between the two figures.)
Figure produced with the attached script:
Figure in the paper:
The text was updated successfully, but these errors were encountered: