From 09bffc0a7f411f003b8261afd36acb297935ab79 Mon Sep 17 00:00:00 2001 From: marcocuturi Date: Tue, 15 Nov 2022 00:36:51 +0100 Subject: [PATCH] last fixes. --- ott/core/potentials.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ott/core/potentials.py b/ott/core/potentials.py index 1770acb01..126348b5f 100644 --- a/ott/core/potentials.py +++ b/ott/core/potentials.py @@ -125,13 +125,13 @@ def _grad_g(self) -> Callable[[jnp.ndarray], jnp.ndarray]: @property def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]: assert isinstance(self.cost_fn, costs.TICost), ( - "Cost must be RBF and ", - "provide access to Legendre Legendre transform of `h`." + "Cost must be a `TICost` and " + "provide access to Legendre transform of `h`." ) return jax.vmap(jax.grad(self.cost_fn.h_legendre)) def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: - return [self._f, self._g, self.cost_fn], {"cor": self._cor} + return [self._f, self._g, self.cost_fn], {"corr": self._corr} @classmethod def tree_unflatten(