Skip to content

Commit

Permalink
add tests for plots in potentials (#393)
Browse files Browse the repository at this point in the history
* add tests for plots in potentials

* add matplotlib to test env

* monkeypatch used properly

* add `plot_potential`
  • Loading branch information
marcocuturi committed Jul 11, 2023
1 parent 63a79f9 commit 1886edf
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ test = [
# tslearn needs numba, which isn't supported for 3.11
"tslearn>=0.5; python_version < '3.11'",
"lineax; python_version >= '3.9'",
"matplotlib",
]
docs = [
"sphinx>=4.0",
Expand Down
21 changes: 12 additions & 9 deletions src/ott/problems/linear/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def plot_ot_map(
self,
source: jnp.ndarray,
target: jnp.ndarray,
samples: Optional[jnp.ndarray] = None,
forward: bool = True,
ax: Optional["plt.Axes"] = None,
legend_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -193,8 +194,10 @@ def plot_ot_map(
Args:
source: samples from the source measure
target: samples from the target measure
forward: use the forward map from the potentials
if ``True``, otherwise use the inverse map
samples: extra samples to transport, either ``source`` (if ``forward``) or
``target`` (if not ``forward``) by default.
forward: use the forward map from the potentials if ``True``,
otherwise use the inverse map.
ax: axis to add the plot to
scatter_kwargs: additional kwargs passed into
:meth:`~matplotlib.axes.Axes.scatter`
Expand Down Expand Up @@ -247,8 +250,8 @@ def plot_ot_map(
)

# plot the transported samples
base_samples = source if forward else target
transported_samples = self.transport(base_samples, forward=forward)
samples = (source if forward else target) if samples is None else samples
transported_samples = self.transport(samples, forward=forward)
ax.scatter(
transported_samples[:, 0],
transported_samples[:, 1],
Expand All @@ -257,12 +260,12 @@ def plot_ot_map(
**scatter_kwargs,
)

for i in range(base_samples.shape[0]):
for i in range(samples.shape[0]):
ax.arrow(
base_samples[i, 0],
base_samples[i, 1],
transported_samples[i, 0] - base_samples[i, 0],
transported_samples[i, 1] - base_samples[i, 1],
samples[i, 0],
samples[i, 1],
transported_samples[i, 0] - samples[i, 0],
transported_samples[i, 1] - samples[i, 1],
color=[0.5, 0.5, 1],
alpha=0.3
)
Expand Down
11 changes: 10 additions & 1 deletion tests/problems/linear/potentials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pytest
from ott.geometry import costs, pointcloud
Expand Down Expand Up @@ -90,8 +92,9 @@ def test_entropic_potentials_dist(

@pytest.mark.fast.with_args(forward=[False, True], only_fast=0)
def test_entropic_potentials_displacement(
self, rng: jax.random.PRNGKeyArray, forward: bool
self, rng: jax.random.PRNGKeyArray, forward: bool, monkeypatch
):
"""Tests entropic displacements, as well as their plots."""
n1, n2, d = 96, 128, 2
eps = 1e-2
rng1, rng2, rng3, rng4 = jax.random.split(rng, 4)
Expand Down Expand Up @@ -122,6 +125,12 @@ def test_entropic_potentials_displacement(
error = jnp.mean(jnp.sum((expected_points - actual_points) ** 2, axis=-1))
assert error <= 0.3

# Test plot functionality, but ensure it does not block execution
monkeypatch.setattr(plt, "show", lambda: None)
potentials.plot_ot_map(x, y, x_test, forward=True)
potentials.plot_ot_map(x, y, y_test, forward=False)
potentials.plot_potential()

@pytest.mark.fast.with_args(
p=[1.3, 2.2, 1.0], forward=[False, True], only_fast=0
)
Expand Down

0 comments on commit 1886edf

Please sign in to comment.