From cbff1c7320e8723d23cb468e62cf1b9ab3d48f42 Mon Sep 17 00:00:00 2001 From: Marco Cuturi Date: Thu, 7 Mar 2024 22:37:06 +0100 Subject: [PATCH] add barycenter operator for (convex) translation invariant costs (#498) --- src/ott/geometry/costs.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index de76be307..0125af75c 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -242,6 +242,11 @@ def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 # not defined for `p=1` return mu.norm(z, self.q) ** self.q / self.q + def barycenter(self, weights: jnp.ndarray, + xs: jnp.ndarray) -> Tuple[jnp.ndarray, Any]: + """Output barycenter of vectors.""" + return jnp.average(xs, weights=weights, axis=0), None + def tree_flatten(self): # noqa: D102 return (), (self.p,) @@ -552,6 +557,11 @@ def f_h(x: jnp.ndarray) -> float: return f_h + def barycenter(self, weights: jnp.ndarray, + xs: jnp.ndarray) -> Tuple[jnp.ndarray, Any]: + """Output barycenter of vectors.""" + return jnp.average(xs, weights=weights, axis=0), None + def tree_flatten(self): # noqa: D102 return (self.scaling_reg, self.matrix), {"orthogonal": self.orthogonal}