diff --git a/pyproject.toml b/pyproject.toml index 2dcf0801..65e9237f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ test = [ # tslearn needs numba, which isn't supported for 3.11 "tslearn>=0.5; python_version < '3.11'", "lineax; python_version >= '3.9'", + "matplotlib", ] docs = [ "sphinx>=4.0", diff --git a/src/ott/problems/linear/potentials.py b/src/ott/problems/linear/potentials.py index 746bcf11..7ab22607 100644 --- a/src/ott/problems/linear/potentials.py +++ b/src/ott/problems/linear/potentials.py @@ -183,6 +183,7 @@ def plot_ot_map( self, source: jnp.ndarray, target: jnp.ndarray, + samples: Optional[jnp.ndarray] = None, forward: bool = True, ax: Optional["plt.Axes"] = None, legend_kwargs: Optional[Dict[str, Any]] = None, @@ -193,8 +194,10 @@ def plot_ot_map( Args: source: samples from the source measure target: samples from the target measure - forward: use the forward map from the potentials - if ``True``, otherwise use the inverse map + samples: extra samples to transport, either ``source`` (if ``forward``) or + ``target`` (if not ``forward``) by default. + forward: use the forward map from the potentials if ``True``, + otherwise use the inverse map. ax: axis to add the plot to scatter_kwargs: additional kwargs passed into :meth:`~matplotlib.axes.Axes.scatter` @@ -247,8 +250,8 @@ def plot_ot_map( ) # plot the transported samples - base_samples = source if forward else target - transported_samples = self.transport(base_samples, forward=forward) + samples = (source if forward else target) if samples is None else samples + transported_samples = self.transport(samples, forward=forward) ax.scatter( transported_samples[:, 0], transported_samples[:, 1], @@ -257,12 +260,12 @@ def plot_ot_map( **scatter_kwargs, ) - for i in range(base_samples.shape[0]): + for i in range(samples.shape[0]): ax.arrow( - base_samples[i, 0], - base_samples[i, 1], - transported_samples[i, 0] - base_samples[i, 0], - transported_samples[i, 1] - base_samples[i, 1], + samples[i, 0], + samples[i, 1], + transported_samples[i, 0] - samples[i, 0], + transported_samples[i, 1] - samples[i, 1], color=[0.5, 0.5, 1], alpha=0.3 ) diff --git a/tests/problems/linear/potentials_test.py b/tests/problems/linear/potentials_test.py index b8771219..f0f30696 100644 --- a/tests/problems/linear/potentials_test.py +++ b/tests/problems/linear/potentials_test.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import jax import jax.numpy as jnp +import matplotlib.pyplot as plt import numpy as np import pytest from ott.geometry import costs, pointcloud @@ -90,8 +92,9 @@ def test_entropic_potentials_dist( @pytest.mark.fast.with_args(forward=[False, True], only_fast=0) def test_entropic_potentials_displacement( - self, rng: jax.random.PRNGKeyArray, forward: bool + self, rng: jax.random.PRNGKeyArray, forward: bool, monkeypatch ): + """Tests entropic displacements, as well as their plots.""" n1, n2, d = 96, 128, 2 eps = 1e-2 rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) @@ -122,6 +125,12 @@ def test_entropic_potentials_displacement( error = jnp.mean(jnp.sum((expected_points - actual_points) ** 2, axis=-1)) assert error <= 0.3 + # Test plot functionality, but ensure it does not block execution + monkeypatch.setattr(plt, "show", lambda: None) + potentials.plot_ot_map(x, y, x_test, forward=True) + potentials.plot_ot_map(x, y, y_test, forward=False) + potentials.plot_potential() + @pytest.mark.fast.with_args( p=[1.3, 2.2, 1.0], forward=[False, True], only_fast=0 )