Skip to content

Commit

Permalink
Fix/sinkdiv hessian (#321)
Browse files Browse the repository at this point in the history
* Use symmetric in `sinkhorn_divergence`

* Re-run notebook

* [ci skip] Regenerate notebook
  • Loading branch information
michalk8 committed Feb 27, 2023
1 parent 46b7567 commit f47a9af
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
7 changes: 5 additions & 2 deletions src/ott/solvers/linear/implicit_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions entering the implicit differentiation of Sinkhorn."""

from typing import TYPE_CHECKING, Callable, Optional, Tuple
import dataclasses
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -286,3 +286,6 @@ def gradient(
# Carries pullback onto original inputs, here geom, a and b.
_, pull_prob = jax.vjp(foc_prob, prob)
return pull_prob(vjp_gr)

def replace(self, **kwargs: Any) -> "ImplicitDiff": # noqa: D102
return dataclasses.replace(self, **kwargs)
4 changes: 3 additions & 1 deletion src/ott/tools/sinkhorn_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,10 @@ def _sinkhorn_divergence(
parallel_dual_updates=True,
momentum=acceleration.Momentum(start=0, value=0.5),
anderson=None,
# TODO(michalk8): implicit_diff
)
implicit_diff = kwargs.get("implicit_diff", None)
if implicit_diff is not None:
kwargs_symmetric["implicit_diff"] = implicit_diff.replace(symmetric=True)

out_xy = sinkhorn.solve(geometry_xy, a, b, **kwargs)
out_xx = sinkhorn.solve(geometry_xx, a, a, **kwargs_symmetric)
Expand Down

0 comments on commit f47a9af

Please sign in to comment.