Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Grid.prepare_divergences #563

Merged
merged 2 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/ott/geometry/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,7 @@ def prepare_divergences(
**kwargs: Any
) -> Tuple["Grid", ...]:
"""Instantiate the geometries used for a divergence computation."""
grid_size = kwargs.pop("grid_size", None)
x = kwargs.pop("x", args)

sep_grid = cls(x=x, grid_size=grid_size, **kwargs)
sep_grid = cls(*args, **kwargs)
size = 2 if static_b else 3
return tuple(sep_grid for _ in range(size))

Expand Down
24 changes: 22 additions & 2 deletions tests/tools/sinkhorn_divergence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple

import pytest

import jax
import jax.numpy as jnp
import numpy as np

from ott.geometry import costs, geometry, pointcloud
from ott.geometry import costs, geometry, grid, pointcloud
from ott.solvers import linear
from ott.solvers.linear import acceleration
from ott.tools import sinkhorn_divergence
Expand Down Expand Up @@ -436,3 +436,23 @@ def loss_fn(cloud_a: jnp.ndarray, cloud_b: jnp.ndarray) -> float:
np.testing.assert_allclose(
custom_grad, finite_diff_grad, rtol=1e-2, atol=1e-2
)

@pytest.mark.parametrize("grid_size", [(5,), (2, 3), (3, 4, 5)])
def test_grid_geometry(self, rng: jax.Array, grid_size: Tuple[int, ...]):
rng1, rng2 = jax.random.split(rng, 2)
gs = (5,)

a = jax.random.uniform(rng1, shape=gs)
a = a / jnp.sum(a)
b = jax.random.uniform(rng2, shape=gs)
b = b / jnp.sum(b)

out = sinkhorn_divergence.sinkhorn_divergence(
grid.Grid,
grid_size=gs,
a=a,
b=b,
epsilon=1e-1,
)

assert jnp.isfinite(out.divergence)
Loading