Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproducing results from paper #557

Closed
zwei-beiner opened this issue Jul 8, 2024 · 2 comments
Closed

Reproducing results from paper #557

zwei-beiner opened this issue Jul 8, 2024 · 2 comments

Comments

@zwei-beiner
Copy link

zwei-beiner commented Jul 8, 2024

Hi, I'm trying to reproduce Figure 1 (right subfigure for d=7) from this paper: https://arxiv.org/abs/1810.02733

However, I am getting different results: The W2 distance is much larger when computed with OTT than in the paper, and larger epsilon gives larger W2, which is opposite to the figure in the paper. (Note that the color coding is opposite between the two figures.)

Figure produced with the attached script:
results

Figure in the paper:
paper

import functools
import jax
import jax.numpy as jnp
from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

import matplotlib.pyplot as plt
import numpy as np

# Samples from the hypercube
def make_samples(key, ndims, nsamples):
    return jax.random.uniform(key, shape=(nsamples, ndims))

@jax.jit
def W(x, y, epsilon):
    geom = pointcloud.PointCloud(x, y, epsilon=epsilon)
    ot_prob = linear_problem.LinearProblem(geom)
    solver = sinkhorn.Sinkhorn()
    ot = solver(ot_prob)
    return ot.primal_cost

def run():
    key = jax.random.PRNGKey(0)
    Epsilons = 10. ** jnp.arange(-3, 3)
    Nsamples = np.int64(np.exp(np.linspace(1.0, 2.5, 200)))
    ndims = 7
    numiter = 300

    @functools.partial(jax.jit, static_argnums=(2, 3))
    def calc_W(key, epsilon, nsamples, ndims):
        key, subkey = jax.random.split(key)
        x = make_samples(subkey, ndims, nsamples)
        key, subkey = jax.random.split(key)
        y = make_samples(subkey, ndims, nsamples)
        return W(x, y, epsilon)

    # Calculate log of W2 distances for all options in the paper
    results = jnp.log(jnp.asarray([[
            jax.vmap(
                lambda key, i: calc_W(key, epsilon, nsamples, ndims)
            )(jax.random.split(key, numiter), jnp.arange(numiter))
            for key, nsamples in zip(jax.random.split(key, len(Nsamples)), Nsamples)
        ] for key, epsilon in zip(jax.random.split(key, len(Epsilons)), Epsilons)
    ]))


    fig, ax = plt.subplots()
    for i, epsilon in enumerate(Epsilons):
        ax.errorbar(np.log(Nsamples), jnp.mean(results[i], axis=-1), yerr=jnp.std(results[i], axis=-1), label=f"epsilon={epsilon:3f}")
    ax.legend()
    fig.tight_layout()
    fig.savefig("results.png")
run()
@michalk8
Copy link
Collaborator

michalk8 commented Jul 8, 2024

  1. This seems to be wrong, consider using jnp.logspace(1.0, 2.5, 200):
image
  1. You're compute the $OT_\epsilon$, not the Sinkhorn Divergence (SD) $\tilde{W}_\epsilon$, please see here for the docs.

You can compute SD as:

from ott.tools import sinkhorn_divergence

@jax.jit
def W(x, y, epsilon):
    sinkhorn_divergence.segment_sinkhorn_divergence(
        pointcloud.PointCloud, x, y, epsilon=epsilon
    ).divergence

@zwei-beiner
Copy link
Author

Thanks a lot, I was not aware of this distinction. Minor correction: I'm guessing the function should be sinkhorn_divergence instead of segment_sinkhorn_divergence. If so, you can close the issue.

@michalk8 michalk8 closed this as completed Jul 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants