Skip to content

Commit

Permalink
add barycenter operator for (convex) translation invariant costs (#498)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcocuturi committed Mar 7, 2024
1 parent abad9a0 commit cbff1c7
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102
# not defined for `p=1`
return mu.norm(z, self.q) ** self.q / self.q

def barycenter(self, weights: jnp.ndarray,
xs: jnp.ndarray) -> Tuple[jnp.ndarray, Any]:
"""Output barycenter of vectors."""
return jnp.average(xs, weights=weights, axis=0), None

def tree_flatten(self): # noqa: D102
return (), (self.p,)

Expand Down Expand Up @@ -552,6 +557,11 @@ def f_h(x: jnp.ndarray) -> float:

return f_h

def barycenter(self, weights: jnp.ndarray,
xs: jnp.ndarray) -> Tuple[jnp.ndarray, Any]:
"""Output barycenter of vectors."""
return jnp.average(xs, weights=weights, axis=0), None

def tree_flatten(self): # noqa: D102
return (self.scaling_reg, self.matrix), {"orthogonal": self.orthogonal}

Expand Down

0 comments on commit cbff1c7

Please sign in to comment.