# Jacobs the back-and-forth method

In [1]:
%matplotlib nbagg
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

In [2]:
import sys
sys.path.append("../")
from c_transform import c_transform
from jacobs import push_forward1
from push_forward import lap_solve

### Wasserstein distance $\int \phi d\nu + \int \phi^c d\mu$

In [3]:
# Wasserstein distance \int \phi d\nu + \int \phi^c d\mu
def w2(phi, psi, mu, nu):
    return np.sum(phi * nu + psi * mu)

### every ascent step, update $\sigma$

各勾配ステップ後に、$J(\phi_{n+1})-J(\phi_n)$ を2乗ノルム$\| \nabla_{\dot{H}^1} J(\phi) \|_{\dot{H}^1}^2$ （$I$および$\psi, \psi$の場合についても同様）に比較.

$$
- \sigma_n \beta_2 \|\nabla_{\dot{H}^1} J(\phi_n)\|^2 \leq J(\phi_n)-J(\phi_{n+1}) \leq - \sigma_n \beta_1 \|\nabla_{\dot{H}^1} J(\phi_n)\|_{\dot{H}^1}^2.
$$

もし右の不等式が成立しない場合、
$\sigma_n$を減らすために$\sigma_{n+1}=\alpha_2\sigma_n$を取る。
左の不等式が成立しない場合、$\sigma_n$を増やすために$\sigma_{n+1}=\alpha_1\sigma_n$を取る。
全ての実験において、$\beta_1 = \frac{1}{4},\beta_2 = \frac{3}{4}, \alpha_1 = \frac{5}{4}, \alpha_2 = \frac{4}{5}$とし、$\sigma = 8 \min(\|\mu\|_{L^\infty}^{-1}, \|\nu\|_{L^\infty}^{-1})$を初期値として選択します。


In [4]:
def update_sigma(diff, H1_sq, sigma):
    if diff < 0.:
        sigma *= 0.1
    elif diff > H1_sq * sigma * upper:
        sigma *= scaleUp
    elif diff < H1_sq * sigma * lower:
        sigma *= scaleDown
    return sigma

### ascent step of $J(\phi) = \int \phi d\nu + \int \phi^c d\mu$

fills $\phi$ and $\phi_c$, returns new $\sigma$.

\begin{align}
  \phi_{n+\frac{1}{2}} &= \phi_n + \sigma \nabla_{\dot{H}^1} J(\phi_n),\\
  \psi_{n+\frac{1}{2}} &= (\phi_{n+\frac{1}{2}})^c,\\
  \psi_{n+1} &= \psi_{n+\frac{1}{2}} + \sigma \nabla_{\dot{H}^1} I(\psi_{n+\frac{1}{2}}), \\
  \quad \phi_{n+1} &= (\psi_{n+1})^c.
\end{align} 

