Skip to content

Commit

Permalink
last fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
marcocuturi committed Nov 14, 2022
1 parent 9ba0879 commit 09bffc0
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions ott/core/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ def _grad_g(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
@property
def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
assert isinstance(self.cost_fn, costs.TICost), (
"Cost must be RBF and ",
"provide access to Legendre Legendre transform of `h`."
"Cost must be a `TICost` and "
"provide access to Legendre transform of `h`."
)
return jax.vmap(jax.grad(self.cost_fn.h_legendre))

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [self._f, self._g, self.cost_fn], {"cor": self._cor}
return [self._f, self._g, self.cost_fn], {"corr": self._corr}

@classmethod
def tree_unflatten(
Expand Down

0 comments on commit 09bffc0

Please sign in to comment.