goal
- [done] register simple pointsets in 2d using lddmm+shooting+data loss
- [done] variational q(m) setup optimize for elbo
    - stochasticity in shooting particles ... time consuming
- extension to probabilty related to gp + dynamic model work ?

In [None]:
from utils_jpt import jpt_autoreload, jpt_full_width
jpt_autoreload(); jpt_full_width()


## Geodesic shooting for registering simple shapes
# 
from functools import partial

import os
os.environ['JAX_ENABLE_X64'] = 'False'

import jax
from jax import random, grad, jit, value_and_grad, vmap
import jax.numpy as np
import jax.numpy.linalg as linalg
from jax.scipy.linalg import cho_solve, solve_triangular
from jax.experimental import optimizers

import matplotlib.pyplot as plt
import matplotlib as mpl
# https://matplotlib.org/3.1.1/gallery/style_sheets/style_sheets_reference.html
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 25
mpl.rcParams['font.family'] = 'Times New Roman'
cmap = plt.cm.get_cmap('bwr')

import sklearn.datasets as datasets

import sys
sys.path.append('../gp')
sys.path.append('../ot')
from gpax import cholesky_jitter, BijSoftplus, kl_mvn_tril_zero_mean_prior
from otax import sinkhorn_log_stabilized, sinkhorn_divergence, plt_transport_plan

from jax_registration import *
from lddmm_pytorch import Curve, level_curves


In [None]:
## Parameters

datakey = 'amoeba'
n = 4*20 if datakey is not 'amoeba' else 198
fill = False
xlim = (0, 1)
ylim = (0, 1)
ℓ = .1
euler_steps = 7
δt = .1
grid_nlines = 11
ot_ρ = 1e5
ot_ρ = (3)**2
ot_ϵ = (.015)**2 ; ot_ϵ = (.025)**2; ot_ϵ = (.03)**2
ot_λ = .998; ot_λ = .998
ot_ρ = ot_ϵ*ot_λ/(1-ot_λ)
ot_n_iters=100
λ_regu = .01; λ_regu = .002

# amoeba: 
# ot_ϵ = (.025)**2
# ot_λ = .998
# λ_regu = .003

# this 


## partials

# @jax.jit
# def k(X, Y=None):
#     return cov_se(X,Y,σ2=.5,ℓ=.01) + cov_se(X,Y,σ2=.3,ℓ=.075) + cov_se(X,Y,σ2=.2,ℓ=.15)


# @jax.jit
# def k(X, Y=None):
#     return cov_se(X,Y,σ2=1.,ℓ=.025) + cov_se(X,Y,σ2=.75,ℓ=.15)


# @jax.jit
# def k(X, Y=None):
#     return cov_se(X, Y, ℓ=.1)

shooting = jit(partial(HamiltonianShooting, euler_steps=euler_steps, δt=δt))
shooting_batched = jit(vmap(shooting, (None, 0)))
carrying = jit(partial(HamiltonianCarrying, euler_steps=euler_steps, δt=δt))
data_fn = jit(partial(sinkhorn_log_stabilized, ϵ=ot_ϵ, ρ=ot_ρ, n_iters=ot_n_iters))
cost_fn = jit(sqdist)


## Data
key = random.PRNGKey(5)

def dataset_to_XY(dataset_cls, scale=.25, center=np.array([.5,.5])):
    X, label = dataset_cls()
    X = scale*X
    X = X + (center - np.mean(X,axis=0))
    X, Y = X[label==0], X[label==1]
    return X, Y

f1_name = 'data/amoeba_1.png'
f2_name = 'data/amoeba_2.png'
Q0 = level_curves(f1_name, npoints=n)
Xt = level_curves(f2_name, npoints=n)
    
def get_amoeba():
    return Q0.points, Xt.points

