-
Notifications
You must be signed in to change notification settings - Fork 78
/
sinkhorn.py
1060 lines (920 loc) · 44.5 KB
/
sinkhorn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""A Jax implementation of the Sinkhorn algorithm."""
from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple
import jax
import jax.numpy as jnp
import numpy as np
from ott.core import anderson as anderson_lib
from ott.core import fixed_point_loop
from ott.core import implicit_differentiation as implicit_lib
from ott.core import initializers as init_lib
from ott.core import linear_problems
from ott.core import momentum as momentum_lib
from ott.core import potentials, unbalanced_functions
from ott.geometry import geometry
class SinkhornState(NamedTuple):
"""Holds the state variables used to solve OT with Sinkhorn."""
errors: Optional[jnp.ndarray] = None
fu: Optional[jnp.ndarray] = None
gv: Optional[jnp.ndarray] = None
old_fus: Optional[jnp.ndarray] = None
old_mapped_fus: Optional[jnp.ndarray] = None
def set(self, **kwargs: Any) -> 'SinkhornState':
"""Return a copy of self, with potential overwrites."""
return self._replace(**kwargs)
def solution_error(
self, ot_prob: linear_problems.LinearProblem, norm_error: Sequence[int],
lse_mode: bool
) -> jnp.ndarray:
return solution_error(self.fu, self.gv, ot_prob, norm_error, lse_mode)
def ent_reg_cost(
self, ot_prob: linear_problems.LinearProblem, lse_mode: bool
) -> float:
return ent_reg_cost(self.fu, self.gv, ot_prob, lse_mode)
def solution_error(
f_u: jnp.ndarray, g_v: jnp.ndarray, ot_prob: linear_problems.LinearProblem,
norm_error: Sequence[int], lse_mode: bool
) -> jnp.ndarray:
"""Given two potential/scaling solutions, computes deviation to optimality.
When the ``ot_prob`` problem is balanced, this is simply deviation to the
target marginals defined in ``ot_prob.a`` and ``ot_prob.b``. When the problem
is unbalanced, additional quantities must be taken into account.
Args:
f_u: jnp.ndarray, potential or scaling
g_v: jnp.ndarray, potential or scaling
ot_prob: linear OT problem
norm_error: int, p-norm used to compute error.
lse_mode: True if log-sum-exp operations, False if kernel vector products.
Returns:
a positive number quantifying how far from optimality current solution is.
"""
if ot_prob.is_balanced:
return marginal_error(
f_u, g_v, ot_prob.b, ot_prob.geom, 0, norm_error, lse_mode
)
# In the unbalanced case, we compute the norm of the gradient.
# the gradient is equal to the marginal of the current plan minus
# the gradient of < z, rho_z(exp^(-h/rho_z) -1> where z is either a or b
# and h is either f or g. Note this is equal to z if rho_z → inf, which
# is the case when tau_z → 1.0
if lse_mode:
grad_a = unbalanced_functions.grad_of_marginal_fit(
ot_prob.a, f_u, ot_prob.tau_a, ot_prob.epsilon
)
grad_b = unbalanced_functions.grad_of_marginal_fit(
ot_prob.b, g_v, ot_prob.tau_b, ot_prob.epsilon
)
else:
u = ot_prob.geom.potential_from_scaling(f_u)
v = ot_prob.geom.potential_from_scaling(g_v)
grad_a = unbalanced_functions.grad_of_marginal_fit(
ot_prob.a, u, ot_prob.tau_a, ot_prob.epsilon
)
grad_b = unbalanced_functions.grad_of_marginal_fit(
ot_prob.b, v, ot_prob.tau_b, ot_prob.epsilon
)
err = marginal_error(f_u, g_v, grad_a, ot_prob.geom, 1, norm_error, lse_mode)
err += marginal_error(f_u, g_v, grad_b, ot_prob.geom, 0, norm_error, lse_mode)
return err
def marginal_error(
f_u: jnp.ndarray,
g_v: jnp.ndarray,
target: jnp.ndarray,
geom: geometry.Geometry,
axis: int = 0,
norm_error: Sequence[int] = (1,),
lse_mode: bool = True
) -> jnp.asarray:
"""Output how far Sinkhorn solution is w.r.t target.
Args:
f_u: a vector of potentials or scalings for the first marginal.
g_v: a vector of potentials or scalings for the second marginal.
target: target marginal.
geom: Geometry object.
axis: axis (0 or 1) along which to compute marginal.
norm_error: (tuple of int) p's to compute p-norm between marginal/target
lse_mode: whether operating on scalings or potentials
Returns:
Array of floats, quantifying difference between target / marginal.
"""
if lse_mode:
marginal = geom.marginal_from_potentials(f_u, g_v, axis=axis)
else:
marginal = geom.marginal_from_scalings(f_u, g_v, axis=axis)
norm_error = jnp.asarray(norm_error)
error = jnp.sum(
jnp.abs(marginal - target) ** norm_error[:, jnp.newaxis], axis=1
) ** (1.0 / norm_error)
return error
def ent_reg_cost(
f: jnp.ndarray, g: jnp.ndarray, ot_prob: linear_problems.LinearProblem,
lse_mode: bool
) -> float:
r"""Compute objective of Sinkhorn for OT problem given dual solutions.
The objective is evaluated for dual solution ``f`` and ``g``, using
information contained in ``ot_prob``. The objective is the regularized
optimal transport cost (i.e. the cost itself plus entropic and unbalanced
terms). Situations where marginals ``a`` or ``b`` in ot_prob have zero
coordinates are reflected in minus infinity entries in their corresponding
dual potentials. To avoid NaN that may result when multiplying 0's by infinity
values, ``jnp.where`` is used to cancel these contributions.
Args:
f: jnp.ndarray, potential
g: jnp.ndarray, potential
ot_prob: linear optimal transport problem.
lse_mode: bool, whether to compute total mass in lse or kernel mode.
Returns:
The regularized transport cost.
"""
supp_a = ot_prob.a > 0
supp_b = ot_prob.b > 0
fa = ot_prob.geom.potential_from_scaling(ot_prob.a)
if ot_prob.tau_a == 1.0:
div_a = jnp.sum(jnp.where(supp_a, ot_prob.a * (f - fa), 0.0))
else:
rho_a = ot_prob.epsilon * (ot_prob.tau_a / (1 - ot_prob.tau_a))
div_a = -jnp.sum(
jnp.where(
supp_a, ot_prob.a * unbalanced_functions.phi_star(-(f - fa), rho_a),
0.0
)
)
gb = ot_prob.geom.potential_from_scaling(ot_prob.b)
if ot_prob.tau_b == 1.0:
div_b = jnp.sum(jnp.where(supp_b, ot_prob.b * (g - gb), 0.0))
else:
rho_b = ot_prob.epsilon * (ot_prob.tau_b / (1 - ot_prob.tau_b))
div_b = -jnp.sum(
jnp.where(
supp_b, ot_prob.b * unbalanced_functions.phi_star(-(g - gb), rho_b),
0.0
)
)
# Using https://arxiv.org/pdf/1910.12958.pdf (24)
if lse_mode:
total_sum = jnp.sum(ot_prob.geom.marginal_from_potentials(f, g))
else:
u = ot_prob.geom.scaling_from_potential(f)
v = ot_prob.geom.scaling_from_potential(g)
total_sum = jnp.sum(ot_prob.geom.marginal_from_scalings(u, v))
return div_a + div_b + ot_prob.epsilon * (
jnp.sum(ot_prob.a) * jnp.sum(ot_prob.b) - total_sum
)
class SinkhornOutput(NamedTuple):
"""Implements the problems.Transport interface, for a Sinkhorn solution."""
f: Optional[jnp.ndarray] = None
g: Optional[jnp.ndarray] = None
errors: Optional[jnp.ndarray] = None
reg_ot_cost: Optional[float] = None
ot_prob: Optional[linear_problems.LinearProblem] = None
def set(self, **kwargs: Any) -> 'SinkhornOutput':
"""Return a copy of self, with potential overwrites."""
return self._replace(**kwargs)
def set_cost(
self, ot_prob: linear_problems.LinearProblem, lse_mode: bool,
use_danskin: bool
) -> 'SinkhornOutput':
f = jax.lax.stop_gradient(self.f) if use_danskin else self.f
g = jax.lax.stop_gradient(self.g) if use_danskin else self.g
return self.set(reg_ot_cost=ent_reg_cost(f, g, ot_prob, lse_mode))
@property
def linear(self) -> bool:
return isinstance(self.ot_prob, linear_problems.LinearProblem)
@property
def geom(self) -> geometry.Geometry:
return self.ot_prob.geom
@property
def a(self) -> jnp.ndarray:
return self.ot_prob.a
@property
def b(self) -> jnp.ndarray:
return self.ot_prob.b
@property
def linear_output(self) -> bool:
return True
@property
def converged(self) -> bool:
if self.errors is None:
return False
return jnp.logical_and(
jnp.any(self.errors == -1), jnp.all(jnp.isfinite(self.errors))
)
@property
def scalings(self) -> Tuple[jnp.ndarray, jnp.ndarray]:
u = self.ot_prob.geom.scaling_from_potential(self.f)
v = self.ot_prob.geom.scaling_from_potential(self.g)
return u, v
@property
def matrix(self) -> jnp.ndarray:
"""Transport matrix if it can be instantiated."""
try:
return self.ot_prob.geom.transport_from_potentials(self.f, self.g)
except ValueError:
return self.ot_prob.geom.transport_from_scalings(*self.scalings)
def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray:
"""Apply the transport to a ndarray; axis=1 for its transpose."""
try:
return self.ot_prob.geom.apply_transport_from_potentials(
self.f, self.g, inputs, axis=axis
)
except ValueError:
u, v = self.scalings
return self.ot_prob.geom.apply_transport_from_scalings(
u, v, inputs, axis=axis
)
def marginal(self, axis: int) -> jnp.ndarray:
return self.ot_prob.geom.marginal_from_potentials(self.f, self.g, axis=axis)
def cost_at_geom(self, other_geom: geometry.Geometry) -> float:
"""Return reg-OT cost for matrix, evaluated at other cost matrix."""
return (
jnp.sum(self.matrix * other_geom.cost_matrix) -
self.geom.epsilon * jnp.sum(jax.scipy.special.entr(self.matrix))
)
def transport_mass(self) -> float:
"""Sum of transport matrix."""
return self.marginal(0).sum()
def to_dual_potentials(self) -> potentials.EntropicPotentials:
"""Return the entropic map estimator."""
return potentials.EntropicPotentials(
self.f, self.g, self.geom, self.a, self.b
)
@jax.tree_util.register_pytree_node_class
class Sinkhorn:
"""A Sinkhorn solver for linear reg-OT problem.
A Sinkhorn solver takes a linear OT problem object as an input and returns a
SinkhornOutput object that contains all the information required to compute
transports. See :func:`~ott.core.sinkhorn.sinkhorn` for a functional wrapper.
Args:
lse_mode: ``True`` for log-sum-exp computations, ``False`` for kernel
multiplication.
threshold: tolerance used to stop the Sinkhorn iterations. This is
typically the deviation between a target marginal and the marginal of the
current primal solution when either or both tau_a and tau_b are 1.0
(balanced or semi-balanced problem), or the relative change between two
successive solutions in the unbalanced case.
norm_error: power used to define p-norm of error for marginal/target.
inner_iterations: the Sinkhorn error is not recomputed at each
iteration but every inner_num_iter instead.
min_iterations: the minimum number of Sinkhorn iterations carried
out before the error is computed and monitored.
max_iterations: the maximum number of Sinkhorn iterations. If
``max_iterations`` is equal to ``min_iterations``, sinkhorn iterations are
run by default using a :func:`jax.lax.scan` loop rather than a custom,
unroll-able :func:`jax.lax.while_loop` that monitors convergence.
In that case the error is not monitored and the ``converged``
flag will return ``False`` as a consequence.
momentum: a Momentum instance. See ott.core.momentum
anderson: an AndersonAcceleration instance. See ott.core.anderson.
implicit_diff: instance used to solve implicit differentiation. Unrolls
iterations if None.
parallel_dual_updates: updates potentials or scalings in parallel if True,
sequentially (in Gauss-Seidel fashion) if False.
use_danskin: when ``True``, it is assumed the entropy regularized cost is
is evaluated using optimal potentials that are frozen, i.e. whose
gradients have been stopped. This is useful when carrying out first order
differentiation, and is only valid (as with ``implicit_differentiation``)
when the algorithm has converged with a low tolerance.
jit: if True, automatically jits the function upon first call.
Should be set to False when used in a function that is jitted by the user,
or when computing gradients (in which case the gradient function
should be jitted by the user)
initializer: how to compute the initial potentials/scalings.
"""
def __init__(
self,
lse_mode: bool = True,
threshold: float = 1e-3,
norm_error: int = 1,
inner_iterations: int = 10,
min_iterations: int = 0,
max_iterations: int = 2000,
momentum: Optional[momentum_lib.Momentum] = None,
anderson: Optional[anderson_lib.AndersonAcceleration] = None,
parallel_dual_updates: bool = False,
use_danskin: Optional[bool] = None,
implicit_diff: Optional[implicit_lib.ImplicitDiff
] = implicit_lib.ImplicitDiff(), # noqa: E124
initializer: init_lib.SinkhornInitializer = init_lib.DefaultInitializer(),
jit: bool = True
):
self.lse_mode = lse_mode
self.threshold = threshold
self.inner_iterations = inner_iterations
self.min_iterations = min_iterations
self.max_iterations = max_iterations
self._norm_error = norm_error
if momentum is not None:
self.momentum = momentum_lib.Momentum(
momentum.start, momentum.value, self.inner_iterations
)
else:
self.momentum = momentum_lib.Momentum(
inner_iterations=self.inner_iterations
)
self.anderson = anderson
self.implicit_diff = implicit_diff
self.parallel_dual_updates = parallel_dual_updates
self.initializer = initializer
self.jit = jit
# Force implicit_differentiation to True when using Anderson acceleration,
# Reset all momentum parameters.
if anderson:
self.implicit_diff = (
implicit_lib.ImplicitDiff()
if self.implicit_diff is None else self.implicit_diff
)
self.momentum = momentum_lib.Momentum(start=0, value=1.0)
# By default, use Danskin theorem to differentiate
# the objective when using implicit_lib.
self.use_danskin = ((self.implicit_diff is not None)
if use_danskin is None else use_danskin)
def __call__(
self,
ot_prob: linear_problems.LinearProblem,
init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]] = (None, None),
) -> SinkhornOutput:
"""Run Sinkhorn algorithm.
Args:
ot_prob: Linear OT problem.
init: Initial dual potentials/scalings f_u and g_v, respectively.
Any `None` values will be initialized using the initializer.
Returns:
The Sinkhorn output.
"""
init_dual_a, init_dual_b = self.initializer(
ot_prob, *init, lse_mode=self.lse_mode
)
run_fn = jax.jit(run) if self.jit else run
return run_fn(ot_prob, self, (init_dual_a, init_dual_b))
def lse_step(
self, ot_prob: linear_problems.LinearProblem, state: SinkhornState,
iteration: int
) -> SinkhornState:
"""Sinkhorn LSE update."""
w = self.momentum.weight(state, iteration)
old_gv = state.gv
new_gv = ot_prob.tau_b * ot_prob.geom.update_potential(
state.fu, state.gv, jnp.log(ot_prob.b), iteration, axis=0
)
gv = self.momentum(w, state.gv, new_gv, self.lse_mode)
new_fu = ot_prob.tau_a * ot_prob.geom.update_potential(
state.fu,
old_gv if self.parallel_dual_updates else gv,
jnp.log(ot_prob.a),
iteration,
axis=1
)
fu = self.momentum(w, state.fu, new_fu, self.lse_mode)
return state.set(fu=fu, gv=gv)
def kernel_step(
self, ot_prob: linear_problems.LinearProblem, state: SinkhornState,
iteration: int
) -> SinkhornState:
"""Sinkhorn multiplicative update."""
w = self.momentum.weight(state, iteration)
old_gv = state.gv
new_gv = ot_prob.geom.update_scaling(
state.fu, ot_prob.b, iteration, axis=0
) ** ot_prob.tau_b
gv = self.momentum(w, state.gv, new_gv, self.lse_mode)
new_fu = ot_prob.geom.update_scaling(
old_gv if self.parallel_dual_updates else gv,
ot_prob.a,
iteration,
axis=1
) ** ot_prob.tau_a
fu = self.momentum(w, state.fu, new_fu, self.lse_mode)
return state.set(fu=fu, gv=gv)
def one_iteration(
self, ot_prob: linear_problems.LinearProblem, state: SinkhornState,
iteration: int, compute_error: bool
) -> SinkhornState:
"""Carries out sinkhorn iteration.
Depending on lse_mode, these iterations can be either in:
- log-space for numerical stability.
- scaling space, using standard kernel-vector multiply operations.
Args:
ot_prob: the transport problem definition
state: SinkhornState named tuple.
iteration: the current iteration of the Sinkhorn loop.
compute_error: flag to indicate this iteration computes/stores an error
Returns:
The updated state.
"""
# When running updates in parallel (Gauss-Seidel mode), old_g_v will be
# used to update f_u, rather than the latest g_v computed in this loop.
# Unused otherwise.
if self.anderson:
state = self.anderson.update(state, iteration, ot_prob, self.lse_mode)
if self.lse_mode: # In lse_mode, run additive updates.
state = self.lse_step(ot_prob, state, iteration)
else:
state = self.kernel_step(ot_prob, state, iteration)
if self.anderson:
state = self.anderson.update_history(state, ot_prob, self.lse_mode)
# re-computes error if compute_error is True, else set it to inf.
err = jnp.where(
jnp.logical_and(compute_error, iteration >= self.min_iterations),
state.solution_error(ot_prob, self.norm_error, self.lse_mode), jnp.inf
)
errors = (state.errors.at[iteration // self.inner_iterations, :].set(err))
return state.set(errors=errors)
def _converged(self, state: SinkhornState, iteration: int) -> bool:
err = state.errors[iteration // self.inner_iterations - 1, 0]
return jnp.logical_and(iteration > 0, err < self.threshold)
def _diverged(self, state: SinkhornState, iteration: int) -> bool:
err = state.errors[iteration // self.inner_iterations - 1, 0]
return jnp.logical_not(jnp.isfinite(err))
def _continue(self, state: SinkhornState, iteration: int) -> bool:
"""Continue while not(converged) and not(diverged)."""
return jnp.logical_and(
jnp.logical_not(self._diverged(state, iteration)),
jnp.logical_not(self._converged(state, iteration))
)
@property
def outer_iterations(self) -> int:
"""Upper bound on number of times inner_iterations are carried out.
This integer can be used to set constant array sizes to track the algorithm
progress, notably errors.
"""
return np.ceil(self.max_iterations / self.inner_iterations).astype(int)
def init_state(
self, ot_prob: linear_problems.LinearProblem, init: Tuple[jnp.ndarray,
jnp.ndarray]
) -> SinkhornState:
"""Return the initial state of the loop."""
fu, gv = init
errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
dtype=fu.dtype)
state = SinkhornState(errors=errors, fu=fu, gv=gv)
return self.anderson.init_maps(ot_prob, state) if self.anderson else state
def output_from_state(
self, ot_prob: linear_problems.LinearProblem, state: SinkhornState
) -> SinkhornOutput:
"""Create an output from a loop state.
Note:
When differentiating the regularized OT cost, and assuming Sinkhorn has
run to convergence, Danskin's (or the envelope)
`theorem <https://en.wikipedia.org/wiki/Danskin%27s_theorem>`_
states that the resulting OT cost as a function of any of the inputs
(``geometry``, ``a``, ``b``) behaves locally as if the dual optimal
potentials were frozen and did not vary with those inputs.
Notice this is only valid, as when using ``implicit_differentiation``
mode, if the Sinkhorn algorithm outputs potentials that are near optimal.
namely when the threshold value is set to a small tolerance.
The flag ``use_danskin`` controls whether that assumption is made. By
default, that flag is set to the value of ``implicit_differentiation`` if
not specified. If you wish to compute derivatives of order 2 and above,
set ``use_danskin`` to ``False``.
Args:
ot_prob: the transport problem.
state: a SinkhornState.
Returns:
A SinkhornOutput.
"""
geom = ot_prob.geom
f = state.fu if self.lse_mode else geom.potential_from_scaling(state.fu)
g = state.gv if self.lse_mode else geom.potential_from_scaling(state.gv)
errors = state.errors[:, 0]
return SinkhornOutput(f=f, g=g, errors=errors)
@property
def norm_error(self) -> Tuple[int, ...]:
"""Powers used to compute the p-norm between marginal/target."""
# To change momentum adaptively, one needs errors in ||.||_1 norm.
# In that case, we add this exponent to the list of errors to compute,
# notably if that was not the error requested by the user.
if self.momentum and self.momentum.start > 0 and self._norm_error != 1:
return self._norm_error, 1
return self._norm_error,
def tree_flatten(self):
aux = vars(self).copy()
aux['norm_error'] = aux.pop('_norm_error')
aux.pop('threshold')
return [self.threshold], aux
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(**aux_data, threshold=children[0])
def run(
ot_prob: linear_problems.LinearProblem, solver: Sinkhorn,
init: Tuple[jnp.ndarray, ...]
) -> SinkhornOutput:
"""Run loop of the solver, outputting a state upgraded to an output."""
iter_fun = _iterations_implicit if solver.implicit_diff else iterations
out = iter_fun(ot_prob, solver, init)
# Be careful here, the geom and the cost are injected at the end, where it
# does not interfere with the implicit differentiation.
out = out.set_cost(ot_prob, solver.lse_mode, solver.use_danskin)
return out.set(ot_prob=ot_prob)
def iterations(
ot_prob: linear_problems.LinearProblem, solver: Sinkhorn,
init: Tuple[jnp.ndarray, ...]
) -> SinkhornOutput:
"""Jittable Sinkhorn loop. args contain initialization variables."""
def cond_fn(
iteration: int, const: Tuple[linear_problems.LinearProblem, Sinkhorn],
state: SinkhornState
) -> bool:
_, solver = const
return solver._continue(state, iteration)
def body_fn(
iteration: int, const: Tuple[linear_problems.LinearProblem, Sinkhorn],
state: SinkhornState, compute_error: bool
) -> SinkhornState:
ot_prob, solver = const
return solver.one_iteration(ot_prob, state, iteration, compute_error)
# Run the Sinkhorn loop. Choose either a standard fixpoint_iter loop if
# differentiation is implicit, otherwise switch to the backprop friendly
# version of that loop if unrolling to differentiate.
if solver.implicit_diff:
fix_point = fixed_point_loop.fixpoint_iter
else:
fix_point = fixed_point_loop.fixpoint_iter_backprop
const = ot_prob, solver
state = solver.init_state(ot_prob, init)
state = fix_point(
cond_fn, body_fn, solver.min_iterations, solver.max_iterations,
solver.inner_iterations, const, state
)
return solver.output_from_state(ot_prob, state)
def _iterations_taped(
ot_prob: linear_problems.LinearProblem, solver: Sinkhorn,
init: Tuple[jnp.ndarray, ...]
) -> Tuple[SinkhornOutput, Tuple[jnp.ndarray, jnp.ndarray,
linear_problems.LinearProblem, Sinkhorn]]:
"""Run forward pass of the Sinkhorn algorithm storing side information."""
state = iterations(ot_prob, solver, init)
return state, (state.f, state.g, ot_prob, solver)
def _iterations_implicit_bwd(res, gr):
"""Run Sinkhorn in backward mode, using implicit differentiation.
Args:
res: residual data sent from fwd pass, used for computations below. In this
case consists in the output itself, as well as inputs against which we
wish to differentiate.
gr: gradients w.r.t outputs of fwd pass, here w.r.t size f, g, errors. Note
that differentiability w.r.t. errors is not handled, and only f, g is
considered.
Returns:
a tuple of gradients: PyTree for geom, one jnp.ndarray for each of a and b.
"""
f, g, ot_prob, solver = res
gr = gr[:2]
return (
*solver.implicit_diff.gradient(ot_prob, f, g, solver.lse_mode, gr), None,
None
)
# Sets threshold, norm_errors, geom, a and b to be differentiable, as those are
# non static. Only differentiability w.r.t. geom, a and b will be used.
_iterations_implicit = jax.custom_vjp(iterations)
_iterations_implicit.defvjp(_iterations_taped, _iterations_implicit_bwd)
def make(
tau_a: float = 1.0,
tau_b: float = 1.0,
threshold: float = 1e-3,
norm_error: int = 1,
inner_iterations: int = 10,
min_iterations: int = 0,
max_iterations: int = 2000,
momentum: float = 1.0,
chg_momentum_from: int = 0,
anderson_acceleration: int = 0,
refresh_anderson_frequency: int = 1,
lse_mode: bool = True,
implicit_differentiation: bool = True,
implicit_solver_fun=jax.scipy.sparse.linalg.cg,
implicit_solver_ridge_kernel: float = 0.0,
implicit_solver_ridge_identity: float = 0.0,
implicit_solver_symmetric: bool = False,
precondition_fun: Optional[Callable[[float], float]] = None,
parallel_dual_updates: bool = False,
use_danskin: bool = None,
initializer: init_lib.SinkhornInitializer = init_lib.DefaultInitializer(),
jit: bool = False
) -> Sinkhorn:
"""For backward compatibility."""
del tau_a, tau_b
if not implicit_differentiation:
implicit_diff = None
else:
implicit_diff = implicit_lib.ImplicitDiff(
solver_fun=implicit_solver_fun,
ridge_kernel=implicit_solver_ridge_kernel,
ridge_identity=implicit_solver_ridge_identity,
symmetric=implicit_solver_symmetric,
precondition_fun=precondition_fun
)
if anderson_acceleration > 0:
anderson = anderson_lib.AndersonAcceleration(
memory=anderson_acceleration, refresh_every=refresh_anderson_frequency
)
else:
anderson = None
return Sinkhorn(
lse_mode=lse_mode,
threshold=threshold,
norm_error=norm_error,
inner_iterations=inner_iterations,
min_iterations=min_iterations,
max_iterations=max_iterations,
momentum=momentum_lib.Momentum(start=chg_momentum_from, value=momentum),
anderson=anderson,
implicit_diff=implicit_diff,
parallel_dual_updates=parallel_dual_updates,
use_danskin=use_danskin,
initializer=initializer,
jit=jit
)
def sinkhorn(
geom: geometry.Geometry,
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
tau_a: float = 1.0,
tau_b: float = 1.0,
init_dual_a: Optional[jnp.ndarray] = None,
init_dual_b: Optional[jnp.ndarray] = None,
**kwargs: Any,
):
r"""Solve regularized OT problem using Sinkhorn iterations.
The Sinkhorn algorithm is a fixed point iteration that solves a regularized
optimal transport (reg-OT) problem between two measures.
The optimization variables are a pair of vectors (called potentials, or
scalings when parameterized as exponentials of the former). Calling this
function returns therefore a pair of optimal vectors. In addition to these,
it also returns the objective value achieved by these optimal vectors;
a vector of size ``max_iterations/inner_iterations`` that records the vector
of values recorded to monitor convergence, throughout the execution of the
algorithm (padded with `-1` if convergence happens before), as well as a
boolean to signify whether the algorithm has converged within the number of
iterations specified by the user.
The reg-OT problem is specified by two measures, of respective sizes ``n`` and
``m``. From the viewpoint of the ``sinkhorn`` function, these two measures are
only seen through a triplet (``geom``, ``a``, ``b``), where ``geom`` is a
``Geometry`` object, and ``a`` and ``b`` are weight vectors of respective
sizes ``n`` and ``m``. Starting from two initial values for those potentials
or scalings (both can be defined by the user by passing value in
``init_dual_a`` or ``init_dual_b``), the Sinkhorn algorithm will use
elementary operations that are carried out by the ``geom`` object.
Some maths:
Given a geometry ``geom``, which provides a cost matrix :math:`C` with its
regularization parameter :math:`\varepsilon`, (or a kernel matrix :math:`K`)
the reg-OT problem consists in finding two vectors `f`, `g` of size ``n``,
``m`` that maximize the following criterion.
.. math::
\arg\max_{f, g}{- \langle a, \phi_a^{*}(-f) \rangle - \langle b,
\phi_b^{*}(-g) \rangle - \varepsilon \langle e^{f/\varepsilon},
e^{-C/\varepsilon} e^{-g/\varepsilon}} \rangle
where :math:`\phi_a(z) = \rho_a z(\log z - 1)` is a scaled entropy, and
:math:`\phi_a^{*}(z) = \rho_a e^{z/\varepsilon}`, its Legendre transform.
That problem can also be written, instead, using positive scaling vectors
`u`, `v` of size ``n``, ``m``, handled with the kernel
:math:`K := e^{-C/\varepsilon}`,
.. math::
\arg\max_{u, v >0} - \langle a,\phi_a^{*}(-\varepsilon\log u) \rangle +
\langle b, \phi_b^{*}(-\varepsilon\log v) \rangle - \langle u, K v \rangle
Both of these problems corresponds, in their *primal* formulation, to
solving the unbalanced optimal transport problem with a variable matrix
:math:`P` of size ``n`` x ``m``:
.. math::
\arg\min_{P>0} \langle P,C \rangle -\varepsilon \text{KL}(P | ab^T)
+ \rho_a \text{KL}(P\mathbf{1}_m | a) + \rho_b \text{KL}(P^T \mathbf{1}_n
| b)
where :math:`KL` is the generalized Kullback-Leibler divergence.
The very same primal problem can also be written using a kernel :math:`K`
instead of a cost :math:`C` as well:
.. math::
\arg\min_{P} \varepsilon KL(P|K) + \rho_a \text{KL}(P\mathbf{1}_m | a) +
\rho_b \text{KL}(P^T \mathbf{1}_n | b)
The *original* OT problem taught in linear programming courses is recovered
by using the formulation above relying on the cost :math:`C`, and letting
:math:`\varepsilon \rightarrow 0`, and :math:`\rho_a, \rho_b \rightarrow
\infty`.
In that case the entropy disappears, whereas the :math:`KL` regularization
above become constraints on the marginals of :math:`P`: This results in a
standard min cost flow problem. This problem is not handled for now in this
toolbox, which focuses exclusively on the case :math:`\varepsilon > 0`.
The *balanced* regularized OT problem is recovered for finite
:math:`\varepsilon > 0` but letting :math:`\rho_a, \rho_b \rightarrow
\infty`. This problem can be shown to be equivalent to a matrix scaling
problem, which can be solved using the Sinkhorn fixed-point algorithm.
To handle the case :math:`\rho_a, \rho_b \rightarrow \infty`, the
``sinkhorn`` function uses parameters :math:`tau\_a := \rho_a /
(\varepsilon + \rho_a)` and :math:`tau\_b := \rho_b / (\varepsilon +
\rho_b)` instead. Setting either of these parameters to 1 corresponds to
setting the corresponding :math:`\rho_a, \rho_b` to :math:`\infty`.
The Sinkhorn algorithm solves the reg-OT problem by seeking optimal `f`, `g`
potentials (or alternatively their parametrization as positive scalings
`u`, `v`), rather than solving the primal problem in :math:`P`.
This is mostly for efficiency (potentials and scalings have a ``n + m``
memory footprint, rather than ``n m`` required to store `P`). This is also
because both problems are, in fact, equivalent, since the optimal transport
math:`P^*` can be recovered from optimal potentials :math:`f^*`, :math:`g^*`
or scalings :math:`u^*`, :math:`v^*`, using the geometry's cost or kernel
matrix respectively:
.. math::
P^* = \exp\left(\frac{f^*\mathbf{1}_m^T + \mathbf{1}_n g^{*T} -
C}{\varepsilon}\right) \text{ or } P^* = \text{diag}(u^*) K
\text{diag}(v^*)
By default, the Sinkhorn algorithm solves this dual problem in `f, g` or
`u, v` using block coordinate ascent, i.e. devising an update for each `f`
and `g` (resp. `u` and `v`) that cancels their respective gradients, one at
a time. These two iterations are repeated ``inner_iterations`` times, after
which the norm of these gradients will be evaluated and compared with the
``threshold`` value. The iterations are then repeated as long as that error
exceeds ``threshold``.
Note on Sinkhorn updates:
The boolean flag ``lse_mode`` sets whether the algorithm is run in either:
- log-sum-exp mode (``lse_mode=True``), in which case it is directly
defined in terms of updates to `f` and `g`, using log-sum-exp
computations. This requires access to the cost matrix :math:`C`, as it is
stored, or possibly computed on the fly by ``geom``.
- kernel mode (``lse_mode=False``), in which case it will require access
to a matrix vector multiplication operator :math:`z \rightarrow K z`,
where :math:`K` is either instantiated from :math:`C` as
:math:`\exp(-C/\varepsilon)`, or provided directly. In that case, rather
than optimizing on :math:`f` and :math:`g`, it is more convenient to
optimize on their so called scaling formulations,
:math:`u := \exp(f / \varepsilon)` and :math:`v := \exp(g / \varepsilon)`.
While faster (applying matrices is faster than applying ``lse`` repeatedly
over lines), this mode is also less stable numerically, notably for
smaller :math:`\varepsilon`.
In the source code, the variables ``f_u`` or ``g_v`` can be either regarded
as potentials (real) or scalings (positive) vectors, depending on the choice
of ``lse_mode`` by the user. Once optimization is carried out, we only
return dual variables in potential form, i.e. ``f`` and ``g``.
In addition to standard Sinkhorn updates, the user can also use heavy-ball
type updates using a ``momentum`` parameter in ]0,2[. We also implement a
strategy that tries to set that parameter adaptively at
``chg_momentum_from`` iterations, as a function of progress in the error,
as proposed in the literature.
Another upgrade to the standard Sinkhorn updates provided to the users lies
in using Anderson acceleration. This can be parameterized by setting the
otherwise null ``anderson`` to a positive integer. When selected,the
algorithm will recompute, every ``refresh_anderson_frequency`` (set by
default to 1) an extrapolation of the most recently computed ``anderson``
iterates. When using that option, notice that differentiation (if required)
can only be carried out using implicit differentiation, and that all
momentum related parameters are ignored.
The ``parallel_dual_updates`` flag is set to ``False`` by default. In that
setting, ``g_v`` is first updated using the latest values for ``f_u`` and
``g_v``, before proceeding to update ``f_u`` using that new value for
``g_v``. When the flag is set to ``True``, both ``f_u`` and ``g_v`` are
updated simultaneously. Note that setting that choice to ``True`` requires
using some form of averaging (e.g. ``momentum=0.5``). Without this, and on
its own ``parallel_dual_updates`` won't work.
Differentiation:
The optimal solutions ``f`` and ``g`` and the optimal objective
(``reg_ot_cost``) outputted by the Sinkhorn algorithm can be differentiated
w.r.t. relevant inputs ``geom``, ``a`` and ``b`` using, by default, implicit
differentiation of the optimality conditions (``implicit_differentiation``
set to ``True``). This choice has two consequences.
- The termination criterion used to stop Sinkhorn (cancellation of
gradient of objective w.r.t. ``f_u`` and ``g_v``) is used to differentiate
``f`` and ``g``, given a change in the inputs. These changes are computed
by solving a linear system. The arguments starting with
``implicit_solver_*`` allow to define the linear solver that is used, and
to control for two types or regularization (we have observed that,
depending on the architecture, linear solves may require higher ridge
parameters to remain stable). The optimality conditions in Sinkhorn can be
analyzed as satisfying a ``z=z'`` condition, which are then
differentiated. It might be beneficial (e.g., as in :cite:`cuturi:20a`)
to use a preconditioning function ``precondition_fun`` to differentiate
instead ``h(z) = h(z')``.
- The objective ``reg_ot_cost`` returned by Sinkhorn uses the so-called
envelope (or Danskin's) theorem. In that case, because it is assumed that
the gradients of the dual variables ``f_u`` and ``g_v`` w.r.t. dual
objective are zero (reflecting the fact that they are optimal), small
variations in ``f_u`` and ``g_v`` due to changes in inputs (such as
``geom``, ``a`` and ``b``) are considered negligible. As a result,
``stop_gradient`` is applied on dual variables ``f_u`` and ``g_v`` when
evaluating the ``reg_ot_cost`` objective. Note that this approach is
`invalid` when computing higher order derivatives. In that case the
``use_danskin`` flag must be set to ``False``.
An alternative yet more costly way to differentiate the outputs of the
Sinkhorn iterations is to use unrolling, i.e. reverse mode differentiation
of the Sinkhorn loop. This is possible because Sinkhorn iterations are
wrapped in a custom fixed point iteration loop, defined in
``fixed_point_loop``, rather than a standard while loop. This is to ensure
the end result of this fixed point loop can also be differentiated, if
needed, using standard JAX operations. To ensure backprop differentiability,
the ``fixed_point_loop.fixpoint_iter_backprop`` loop does checkpointing of
state variables (here ``f_u`` and ``g_v``) every ``inner_iterations``, and
backpropagates automatically, block by block, through blocks of
``inner_iterations`` at a time.
Note:
* The Sinkhorn algorithm may not converge within the maximum number of
iterations for possibly several reasons:
1. the regularizer (defined as ``epsilon`` in the geometry ``geom``
object) is too small. Consider either switching to ``lse_mode=True``
(at the price of a slower execution), increasing ``epsilon``, or,
alternatively, if you are unable or unwilling to increase ``epsilon``,
either increase ``max_iterations`` or ``threshold``.
2. the probability weights ``a`` and ``b`` do not have the same total
mass, while using a balanced (``tau_a=tau_b=1.0``) setup.
Consider either normalizing ``a`` and ``b``, or set either ``tau_a``
and/or ``tau_b<1.0``.
3. OOMs issues may arise when storing either cost or kernel matrices that
are too large in ``geom``. In the case where, the ``geom`` geometry is
a ``PointCloud``, some of these issues might be solved by setting the
``online`` flag to ``True``. This will trigger a re-computation on the
fly of the cost/kernel matrix.
* The weight vectors ``a`` and ``b`` can be passed on with coordinates that
have zero weight. This is then handled by relying on simple arithmetic for
``inf`` values that will likely arise (due to :math:`\log 0` when
``lse_mode`` is ``True``, or divisions by zero when ``lse_mode`` is
``False``). Whenever that arithmetic is likely to produce ``NaN`` values
(due to ``-inf * 0``, or ``-inf - -inf``) in the forward pass, we use
``jnp.where`` conditional statements to carry ``inf`` rather than ``NaN``
values. In the reverse mode differentiation, the inputs corresponding to
these 0 weights (a location `x`, or a row in the corresponding cost/kernel
matrix), and the weight itself will have ``NaN`` gradient values. This is
reflects that these gradients are undefined, since these points were not
considered in the optimization and have therefore no impact on the output.
Args:
geom: a Geometry object.
a: The first marginal. If `None`, it will be uniform.
b: The second marginal. If `None`, it will be uniform.
tau_a: ratio rho/(rho+eps) between KL divergence regularizer to first
marginal and itself + epsilon regularizer used in the unbalanced
formulation.
tau_b: ratio rho/(rho+eps) between KL divergence regularizer to first
marginal and itself + epsilon regularizer used in the unbalanced
formulation.
init_dual_a: optional initialization for potentials/scalings w.r.t.
first marginal (``a``) of reg-OT problem.
init_dual_b: optional initialization for potentials/scalings w.r.t.
second marginal (``b``) of reg-OT problem.
threshold: tolerance used to stop the Sinkhorn iterations. This is
typically the deviation between a target marginal and the marginal of the
current primal solution when either or both tau_a and tau_b are 1.0
(balanced or semi-balanced problem), or the relative change between two