Skip to content

Commit db9942f

Browse files
committed
first shot done
1 parent 04ffaa4 commit db9942f

File tree

2 files changed

+128
-1
lines changed

2 files changed

+128
-1
lines changed

examples/plot_quickstart_guide.py

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@
154154

155155
# Solve the OT problem with the custom cost matrix
156156
P_city = ot.solve(C).plan
157+
# the parameters a and b are not provided so uniform weights are assumed
157158

158159
# Compute the OT loss (equivalent to ot.solve(C).value)
159160
loss_city = np.sum(P_city * C)
@@ -177,6 +178,10 @@
177178
# But the same can be done with the :func:`ot.solve_sample` function by passing
178179
# :code:`metric='cityblock'` as argument.
179180
#
181+
# The cost matrix can be computed with the :func:`ot.dist` function which
182+
# computes the pairwise distance between two sets of samples or can be provided
183+
# directly as a matrix by the user when no samples are available.
184+
#
180185
# .. note::
181186
# The examples above use the new API of POT. The old API is still available
182187
# and and OT plan and loss can be computed with the :func:`ot.emd` and
@@ -388,7 +393,7 @@ def df(G):
388393
# sphinx_gallery_end_ignore
389394
# %%
390395
#
391-
# Gromov-Wasserstein (GW) and Fused GW
396+
# Gromov-Wasserstein and Fused GW
392397
# -------------------------------------
393398
#
394399
# Solve the Gromov-Wasserstein problem
@@ -519,3 +524,121 @@ def df(G):
519524
# pl.title("Unbalanced Entropic GW plan")
520525
# pl.show()
521526
# # sphinx_gallery_end_ignore
527+
# %%
528+
#
529+
# Large scale OT
530+
# --------------
531+
#
532+
# We discuss here strategies to solve large scale OT problems using approximations
533+
# of the exact OT problem.
534+
#
535+
# Large scale Sinkhorn
536+
# ~~~~~~~~~~~~~~~~~~~~
537+
#
538+
# When having samples with a large number of points, the Sinkhorn algorithm can
539+
# be implemented in a Lazy version which is more memory efficient and avoids
540+
# the computation of the :math:`n \times m` cost matrix.
541+
#
542+
# POT provides two implementation of the lazy Sinkhorn algorithm that return their
543+
# results in a lazy form of type :class:`ot.utils.LazyTensor`. This object can be
544+
# used to compute the loss or the OT plan in a lazy way or to recover its values
545+
# in a dense form.
546+
#
547+
548+
# Solve the Sinkhorn problem in a lazy way
549+
sol = ot.solve_sample(x1, x2, a, b, reg=1e-1, lazy=True)
550+
551+
# Solve the sinkhoorn in a lazy way with geomloss
552+
sol_geo = ot.solve_sample(x1, x2, a, b, reg=1e-1, method="geomloss", lazy=True)
553+
554+
# get the OT lazy plan and loss
555+
P_sink_lazy = sol.lazy_plan
556+
557+
# recover values for Lazy plan
558+
P12 = P_sink_lazy[1, 2]
559+
P1dots = P_sink_lazy[1, :]
560+
P_sink_lazy_dense = P_sink_lazy[
561+
:
562+
] # convert to dense matrix !!warning this can be memory consuming
563+
564+
# sphinx_gallery_start_ignore
565+
pl.figure(1, (3, 3))
566+
plot2D_samples_mat(x1, x2, P_sink_lazy_dense)
567+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
568+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
569+
pl.title("Lazy Sinkhorn OT plan")
570+
pl.show()
571+
572+
pl.figure(2, (3, 1.7))
573+
pl.imshow(P_sink_lazy_dense, cmap="Greys")
574+
pl.title("Lazy Sinkhorn OT plan")
575+
pl.show()
576+
577+
# sphinx_gallery_end_ignore
578+
#
579+
# %%
580+
#
581+
# the first example shows how to solve the Sinkhorn problem in a lazy way with
582+
# the default POT implementation. The second example shows how to solve the
583+
# Sinkhorn problem in a lazy way with the PyKeops/Geomloss implementation that provides
584+
# a very efficient way to solve large scale problems on low dimensionality
585+
# samples.
586+
#
587+
# Factored and Low rank OT
588+
# ------------------------
589+
#
590+
# The Sinkhorn algorithm can be implemented in a low rank version that
591+
# approximates the OT plan with a low rank matrix. This can be useful to
592+
# accelerate the computation of the OT plan for large scale problems.
593+
# A similar non-regularized version of low rank factorization is also available.
594+
#
595+
596+
# Solve the Factored OT problem (use lazy=True for large scale)
597+
P_fact = ot.solve_sample(x1, x2, a, b, method="factored", rank=8).plan
598+
599+
P_lowrank = ot.solve_sample(x1, x2, a, b, reg=0.1, method="lowrank", rank=8).plan
600+
601+
# sphinx_gallery_start_ignore
602+
pl.figure(1, (6, 3))
603+
604+
pl.subplot(1, 2, 1)
605+
plot2D_samples_mat(x1, x2, P_fact)
606+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
607+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
608+
pl.title("Factored OT plan")
609+
610+
pl.subplot(1, 2, 2)
611+
plot2D_samples_mat(x1, x2, P_lowrank)
612+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
613+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
614+
pl.title("Low rank OT plan")
615+
pl.show()
616+
617+
pl.figure(2, (6, 1.7))
618+
619+
pl.subplot(1, 2, 1)
620+
pl.imshow(P_fact, cmap="Greys")
621+
pl.title("Factored OT plan")
622+
623+
pl.subplot(1, 2, 2)
624+
pl.imshow(P_lowrank, cmap="Greys")
625+
pl.title("Low rank OT plan")
626+
pl.show()
627+
628+
# sphinx_gallery_end_ignore
629+
630+
# %%
631+
#
632+
# Gaussian OT with Bures-Wasserstein
633+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
634+
#
635+
# The Gaussian Wasserstein or Bures-Wasserstein distance is the Wasserstein distance
636+
# between Gaussian distributions. It can be used as an approximation of the
637+
# Wasserstein distance between empirical distributions by estimating the
638+
# covariance matrices of the samples.
639+
#
640+
641+
# Compute the Bures-Wasserstein distance
642+
bw_value = ot.solve_sample(x1, x2, a, b, method="gaussian").value
643+
644+
print(f"Bures-Wasserstein distance = {bw_value:1.3f}")

ot/bregman/_empirical.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ def get_sinkhorn_lazytensor(X_a, X_b, f, g, metric="sqeuclidean", reg=1e-1, nx=N
5353
shape = (X_a.shape[0], X_b.shape[0])
5454

5555
def func(i, j, X_a, X_b, f, g, metric, reg):
56+
if isinstance(i, int):
57+
i = slice(i, i + 1)
58+
if isinstance(j, int):
59+
j = slice(j, j + 1)
5660
C = dist(X_a[i], X_b[j], metric=metric)
5761
return nx.exp(f[i, None] + g[None, j] - C / reg)
5862

0 commit comments

Comments
 (0)