make_dataaset_dict = {
    'shapes': partial(make_two_shapes, shapes=(square, circle), n=n, 
        center=((.25, .5), (.75,  .5)), radius=(.15, .15), fill=fill, nlayers=4),
    'moons': partial(dataset_to_XY, dataset_cls=partial(
        datasets.make_moons, n_samples=(n//2, n-n//2), noise=.05)),
    'circles': partial(dataset_to_XY, dataset_cls=partial(
        datasets.make_circles, n_samples=2*n,  factor=.5, noise=.05)),
    'amoeba': get_amoeba
}

get_dataset = make_dataaset_dict[datakey]
X, Y = get_dataset()

g0, gL = GridData(nlines=grid_nlines, xlim=xlim, ylim=ylim, nsubticks=6)

p0 = np.zeros(X.shape) * 1.
if datakey is 'amoeba':
    μ = line_vertex_area(X, Q0.connectivity)
    ν = line_vertex_area(Y, Xt.connectivity)
    print(f'sum(μ)={np.sum(μ):3f}, sum(ν)={np.sum(ν):3f}')
else:
    μ = np.ones((X.shape[0],))
    ν = np.ones((Y.shape[0],))


fig, axs = plt.subplots(1,2,figsize=(20,10))
ax = axs[0]
ax.scatter(X[:,0], X[:,1], color=cmap(.8), marker='o')
ax.scatter(Y[:,0], Y[:,1], color=cmap(.2), marker='x')
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect('equal')

plt_grid(ax, g0, gL)
ax = axs[1]
# ax.imshow(k(X))


In [None]:
n_inducing = 50

params = {'μ': .01*np.ones((n_inducing,2)), 'Σ': BijSoftplus.reverse(np.ones((n_inducing,))*.01),
          'Xu': X[np.floor(np.linspace(0,len(X)-1,n_inducing)).astype(np.int32),:],
          'k_0': {'σ2': BijSoftplus.reverse(np.ones((1,))),
                  'ℓ':  BijSoftplus.reverse((np.ones((1,)) + random.normal(sk1, (1,)))*.1),},
          'k_1': {'σ2': BijSoftplus.reverse(np.ones((1,))),
                  'ℓ':  BijSoftplus.reverse((np.ones((1,)) + random.normal(sk2, (1,)))*.1),}}


n_samples = 2
qu = params['Xu']
Qμ = params['μ']
QΣ = BijSoftplus.forward(params['Σ'])
QΣ = QΣ.reshape(Qμ.shape[0],-1)
key, subkey = random.split(key, 2)
p = Qμ + random.normal(subkey, (n_samples, *Qμ.shape))*QΣ # (#samples, #V, 2)

print(Qμ.shape, QΣ.shape, p.shape)

carrying = partial(HamiltonianCarrying, k=cov_se, euler_steps=euler_steps, δt=δt)
carrying_batched = vmap(carrying, (None, 0, None))



qu1, _, q = carrying_batched(qu, p, X)



C = vmap(cost_fn, (0,None))(q, Y)
Cμ = np.mean(C, axis=0)
π, loss_data = data_fn(μ, ν, Cμ)


print(X.shape, qu1.shape, q.shape)

## kl(q(v)||p(v)) where q ~ N(Kμ, K\diag[h]Kᵀ) and
#                       p ~ N(0, K)

kl_mvn_batched = vmap(kl_mvn_tril_zero_mean_prior, (1, None, None)) # wrt D dim in vμ
    
Kfu = cov_se(X, qu)
vμ, vΣ = mvn_linear(Kfu, Qμ, QΣ)
vL = cholesky_jitter(vΣ, jitter=5e-5)
K = cov_se(X)
L = cholesky_jitter(K, jitter=5e-5)
loss_regu = -np.sum(kl_mvn_batched(vμ, vL, L)) * λ_regu


fig, ax = plt.subplots(1,1,figsize=(10,10))

plt_grid(ax, g0, gL)

plt_points(ax, X, color=cmap(.8), marker='x', s=2)
plt_points(ax, Y, color=cmap(.2), marker='o', s=2)

plt_points(ax, qu, color='r', marker='+', s=30)
plt_points(ax, q1[0], color='y', marker='o', s=30)
plt_points(ax, qu1[0], color='b', marker='+', s=30)


# p0_color = cmap(.1) if p0_color == 'b' else plt.cm.get_cmap('OrRd')(mpl.colors.Normalize()(p0Σ))
# plt_vectorfield(ax, X, p0, scale=.2, color=p0_color)

In [None]:
## Plotting


n_samples = 1
optimize_v = False

def k(X, Y=None, params=None):
    σ20 = BijSoftplus.forward(params['k_0']['σ2'])
    σ21 = BijSoftplus.forward(params['k_1']['σ2'])
    ℓ0 = BijSoftplus.forward(params['k_0']['ℓ'])
    ℓ1 = BijSoftplus.forward(params['k_1']['ℓ'])
    # return cov_se(X,Y,σ2=1,ℓ=.025) + cov_se(X,Y,σ2=.75,ℓ=.15)
    return cov_se(X,Y,σ2=σ20,ℓ=ℓ0) + cov_se(X,Y,σ2=σ21,ℓ=ℓ1)


def loss_fn(params, key, x, μ, y, ν, g, λ_regu, k):
    """ Compute loss    D(q1, μ, y, ν) + λ*R(q̇)
            - geodeisc shooting obtain q1
            - compute data matching term D(q1, y)
    """
    k = partial(k, params=params)    
    shooting = partial(HamiltonianShooting, k=k, euler_steps=euler_steps, δt=δt)
    shooting_batched = vmap(shooting, (None, 0))
    carrying = partial(HamiltonianCarrying, k=k, euler_steps=euler_steps, δt=δt)
    carrying_batched = vmap(carrying, (None, 0, None))
    kl_mvn_batched = vmap(kl_mvn_tril_zero_mean_prior, (1, None, None)) # wrt D dim in vμ
    
    if optimize_v:
        v = params['μ']
        p = v2p(k, x, v)

    qu = params['Xu']
    Qμ = params['μ']
    QΣ = BijSoftplus.forward(params['Σ'])
    QΣ = QΣ.reshape(Qμ.shape[0],-1)
    
    key, subkey = random.split(key, 2)
    p = Qμ + 0*random.normal(subkey, (n_samples, *Qμ.shape))*QΣ # (#samples, #V, 2)

    _, _, q = carrying_batched(qu, p, X)
    
    _, p1, q1 = carrying(jax.lax.stop_gradient(qu),
                        jax.lax.stop_gradient(Qμ), X)
    qu1, pu1, g1 = carrying(jax.lax.stop_gradient(qu),
                            jax.lax.stop_gradient(Qμ), g0)

    
    if datakey == 'amoeba':
        edge_lengths = line_edge_area(q1, Q0.connectivity)
        μ = np.sum(edge_lengths[Q0.line_vertex_area_ind], axis=1)/2

    C = vmap(cost_fn, (0,None))(q, y)
    Cμ = np.mean(C, axis=0)
    π, loss_data = data_fn(μ, ν, Cμ)
    
    ## kl(q(v)||p(v)) where q ~ N(Kμ, K\diag[h]Kᵀ) and
    #                       p ~ N(0, K)
    
    K = k(qu)
    L = cholesky_jitter(K, jitter=5e-5)
    vμ, vΣ = mvn_linear(K, Qμ, QΣ)
    vL = cholesky_jitter(vΣ, jitter=5e-5)
    loss_regu = -np.sum(kl_mvn_batched(vμ, vL, L)) * λ_regu
    
#     Kfu = k(x, qu)
#     vμ, vΣ = mvn_linear(Kfu, Qμ, QΣ)
#     vL = cholesky_jitter(vΣ, jitter=5e-5)
#     K = k(X)
#     L = cholesky_jitter(K, jitter=5e-5)
#     loss_regu = -np.sum(kl_mvn_batched(vμ, vL, L)) * λ_regu + loss_regu
    
    #; loss_regu = Hqp(x, Qμ, k) * λ_regu

    loss = loss_regu + loss_data

    return loss, {'loss': loss,
                  'loss_regu': loss_regu,
                  'loss_data': loss_data,
                  'π': π, 'μ': μ, 'C': C, 'Cvar': np.var(C, axis=0),
                  'q1': q1, 'p1': p1, 'g1': g1,
                  'qu0': params['Xu'], 'qu1': qu1,}

loss_fn_capture = jit(partial(loss_fn, x=X, μ=μ, y=Y, ν=ν, g=g0, λ_regu=λ_regu, k=k))
value_and_grad_fn = jit(value_and_grad(loss_fn_capture, has_aux=True))
    

n_steps = 800; lr = .001
n_inducing = 50

key = random.PRNGKey(0)
if optimize_v:
    params = {'μ': v2p(k, X, p0)}
else:
    key, sk1, sk2 = random.split(key,3)
    params = {'μ': .01*np.zeros((n_inducing,2)), 'Σ': BijSoftplus.reverse(np.ones((n_inducing,))*.01),
              'Xu': X[random.randint(key,(n_inducing,),0,len(X)),:],
              'k_0': {'σ2': BijSoftplus.reverse(np.ones((1,))),
                      'ℓ':  BijSoftplus.reverse((np.ones((1,)) + random.normal(sk1, (1,)))*.01),},
              'k_1': {'σ2': BijSoftplus.reverse(np.ones((1,))),
                      'ℓ':  BijSoftplus.reverse((np.ones((1,)) + random.normal(sk2, (1,)))*.1),}}


opt_init, opt_update, get_params = optimizers.adam(step_size=lr)
opt_state = opt_init(params)

axi = 0
display_its = [int(x*n_steps) 
               for x in [0.,.03, .1,.3,1-1/n_steps]]
fig, axs = plt.subplots(2, len(display_its),
                        figsize=(5*len(display_its), 5*2), sharex=True, sharey=True)
plt.setp(axs, xlim=xlim, ylim=ylim, xticks=[], yticks=[])



def plt_momentum_shooting(axs, params, info, p0_color='b'):
    def k(X, Y=None):
        σ20 = BijSoftplus.forward(params['k_0']['σ2'])
        σ21 = BijSoftplus.forward(params['k_1']['σ2'])
        ℓ0 = BijSoftplus.forward(params['k_0']['ℓ'])
        ℓ1 = BijSoftplus.forward(params['k_1']['ℓ'])
        # return cov_se(X,Y,σ2=1,ℓ=.025) + cov_se(X,Y,σ2=.75,ℓ=.15)
        return cov_se(X,Y,σ2=σ20,ℓ=ℓ0) + cov_se(X,Y,σ2=σ21,ℓ=ℓ1)
    
    p0 = v2p(k, X, params['μ'])if optimize_v else params['μ']
    p0Σ = BijSoftplus.forward(params['Σ'])
    q1, g1, π = info['q1'], info['g1'], info['π']
        
    ax = axs[0]
    plt_grid(ax, g0, gL)
    plt_vectorfield(ax, g0, p2v(k, params['Xu'], p0, g0), scale=None, color='k')
    plt_points(ax, X, color=cmap(.8), marker='x', s=2)
    plt_points(ax, info['qu0'], color='fuchsia', marker='o', s=30)
    p0_color = cmap(.1) if p0_color == 'b' else plt.cm.get_cmap('OrRd')(mpl.colors.Normalize()(p0Σ))
    plt_vectorfield(ax, params['Xu'], p0, scale=.2,
                    color=p0_color)
    ax = axs[1]
    plt_grid(ax, g1, gL)
    plt_transport_plan(ax, π, q1, Y, thresh=np.mean(π))
    plt_points(ax, Y, color=cmap(.2), marker='o', s=2)
    plt_points(ax, q1, color=cmap(.8), marker='x', s=2)
    plt_points(ax, info['qu1'], color='fuchsia', marker='o', s=30)
    

    
qs = []

for it in range(n_steps):
    key, subkey = random.split(key)
    params = get_params(opt_state)
    (loss, info), grads = value_and_grad_fn(params, subkey)
    opt_state = opt_update(it, grads, opt_state)
    qs.append(info['q1'])
    
    if it%(n_steps//10) == 0:
        σ20 = BijSoftplus.forward(params["k_0"]["σ2"])[0]
        σ21 = BijSoftplus.forward(params["k_1"]["σ2"])[0]
        ℓ0 = BijSoftplus.forward(params["k_0"]["ℓ"])[0]
        ℓ1 = BijSoftplus.forward(params["k_1"]["ℓ"])[0]
        
        print(f'[{it:4}] loss={info["loss"]:7.3f}'
              f'({info["loss_data"]:7.4f} +{info["loss_regu"]:7.4f})'
              f'\tsum(μ)={np.sum(info["μ"]):.3f}'
              f'\tlogdet={-2*np.sum(np.log(BijSoftplus.forward(params["Σ"]))):.3f}'
              f'\tnorm(var(C))={linalg.norm(info["Cvar"]):.3f}'
              f'\tℓ|σ2={ℓ0:.3f},{ℓ1:.3f}|{σ20:.3f},{σ21:.3f}')
    
    if it in display_its:
        plt_momentum_shooting(axs[:,axi], params, info)
        axs[0,axi].set_title(f't={it}', fontsize=40)
        axi += 1
    


def plt_savefig(fig, save_path):
    fig.tight_layout()
    fig.savefig(save_path, bbox_inches='tight', dpi=400)

fig.tight_layout()
plt_savefig(fig, f'summary/assets/plt_lddmm_points_inducing={n_inducing}_{datakey}_training.png')

# not optimizing Xu
# [   0] loss= -1.834( 0.1147 +-1.9488)	sum(μ)=3.114	logdet=460.517	norm(var(C))=0.000	ℓ=0.158,0.185
# [  50] loss= -1.940( 0.0143 +-1.9544)	sum(μ)=3.114	logdet=465.493	norm(var(C))=0.000	ℓ=0.158,0.185
# [ 100] loss= -1.955( 0.0049 +-1.9594)	sum(μ)=3.114	logdet=470.473	norm(var(C))=0.000	ℓ=0.158,0.185
# [ 150] loss= -1.961( 0.0036 +-1.9644)	sum(μ)=3.114	logdet=475.455	norm(var(C))=0.000	ℓ=0.158,0.185
# [ 200] loss= -1.966( 0.0031 +-1.9693)	sum(μ)=3.114	logdet=480.441	norm(var(C))=0.000	ℓ=0.158,0.185
# [ 250] loss= -1.972( 0.0027 +-1.9743)	sum(μ)=3.114	logdet=485.430	norm(var(C))=0.000	ℓ=0.158,0.185
# [ 300] loss= -1.977( 0.0024 +-1.9792)	sum(μ)=3.114	logdet=490.421	norm(var(C))=0.000	ℓ=0.158,0.185
# [ 350] loss= -1.982( 0.0022 +-1.9841)	sum(μ)=3.114	logdet=495.414	norm(var(C))=0.000	ℓ=0.158,0.185
# [ 400] loss= -1.987( 0.0021 +-1.9890)	sum(μ)=3.114	logdet=500.409	norm(var(C))=0.000	ℓ=0.158,0.185
# [ 450] loss= -1.992( 0.0020 +-1.9940)	sum(μ)=3.114	logdet=505.407	norm(var(C))=0.000	ℓ=0.158,0.185

# some take-aways 
# - optimize kl(v@xu) is better than kl(v@x)! but points still do cluster,
# - optimize kernel hyperparameters ℓ etc. alongside with inducing points does the trick!
#     - points able to align with target ... 


In [None]:

    

def plt_momentum_shooting(axs, params, info, p0_color='b'):

    def k(X, Y=None):
        ℓ0 = BijSoftplus.forward(params['k_0']['ℓ'])
        ℓ1 = BijSoftplus.forward(params['k_1']['ℓ'])
        return cov_se(X,Y,σ2=1,ℓ=.025) + cov_se(X,Y,σ2=.75,ℓ=.15)
    
    p0 = v2p(k, X, params['μ'])if optimize_v else params['μ']
    p0Σ = BijSoftplus.forward(params['Σ'])
    q1, g1, π = info['q1'], info['g1'], info['π']
        
    ax = axs[0]
    plt_grid(ax, g0, gL)
    plt_vectorfield(ax, g0, p2v(k, params['Xu'], p0, g0), scale=None, color='k')
    plt_points(ax, X, color=cmap(.8), marker='x', s=2)
    plt_points(ax, info['qu0'], color='fuchsia', marker='o', s=30)
    p0_color = cmap(.1) if p0_color == 'b' else plt.cm.get_cmap('OrRd')(mpl.colors.Normalize()(p0Σ))
    plt_vectorfield(ax, params['Xu'], p0, scale=.2,
                    color=p0_color)
    ax = axs[1]
    plt_grid(ax, g1, gL)
    plt_transport_plan(ax, π, q1, Y, thresh=np.mean(π))
    plt_points(ax, Y, color=cmap(.2), marker='o', s=2)
    plt_points(ax, q1, color=cmap(.8), marker='x', s=2)
    plt_points(ax, info['qu1'], color='fuchsia', marker='o', s=30)
    
    
fig, axs = plt.subplots(1,3,figsize=(15,5))
plt.setp(axs, xlim=xlim, ylim=ylim, xticks=[], yticks=[])
ax = axs[0]
plt_points(ax, X, color=cmap(.8), marker='x', s=2)
plt_points(ax, Y, color=cmap(.2), marker='o', s=2)
plt_grid(ax, g0, gL)
plt_momentum_shooting(axs[1:], params, info, p0_color='b')
fig.tight_layout()
plt_savefig(fig, f'summary/assets/plt_lddmm_points_inducing={n_inducing}_{datakey}.png')



In [None]:


fig, axs = plt.subplots(1,2,figsize=(20,10))
plt.setp(axs, xlim=xlim, ylim=ylim, xticks=[], yticks=[])

ax = axs[0]

Xu = params['Xu']
Qμ = params['μ']
QΣ = BijSoftplus.forward(params['Σ'])
QΣ = QΣ.reshape(Qμ.shape[0],-1)

n_samples = 1
key, subkey = random.split(key, 2)
ps = Qμ + random.normal(subkey, (n_samples, *Qμ.shape))*QΣ
# shooting = jit(partial(HamiltonianShooting, k=partial(k,params=params), euler_steps=euler_steps, δt=δt))
# shooting_batched = jit(vmap(shooting, (None, 0)))
# qs, _ = shooting_batched(X, ps)

plt_points(ax, X, color='lightgrey', marker='x')

mpl_norm = mpl.colors.Normalize()
cmap_OrRd = plt.cm.get_cmap('OrRd')
color_by_Σ = cmap_OrRd(mpl_norm(QΣ))
plt_vectorfield(ax, Xu, Qμ, scale=.2, color=color_by_Σ) # color_by_Σ


ind = np.argsort(QΣ.flatten())
# nn = 1; ind = np.hstack((ind[:nn], ind[-nn:]))
# for q1 in qs:
#     plt_points(ax, q1[ind,:], color=color_by_Σ[ind,:], marker='o')
# plt_points(ax, Y, color=cmap(.2), marker='o', s=2)

    
fig.colorbar(plt.cm.ScalarMappable(norm=mpl_norm, cmap=cmap_OrRd),
             cax=plt_scaled_colorbar_ax(ax), shrink=.5)

plt_grid(ax, g0, gL)

ax = axs[1]

mpl_norm = mpl.colors.Normalize()
cmap_RdPu = plt.cm.get_cmap('RdPu')

K = k(g0, Xu, params)
vΣ = np.diag(K@np.diag(QΣ.flatten())@K.T)
plt_vectorfield(ax, g0, p2v(partial(k,params=params), Xu, Qμ, g0), scale=5, color=cmap_RdPu(mpl_norm(vΣ)))
fig.colorbar(plt.cm.ScalarMappable(norm=mpl_norm, cmap=cmap_RdPu),
             cax=plt_scaled_colorbar_ax(ax), shrink=.5)

plt_grid(ax, g0, gL)

fig.tight_layout()
plt_savefig(fig, f'summary/assets/plt_lddmm_points_inducing={n_inducing}_{datakey}_mv.png')


In [None]:




fig, axs = plt.subplots(2,3,figsize=(15,10))

for i in range(6):
    
    Qμ = params['μ']
    QΣ = BijSoftplus.forward(params['Σ'])
    QΣ = QΣ.reshape(Qμ.shape[0],-1)
    key, subkey = random.split(key, 2)
    p = Qμ + QΣ*random.normal(subkey, Qμ.shape)
    q1, p1 = shooting(X, p)
    
    
    ax = axs[i//3, i%3]
    mpl_norm = mpl.colors.Normalize()
    cmap_OrRd = plt.cm.get_cmap('OrRd')
    color_by_Σ = cmap_OrRd(mpl_norm(QΣ))
    
    plt_points(ax, X, color=cmap(.8), marker='x')
    plt_points(ax, q1, color=color_by_Σ, marker='o')
    plt_vectorfield(ax, X, Qμ, scale=.3, color=color_by_Σ)
    plt_grid(ax, g0, gL)
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')
    
    fig.colorbar(plt.cm.ScalarMappable(norm=mpl_norm, cmap=cmap_OrRd),
                 cax=plt_scaled_colorbar_ax(ax), shrink=.5)
    




In [None]:

fig, axs = plt.subplots(1,3,figsize=(15,5))
ax = axs[0]
plt_shape(ax, X, Y)
plt_grid(ax, g0, gL)
ind = random.randint(key, (10,), 0, len(X))
paramsp0 = params['p0'] if optimize_p else v2p(k, X, params['v0'])
plt_momentum_shooting(axs[1:], paramsp0, info['q1'], info['g1'], info['π'])
fig.tight_layout()

In [None]:
# # try f-bfgs
# # - get Desired error not necessarily achieved due to precision loss. 
# #   cannot be fixed ... 

# from scipy.optimize import minimize

# def loss_fn_minimize(p, x, μ, y, ν, g, λ_regu):
#     """ Compute loss    D(q1, μ, y, ν) + λ*R(q̇)
#             - geodeisc shooting obtain q1
#             - compute data matching term D(q1, y)
#     """
#     p = p.reshape(-1,2)
#     q1, p1, g1 = shooting(x, p, g)

#     C = cost_fn(q1, y)
#     π, loss_data = data_fn(μ, ν, C)
#     loss_regu = regu_fn(x, p) * λ_regu

#     loss = loss_regu + loss_data

#     return loss


# loss_fn_minimize_capture = partial(loss_fn_minimize, x=X,μ=μ,y=Y,ν=ν,g=g0,λ_regu=λ_regu)
# value_and_grad = jax.jit(jax.value_and_grad(loss_fn_minimize_capture))
    
# results = minimize(value_and_grad, p0.reshape(-1,), method='BFGS',
#                    jac=True, tol=1e-10, options={'maxiter': 5000, 'disp': True})
# print(results.nit, results.message)


# params = {'p0': results.x.reshape(-1,2)}
# _, info = loss_fn_capture(params)

# fig, axs = plt.subplots(1,2,figsize=(20,10))

# ax = axs[0]
# plt_shape(ax, X, Y)
# plt_vectorfield(ax, X, params['p0'], color=cmap(.9))

# ax = axs[1]
# plt_shape(ax, info['q1'], Y)
# plt_grid(ax, info['g1'], gL)
# plt_vectorfield(ax, info['g1'], k(info['g1'], info['q1'])@info['p1'], scale=None, color='k')

