Skip to content

Commit

Permalink
remove unused barycenters functions
Browse files Browse the repository at this point in the history
  • Loading branch information
marcocuturi committed Jul 12, 2023
1 parent a5723c8 commit 8fc3dc9
Showing 1 changed file with 0 additions and 45 deletions.
45 changes: 0 additions & 45 deletions src/ott/tools/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np
import scipy

from ott import utils
from ott.geometry import pointcloud
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein
Expand Down Expand Up @@ -218,47 +217,3 @@ def animate(
interval=1000 / frame_rate,
blit=True
)


def _barycenters(
ax: "plt.Axes",
y: jnp.ndarray,
a: jnp.ndarray,
b: jnp.ndarray,
matrix: jnp.ndarray,
scale: int = 200
) -> None:
"""Plot 2-D sinkhorn barycenters."""
sa, sb = jnp.min(a) / scale, jnp.min(b) / scale
ax.scatter(*y.T, s=b / sb, edgecolors="k", marker="X", label="y")
tx = 1 / a[:, None] * jnp.matmul(matrix, y)
ax.scatter(*tx.T, s=a / sa, edgecolors="k", marker="X", label="T(x)")
ax.legend(fontsize=15)


def barycentric_projections(
arg: Union[Transport, jnp.ndarray],
a: jnp.ndarray = None,
b: jnp.ndarray = None,
matrix: jnp.ndarray = None,
ax: Optional["plt.Axes"] = None,
**kwargs
):
"""Plot the barycenters, from the Transport object or from arguments."""
if ax is None:
_, ax = plt.subplots(1, 1, figsize=(8, 5))

if utils.is_jax_array(arg):
if matrix is None:
raise ValueError("The `matrix` argument cannot be None.")

a = jnp.ones(matrix.shape[0]) / matrix.shape[0] if a is None else a
b = jnp.ones(matrix.shape[1]) / matrix.shape[1] if b is None else b
return _barycenters(ax, arg, a, b, matrix, **kwargs)

if isinstance(arg, gromov_wasserstein.GWOutput):
geom = arg.linear_state.geom
else:
geom = arg.geom

return _barycenters(ax, geom.y, arg.a, arg.b, arg.matrix, **kwargs)

0 comments on commit 8fc3dc9

Please sign in to comment.