diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index de76be307..0125af75c 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -242,6 +242,11 @@ def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 # not defined for `p=1` return mu.norm(z, self.q) ** self.q / self.q + def barycenter(self, weights: jnp.ndarray, + xs: jnp.ndarray) -> Tuple[jnp.ndarray, Any]: + """Output barycenter of vectors.""" + return jnp.average(xs, weights=weights, axis=0), None + def tree_flatten(self): # noqa: D102 return (), (self.p,) @@ -552,6 +557,11 @@ def f_h(x: jnp.ndarray) -> float: return f_h + def barycenter(self, weights: jnp.ndarray, + xs: jnp.ndarray) -> Tuple[jnp.ndarray, Any]: + """Output barycenter of vectors.""" + return jnp.average(xs, weights=weights, axis=0), None + def tree_flatten(self): # noqa: D102 return (self.scaling_reg, self.matrix), {"orthogonal": self.orthogonal}