### (1) $\,\phi_{n+\frac{1}{2}} = \phi_n + \sigma \nabla_{\dot{H}^1} J(\phi_n)$, 

   $\nabla_{\dot{H}^1} J(\phi_n) =  (- \Delta)^{-1} (\nu - T_{\phi\#}\mu), \qquad (\nu - T_{\phi\#}\mu) =: \rho$.

   $T_{\phi\#}\mu = x - (\nabla h)^{-1}(\nabla\phi^c(x))$
   
   コスト関数$c(x, y) = \frac{1}{2}|x - y|^2$のとき,

   $T_{\phi\#}\mu = x - (\nabla h)^{-1}(\nabla\phi^c(x)) =  x - \nabla\phi^c(x)$

   よって,

   $\phi_{n+\frac{1}{2}} = (- \Delta)^{-1} (\nu - (x - \nabla\phi^c(x)))$


In [5]:
"""
for k in range(300):
    phi_c, _ = c_transform(x, phi, p)                                 #1-1
    phi += sigma * lap_solve(nu - push_forward1(mu, phi_c, h))        #1-2      phi_{n + 1/2} = phi_n + sigma * 
    psi, _ = c_transform(x, phi, p)                                   #2        psi_{n + 1/2} = (phi_{n + 1/2})^c
    
    psi_c, _ = c_transform(x, psi, p)                                 #3-1
    psi += sigma * lap_solve(mu - push_forward1(nu, psi_c, h))        #3-2
    phi, _ = c_transform(x, psi, p)                                   #4        phi_{n + 1} = (psi_{n + 1})^c
    
    title = ax.text(4.5, 1.15, 'back-and-forth update $\mu$ and $\nu$. Example 2:  Iterate {}'.format(str(k+1)))
    img1, = ax.plot(x, push_forward1(nu, psi_c, h), color='red',label=r'$T_{\phi \#} \mu$')
    img2, = ax.plot(x, push_forward1(mu, phi_c, h), color='blue', label=r'$T_{\psi \#} \nu$')
    
    if k % 1 == 0:
        ax.legend(prop={'size': 15})
        artists.append([img1, img2, title])
    
ani = animation.ArtistAnimation(fig, artists, interval=2, repeat_delay=1000)
plt.show()

"""

"\nfor k in range(300):\n    phi_c, _ = c_transform(x, phi, p)                                 #1-1\n    phi += sigma * lap_solve(nu - push_forward1(mu, phi_c, h))        #1-2      phi_{n + 1/2} = phi_n + sigma * \n    psi, _ = c_transform(x, phi, p)                                   #2        psi_{n + 1/2} = (phi_{n + 1/2})^c\n    \n    psi_c, _ = c_transform(x, psi, p)                                 #3-1\n    psi += sigma * lap_solve(mu - push_forward1(nu, psi_c, h))        #3-2\n    phi, _ = c_transform(x, psi, p)                                   #4        phi_{n + 1} = (psi_{n + 1})^c\n    \n    title = ax.text(4.5, 1.15, 'back-and-forth update $\\mu$ and $\nu$. Example 2:  Iterate {}'.format(str(k+1)))\n    img1, = ax.plot(x, push_forward1(nu, psi_c, h), color='red',label=r'$T_{\\phi \\#} \\mu$')\n    img2, = ax.plot(x, push_forward1(mu, phi_c, h), color='blue', label=r'$T_{\\psi \\#} \nu$')\n    \n    if k % 1 == 0:\n        ax.legend(prop={'size': 15})\n        artists.append(

In [6]:
def ascent(phi, phi_c, mu, nu, sigma):
    phi_c, _ = c_transform(x, phi, x)                        # 1-1  phi_c, _ = c_transform(x, phi, p)

    old_J = w2(phi, phi_c, mu, nu)  

    pfwd = push_forward1(mu, phi_c, h)              # 1-2-1     pfwd : T_{\phi\#}\mu = x - (\nabla h)^{-1}(\nabla\phi^c(x)) = x - \nabla\phi^c(x)
    rho = nu - pfwd                                 # 1-2-2     rho = \nu - T_{\phi\#}\mu
    # TODO: This is by far the slowest part of the algorithm
    lp = lap_solve(rho)                             # 1-2-3     lp: \nabla_{\dot{H}^1} J(\phi_n) = (- \Delta)^{-1} * rho
    phi += sigma * lp                               # 1-2-4   phi_{n + 1/2} = phi_n + sigma * lp
####################################################################

    phi_c, _ = c_transform(x, phi, x)                    # 2    psi_{n + 1/2} = (phi_{n + 1/2})^c
    J = w2(phi, phi_c, mu, nu)
    H1_sq = np.mean(rho * lp)                        #######       ?
    return update_sigma(J - old_J, H1_sq, sigma), J, H1_sq, phi, phi_c, pfwd   #  ?


### Fix $\alpha_1, \alpha_2, \beta_1, \beta_2, \sigma_0$

In [7]:
scaleDown = 0.5  # \alpha_2
scaleUp   = 1 / scaleDown # \alpha_1
upper = 0.75     # \beta_2
lower = 0.25     # \beta_1

sigma = 200.
sigma1 = sigma
sigma2 = sigma
J = 0
H1_sq = 0
common_sigma = False

### Example. 1

In [14]:
x = np.linspace(-1, 1, 1001)
p = x

mu = np.where((x > 0.3) & (x < 0.8), 1., 0.)     #True: 1. False: 0.
nu = np.copy(mu[::-1])

phi = np.zeros_like(x)
psi = np.zeros_like(x)  

h = x[1] - x[0]

fig, ax = plt.subplots()
artists = []

ax.set_xlim(-1,1)
ax.set_ylim(-0.5,1.5)

for k in range(100):
    #plt.title(r'back-and-forth update $\mu$ and $\nu$. Example 1:  Iterate ' + str(k+1))
    title = ax.text(4.5, 1.15, 'back-and-forth update $\mu$ and $\nu$. Example 2:  Iterate {}'.format(str(k+1)))
    
    sigma1, J, H1_sq, phi, psi, pfwd  = ascent(phi, psi, mu, nu, sigma1)

    print(f'{k:3}: J(φ) = {J}, (H¹)² = {H1_sq:.3}, σ₂ = {sigma1:.5}')
    if common_sigma:
        sigma2 = sigma1
    img2, = ax.plot(x, pfwd, color='blue')
    
    
    sigma2, J, H1_sq, psi, phi, pfwd = ascent(psi, phi, nu, mu, sigma2)
    if common_sigma:
        sigma1 = sigma2
        
    img1, = ax.plot(x, pfwd, color='red')
    
    
    if k % 1 == 0:  
        artists.append([img1] + [img2] + [title])
        
ani = animation.ArtistAnimation(fig, artists, interval=100, repeat_delay=100)
plt.show()

<IPython.core.display.Javascript object>

  0: J(φ) = 1.1641700000000022e-46, (H¹)² = 0.0291, σ₂ = 1.6e-47
  1: J(φ) = 98.91293368136334, (H¹)² = 0.0181, σ₂ = 8e-48
  2: J(φ) = 141.75334784640015, (H¹)² = 0.0045, σ₂ = 4e-48
  3: J(φ) = 144.77742186857225, (H¹)² = 0.00297, σ₂ = 2e-48
  4: J(φ) = 116.39754368219236, (H¹)² = 0.0172, σ₂ = 1e-48
  5: J(φ) = 128.303634937785, (H¹)² = 0.0112, σ₂ = 5e-49
  6: J(φ) = 143.5999160443557, (H¹)² = 0.00357, σ₂ = 2.5e-49
  7: J(φ) = 145.7567255760382, (H¹)² = 0.00247, σ₂ = 1.25e-49
  8: J(φ) = 126.50425842624793, (H¹)² = 0.0121, σ₂ = 6.25e-50
  9: J(φ) = 135.33524601541325, (H¹)² = 0.00771, σ₂ = 3.125e-50
 10: J(φ) = 146.55226397418843, (H¹)² = 0.00209, σ₂ = 1.5625e-50
 11: J(φ) = 148.17052073166394, (H¹)² = 0.00127, σ₂ = 7.8125e-51
 12: J(φ) = 135.84544764896975, (H¹)² = 0.00744, σ₂ = 3.9063e-51
 13: J(φ) = 141.51262737600453, (H¹)² = 0.00461, σ₂ = 1.9531e-51
 14: J(φ) = 148.6268948064568, (H¹)² = 0.00104, σ₂ = 9.7656e-52
 15: J(φ) = 149.38753353595743, (H¹)² = 0.000658, σ₂ = 4.8828e-52
 16

### Example. 2

In [15]:
x = np.linspace(-1, 1, 1001)
p = x

mu = np.where((x > 0.) & (x < 0.5), 0.5, 0.)     #True: 1. False: 0.
nu = np.where((x > -0.5) & (x < -0.25), 1., 0.)

phi = np.zeros_like(x)
psi = np.zeros_like(x)  

h = x[1] - x[0]

fig, ax = plt.subplots()
artists = []

ax.set_xlim(-1,1)
ax.set_ylim(-0.5,1.5)

for k in range(100):
    #plt.title(r'back-and-forth update $\mu$ and $\nu$. Example 1:  Iterate ' + str(k+1))
    title = ax.text(4.5, 1.15, 'back-and-forth update $\mu$ and $\nu$. Example 2:  Iterate {}'.format(str(k+1)))
    
    sigma1, J, H1_sq, phi, psi, pfwd  = ascent(phi, psi, mu, nu, sigma1)

    print(f'{k:3}: J(φ) = {J}, (H¹)² = {H1_sq:.3}, σ₂ = {sigma1:.5}')
    if common_sigma:
        sigma2 = sigma1
    img2, = ax.plot(x, pfwd, color='blue')
    
    
    sigma2, J, H1_sq, psi, phi, pfwd = ascent(psi, phi, nu, mu, sigma2)
    if common_sigma:
        sigma1 = sigma2
        
    img1, = ax.plot(x, pfwd, color='red')
    
    
    if k % 1 == 0:  
        artists.append([img1] + [img2] + [title])
        
ani = animation.ArtistAnimation(fig, artists, interval=100, repeat_delay=100)
plt.show()

<IPython.core.display.Javascript object>

  0: J(φ) = 5.573225052650975e-77, (H¹)² = 0.00386, σ₂ = 5.0487e-77
  1: J(φ) = 10.841964143014483, (H¹)² = 0.00284, σ₂ = 2.5244e-77
  2: J(φ) = 19.859547989146407, (H¹)² = 0.00118, σ₂ = 1.2622e-77
  3: J(φ) = 22.440418759613245, (H¹)² = 0.00054, σ₂ = 6.3109e-78
  4: J(φ) = 23.02446996634027, (H¹)² = 0.000388, σ₂ = 3.1554e-78
  5: J(φ) = 24.078134998200184, (H¹)² = 0.000128, σ₂ = 1.5777e-78
  6: J(φ) = 24.283554904086692, (H¹)² = 7.89e-05, σ₂ = 7.8886e-79
  7: J(φ) = 24.344155222332986, (H¹)² = 6.27e-05, σ₂ = 3.9443e-79
  8: J(φ) = 24.085490098840552, (H¹)² = 0.000129, σ₂ = 1.9722e-79
  9: J(φ) = 24.323118203238497, (H¹)² = 6.96e-05, σ₂ = 9.8608e-80
 10: J(φ) = 24.469847273189473, (H¹)² = 3.19e-05, σ₂ = 4.9304e-80
 11: J(φ) = 24.489648726756386, (H¹)² = 2.67e-05, σ₂ = 2.4652e-80
 12: J(φ) = 24.496656328069594, (H¹)² = 2.52e-05, σ₂ = 1.2326e-80
 13: J(φ) = 24.274537497936848, (H¹)² = 7.86e-05, σ₂ = 6.163e-81
 14: J(φ) = 24.539862014946564, (H¹)² = 1.41e-05, σ₂ = 3.0815e-81
 15: J(φ) = 2

### Example. 3

In [17]:
x = np.linspace(-1, 1, 1001)
p = x

mu = np.where(x > 0.2, 1., 0.)     #True: 1. False: 0.
mu /= np.sum(mu)                   # mu = mu / np.sum(mu)
nu = np.exp(-(x + 0.5)**2 * 100)
nu /= np.sum(nu)  

phi = np.zeros_like(x)
psi = np.zeros_like(x)  

h = x[1] - x[0]

fig, ax = plt.subplots()
artists = []

ax.set_xlim(-1,1)
ax.set_ylim(-0.0001,0.02) 

for k in range(100):
    #plt.title(r'back-and-forth update $\mu$ and $\nu$. Example 1:  Iterate ' + str(k+1))
    title = ax.text(4.5, 1.15, 'back-and-forth update $\mu$ and $\nu$. Example 2:  Iterate {}'.format(str(k+1)))
    
    sigma1, J, H1_sq, phi, psi, pfwd  = ascent(phi, psi, mu, nu, sigma1)
    print(f'{k:3}: J(φ) = {J}, (H¹)² = {H1_sq:.3}, σ₂ = {sigma1:.5}')
    
    if common_sigma:
        sigma2 = sigma1
    img2, = ax.plot(x, pfwd, color='blue')
    
    
    sigma2, J, H1_sq, psi, phi, pfwd = ascent(psi, phi, nu, mu, sigma2)
    if common_sigma:
        sigma1 = sigma2
        
    img1, = ax.plot(x, pfwd, color='red')
    
    
    if k % 1 == 0:
        artists.append([img1] + [img2] + [title])
        
ani = animation.ArtistAnimation(fig, artists, interval=100, repeat_delay=100)
plt.show()

<IPython.core.display.Javascript object>

  0: J(φ) = 5.911480049358633e-140, (H¹)² = 4.63e-07, σ₂ = 5.0269e-136
  1: J(φ) = 0.06623436473347229, (H¹)² = 4.39e-07, σ₂ = 2.5135e-136
  2: J(φ) = 0.185709187039608, (H¹)² = 3.91e-07, σ₂ = 1.2567e-136
  3: J(φ) = 0.36225142035443214, (H¹)² = 2.94e-07, σ₂ = 6.2836e-137
  4: J(φ) = 0.5057295584763868, (H¹)² = 1.42e-07, σ₂ = 3.1418e-137
  5: J(φ) = 0.5644532293125285, (H¹)² = 6.87e-08, σ₂ = 1.5709e-137
  6: J(φ) = 0.5838050421208094, (H¹)² = 4.42e-08, σ₂ = 7.8545e-138
  7: J(φ) = 0.6072995476512718, (H¹)² = 1.51e-08, σ₂ = 3.9273e-138
  8: J(φ) = 0.6088355143091104, (H¹)² = 1.33e-08, σ₂ = 1.9636e-138
  9: J(φ) = 0.6047795841573521, (H¹)² = 1.81e-08, σ₂ = 9.8182e-139
 10: J(φ) = 0.6088000227540533, (H¹)² = 1.32e-08, σ₂ = 4.9091e-139
 11: J(φ) = 0.6122858561909031, (H¹)² = 8.88e-09, σ₂ = 2.4545e-139
 12: J(φ) = 0.6130780748900923, (H¹)² = 7.92e-09, σ₂ = 1.2273e-139
 13: J(φ) = 0.6132627448636282, (H¹)² = 7.62e-09, σ₂ = 6.1364e-140
 14: J(φ) = 0.6011466816758816, (H¹)² = 2.29e-08, σ₂ = 3.

### Example. 4

In [19]:
x = np.linspace(-1, 1, 1001)
p = x

mu = np.exp(-(x - 0.5)**2 * 100)   #e^(-(x-0.5)^2 * 100)    #True: 1. False: 0.
mu /= np.sum(mu)                   # mu = mu / np.sum(mu)
nu = np.exp(-(x + 0.2)**2 * 100) + np.exp(-(x+0.7)**2 * 100)
nu /= np.sum(nu) 

phi = np.zeros_like(x)
psi = np.zeros_like(x)  

h = x[1] - x[0]

fig, ax = plt.subplots()
artists = []

ax.set_xlim(-1,1)
ax.set_ylim(-0.0001,0.02) 

for k in range(100):
    #plt.title(r'back-and-forth update $\mu$ and $\nu$. Example 1:  Iterate ' + str(k+1))
    title = ax.text(4.5, 1.15, 'back-and-forth update $\mu$ and $\nu$. Example 2:  Iterate {}'.format(str(k+1)))
    
    sigma1, J, H1_sq, phi, psi, pfwd  = ascent(phi, psi, mu, nu, sigma1)
    print(f'{k:3}: J(φ) = {J}, (H¹)² = {H1_sq:.3}, σ₂ = {sigma1:.5}')
    
    if common_sigma:
        sigma2 = sigma1
    img2, = ax.plot(x, pfwd, color='blue')
    
    
    sigma2, J, H1_sq, psi, phi, pfwd = ascent(psi, phi, nu, mu, sigma2)
    if common_sigma:
        sigma1 = sigma2
        
    img1, = ax.plot(x, pfwd, color='red')
    
    
    if k % 1 == 0:
        artists.append([img1] + [img2] + [title])
        
ani = animation.ArtistAnimation(fig, artists, interval=10, repeat_delay=1000)
plt.show()

<IPython.core.display.Javascript object>

  0: J(φ) = 4.176876539508752e-199, (H¹)² = 3.82e-07, σ₂ = 5.0052e-195
  1: J(φ) = 0.4566813650463195, (H¹)² = 4.68e-08, σ₂ = 2.5026e-195
  2: J(φ) = 0.4010161278072918, (H¹)² = 3.16e-07, σ₂ = 1.2513e-195
  3: J(φ) = 0.4339723745304609, (H¹)² = 1.71e-07, σ₂ = 6.2565e-196
  4: J(φ) = 0.46266363973142743, (H¹)² = 2.63e-08, σ₂ = 3.1283e-196
  5: J(φ) = 0.46447322513733513, (H¹)² = 2.05e-08, σ₂ = 1.5641e-196
  6: J(φ) = 0.4664605404196758, (H¹)² = 1.32e-08, σ₂ = 7.8206e-197
  7: J(φ) = 0.46791657005811216, (H¹)² = 7.11e-09, σ₂ = 3.9103e-197
  8: J(φ) = 0.4689512335753786, (H¹)² = 2.88e-09, σ₂ = 1.9552e-197
  9: J(φ) = 0.46866438646353226, (H¹)² = 4.37e-09, σ₂ = 9.7758e-198
 10: J(φ) = 0.46916729765916954, (H¹)² = 2.5e-09, σ₂ = 4.8879e-198
 11: J(φ) = 0.46962078009628416, (H¹)² = 9.51e-10, σ₂ = 2.4439e-198
 12: J(φ) = 0.46974224338320525, (H¹)² = 7.36e-10, σ₂ = 1.222e-198
 13: J(φ) = 0.4697346293764557, (H¹)² = 1.03e-09, σ₂ = 6.1099e-199
 14: J(φ) = 0.4698313159220353, (H¹)² = 7.04e-10, σ₂ 