Skip to content

Commit

Permalink
[ci skip] Mention gamma > 0
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Feb 27, 2023
1 parent 68f088b commit f6420ff
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ class SoftDTW(CostFn):
"""Soft dynamic time warping (DTW) cost :cite:`cuturi:17`.
Args:
gamma: Smoothing parameter for the soft-min operator.
gamma: Smoothing parameter :math:`> 0` for the soft-min operator.
ground_cost: Ground cost function. If ``None``,
use :class:`~ott.geometry.costs.SqEuclidean`.
debiased: Whether to compute the debiased soft-DTW :cite:`blondel:21`.
Expand All @@ -779,7 +779,7 @@ def __init__(
self.ground_cost = SqEuclidean() if ground_cost is None else ground_cost
self.debiased = debiased

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: # noqa: D102
c_xy = self._soft_dtw(x, y)
if self.debiased:
return c_xy - 0.5 * (self._soft_dtw(x, x) + self._soft_dtw(y, y))
Expand Down Expand Up @@ -829,11 +829,11 @@ def body(
(_, carry), _ = jax.lax.scan(body, init, model_matrix[2:])
return carry[-1]

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (self.gamma, self.ground_cost), {"debiased": self.debiased}

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(*children, **aux_data)


Expand Down
2 changes: 1 addition & 1 deletion src/ott/math/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def softmin(
Args:
x: Input data.
gamma: Smoothing parameter.
gamma: Smoothing parameter :math:`> 0`.
axis: Axis or axes over which to operate. If ``None``, use flattened input.
Returns:
Expand Down
6 changes: 3 additions & 3 deletions tests/geometry/costs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def test_stronger_regularization_increases_sparsity(


@pytest.mark.skipif(ts_metrics is None, reason="Not supported for Python 3.11")
@pytest.mark.fast
@pytest.mark.fast()
class TestSoftDTW:

@pytest.mark.parametrize("n", [11, 16])
Expand All @@ -208,7 +208,7 @@ def test_soft_dtw(

np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("debiased,jit", [(False, True), (True, False)])
@pytest.mark.parametrize(("debiased", "jit"), [(False, True), (True, False)])
def test_soft_dtw_debiased(
self,
rng: jax.random.PRNGKeyArray,
Expand All @@ -235,7 +235,7 @@ def test_soft_dtw_debiased(
np.testing.assert_allclose(cost_fn(t1, t1), 0.0, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(cost_fn(t2, t2), 0.0, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("debiased,jit", [(False, False), (True, True)])
@pytest.mark.parametrize(("debiased", "jit"), [(False, False), (True, True)])
@pytest.mark.parametrize("gamma", [1e-2, 1])
def test_soft_dtw_grad(
self, rng: jax.random.PRNGKeyArray, debiased: bool, jit: bool,
Expand Down

0 comments on commit f6420ff

Please sign in to comment.