|
154 | 154 |
|
155 | 155 | # Solve the OT problem with the custom cost matrix
|
156 | 156 | P_city = ot.solve(C).plan
|
| 157 | +# the parameters a and b are not provided so uniform weights are assumed |
157 | 158 |
|
158 | 159 | # Compute the OT loss (equivalent to ot.solve(C).value)
|
159 | 160 | loss_city = np.sum(P_city * C)
|
|
177 | 178 | # But the same can be done with the :func:`ot.solve_sample` function by passing
|
178 | 179 | # :code:`metric='cityblock'` as argument.
|
179 | 180 | #
|
| 181 | +# The cost matrix can be computed with the :func:`ot.dist` function which |
| 182 | +# computes the pairwise distance between two sets of samples or can be provided |
| 183 | +# directly as a matrix by the user when no samples are available. |
| 184 | +# |
180 | 185 | # .. note::
|
181 | 186 | # The examples above use the new API of POT. The old API is still available
|
182 | 187 | # and and OT plan and loss can be computed with the :func:`ot.emd` and
|
@@ -388,7 +393,7 @@ def df(G):
|
388 | 393 | # sphinx_gallery_end_ignore
|
389 | 394 | # %%
|
390 | 395 | #
|
391 |
| -# Gromov-Wasserstein (GW) and Fused GW |
| 396 | +# Gromov-Wasserstein and Fused GW |
392 | 397 | # -------------------------------------
|
393 | 398 | #
|
394 | 399 | # Solve the Gromov-Wasserstein problem
|
@@ -519,3 +524,121 @@ def df(G):
|
519 | 524 | # pl.title("Unbalanced Entropic GW plan")
|
520 | 525 | # pl.show()
|
521 | 526 | # # sphinx_gallery_end_ignore
|
| 527 | +# %% |
| 528 | +# |
| 529 | +# Large scale OT |
| 530 | +# -------------- |
| 531 | +# |
| 532 | +# We discuss here strategies to solve large scale OT problems using approximations |
| 533 | +# of the exact OT problem. |
| 534 | +# |
| 535 | +# Large scale Sinkhorn |
| 536 | +# ~~~~~~~~~~~~~~~~~~~~ |
| 537 | +# |
| 538 | +# When having samples with a large number of points, the Sinkhorn algorithm can |
| 539 | +# be implemented in a Lazy version which is more memory efficient and avoids |
| 540 | +# the computation of the :math:`n \times m` cost matrix. |
| 541 | +# |
| 542 | +# POT provides two implementation of the lazy Sinkhorn algorithm that return their |
| 543 | +# results in a lazy form of type :class:`ot.utils.LazyTensor`. This object can be |
| 544 | +# used to compute the loss or the OT plan in a lazy way or to recover its values |
| 545 | +# in a dense form. |
| 546 | +# |
| 547 | + |
| 548 | +# Solve the Sinkhorn problem in a lazy way |
| 549 | +sol = ot.solve_sample(x1, x2, a, b, reg=1e-1, lazy=True) |
| 550 | + |
| 551 | +# Solve the sinkhoorn in a lazy way with geomloss |
| 552 | +sol_geo = ot.solve_sample(x1, x2, a, b, reg=1e-1, method="geomloss", lazy=True) |
| 553 | + |
| 554 | +# get the OT lazy plan and loss |
| 555 | +P_sink_lazy = sol.lazy_plan |
| 556 | + |
| 557 | +# recover values for Lazy plan |
| 558 | +P12 = P_sink_lazy[1, 2] |
| 559 | +P1dots = P_sink_lazy[1, :] |
| 560 | +P_sink_lazy_dense = P_sink_lazy[ |
| 561 | + : |
| 562 | +] # convert to dense matrix !!warning this can be memory consuming |
| 563 | + |
| 564 | +# sphinx_gallery_start_ignore |
| 565 | +pl.figure(1, (3, 3)) |
| 566 | +plot2D_samples_mat(x1, x2, P_sink_lazy_dense) |
| 567 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 568 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 569 | +pl.title("Lazy Sinkhorn OT plan") |
| 570 | +pl.show() |
| 571 | + |
| 572 | +pl.figure(2, (3, 1.7)) |
| 573 | +pl.imshow(P_sink_lazy_dense, cmap="Greys") |
| 574 | +pl.title("Lazy Sinkhorn OT plan") |
| 575 | +pl.show() |
| 576 | + |
| 577 | +# sphinx_gallery_end_ignore |
| 578 | +# |
| 579 | +# %% |
| 580 | +# |
| 581 | +# the first example shows how to solve the Sinkhorn problem in a lazy way with |
| 582 | +# the default POT implementation. The second example shows how to solve the |
| 583 | +# Sinkhorn problem in a lazy way with the PyKeops/Geomloss implementation that provides |
| 584 | +# a very efficient way to solve large scale problems on low dimensionality |
| 585 | +# samples. |
| 586 | +# |
| 587 | +# Factored and Low rank OT |
| 588 | +# ------------------------ |
| 589 | +# |
| 590 | +# The Sinkhorn algorithm can be implemented in a low rank version that |
| 591 | +# approximates the OT plan with a low rank matrix. This can be useful to |
| 592 | +# accelerate the computation of the OT plan for large scale problems. |
| 593 | +# A similar non-regularized version of low rank factorization is also available. |
| 594 | +# |
| 595 | + |
| 596 | +# Solve the Factored OT problem (use lazy=True for large scale) |
| 597 | +P_fact = ot.solve_sample(x1, x2, a, b, method="factored", rank=8).plan |
| 598 | + |
| 599 | +P_lowrank = ot.solve_sample(x1, x2, a, b, reg=0.1, method="lowrank", rank=8).plan |
| 600 | + |
| 601 | +# sphinx_gallery_start_ignore |
| 602 | +pl.figure(1, (6, 3)) |
| 603 | + |
| 604 | +pl.subplot(1, 2, 1) |
| 605 | +plot2D_samples_mat(x1, x2, P_fact) |
| 606 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 607 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 608 | +pl.title("Factored OT plan") |
| 609 | + |
| 610 | +pl.subplot(1, 2, 2) |
| 611 | +plot2D_samples_mat(x1, x2, P_lowrank) |
| 612 | +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) |
| 613 | +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) |
| 614 | +pl.title("Low rank OT plan") |
| 615 | +pl.show() |
| 616 | + |
| 617 | +pl.figure(2, (6, 1.7)) |
| 618 | + |
| 619 | +pl.subplot(1, 2, 1) |
| 620 | +pl.imshow(P_fact, cmap="Greys") |
| 621 | +pl.title("Factored OT plan") |
| 622 | + |
| 623 | +pl.subplot(1, 2, 2) |
| 624 | +pl.imshow(P_lowrank, cmap="Greys") |
| 625 | +pl.title("Low rank OT plan") |
| 626 | +pl.show() |
| 627 | + |
| 628 | +# sphinx_gallery_end_ignore |
| 629 | + |
| 630 | +# %% |
| 631 | +# |
| 632 | +# Gaussian OT with Bures-Wasserstein |
| 633 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 634 | +# |
| 635 | +# The Gaussian Wasserstein or Bures-Wasserstein distance is the Wasserstein distance |
| 636 | +# between Gaussian distributions. It can be used as an approximation of the |
| 637 | +# Wasserstein distance between empirical distributions by estimating the |
| 638 | +# covariance matrices of the samples. |
| 639 | +# |
| 640 | + |
| 641 | +# Compute the Bures-Wasserstein distance |
| 642 | +bw_value = ot.solve_sample(x1, x2, a, b, method="gaussian").value |
| 643 | + |
| 644 | +print(f"Bures-Wasserstein distance = {bw_value:1.3f}") |
0 commit comments