goal
- [done] register simple pointsets in 2d using lddmm+shooting+data loss
- [done] variational q(m) setup optimize for elbo
    - stochasticity in shooting particles ... time consuming
- extension to probabilty related to gp + dynamic model work ?

In [None]:
from utils_jpt import jpt_autoreload, jpt_full_width
jpt_autoreload(); jpt_full_width()


## Geodesic shooting for registering simple shapes
# 
from functools import partial

import os
os.environ['JAX_ENABLE_X64'] = 'False'

import jax
from jax import random, grad, jit, value_and_grad, vmap
import jax.numpy as np
import jax.numpy.linalg as linalg
from jax.scipy.linalg import cho_solve, solve_triangular
from jax.experimental import optimizers

import matplotlib.pyplot as plt
import matplotlib as mpl
# https://matplotlib.org/3.1.1/gallery/style_sheets/style_sheets_reference.html
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 25
mpl.rcParams['font.family'] = 'Times New Roman'
cmap = plt.cm.get_cmap('bwr')

import sklearn.datasets as datasets
import scipy.interpolate as interp

import sys
sys.path.append('../gp')
sys.path.append('../ot')
from gpax import cholesky_jitter, BijSoftplus, kl_mvn_tril_zero_mean_prior
from otax import sinkhorn_log_stabilized, sinkhorn_divergence, plt_transport_plan

from jax_registration import *
from lddmm_pytorch import Curve, level_curves


In [None]:
## Parameters

# amoeba, octopus

datakey = 'amoeba'
n = 4*20 if datakey in ['shapes', 'moons', 'circles'] else 198
fill = False
xlim = (0, 1)
ylim = (0, 1)
ℓ = .1
euler_steps = 7
δt = .1
grid_nlines = 11
ot_ρ = 1e5
ot_ρ = (3)**2
ot_ϵ = (.015)**2 ; ot_ϵ = (.025)**2 ; ot_ϵ = (.015)**2;
ot_λ = .998  #  ; ot_λ = 0.9999999
ot_ρ = ot_ϵ*ot_λ/(1-ot_λ)
ot_n_iters=100
λ_regu = .01; λ_regu = .003; λ_regu = .001
scale_momentum = .2

# deer 
# λ_regu = .001 will get wrong results 
# ot_ϵ = .5

# amoeba: 
# ot_ϵ = (.025)**2
# ot_λ = .998
# λ_regu = .003

# shapes, moons, circles 
# ot_λ = 0.9999999
# λ_regu = .001;
# ot_ϵ = (.015)**2
# n_samples = 10

## partials

# @jax.jit
# def k(X, Y=None):
#     return cov_se(X,Y,σ2=.5,ℓ=.01) + cov_se(X,Y,σ2=.3,ℓ=.075) + cov_se(X,Y,σ2=.2,ℓ=.15)


# @jax.jit
# def k(X, Y=None):
#     return cov_se(X,Y,σ2=1.,ℓ=.025) + cov_se(X,Y,σ2=.75,ℓ=.15)


# @jax.jit
# def k(X, Y=None):
#     return cov_se(X, Y, ℓ=.1)

shooting = jit(partial(HamiltonianShooting, euler_steps=euler_steps, δt=δt))
shooting_batched = jit(vmap(shooting, (None, 0)))
carrying = jit(partial(HamiltonianCarrying, euler_steps=euler_steps, δt=δt))
data_fn = jit(partial(sinkhorn_log_stabilized, ϵ=ot_ϵ, ρ=ot_ρ, n_iters=ot_n_iters))
cost_fn = jit(sqdist)


## Data
key = random.PRNGKey(5)
os.makedirs(f'summary/assets/{datakey}', exist_ok=True)

def dataset_to_XY(dataset_cls, scale=.25, center=np.array([.5,.5])):
    X, label = dataset_cls()
    X = scale*X
    X = X + (center - np.mean(X,axis=0))
    X, Y = X[label==0], X[label==1]
    return X, Y

filenames = {
    'amoeba': ['data/amoeba_1.png', 'data/amoeba_2.png'],
    'bat': ['data/mpeg7/bat-14.png', 'data/mpeg7/bat-15.png'],
    'octopus': ['data/mpeg7/octopus-6.png', 'data/mpeg7/octopus-1.png'],
    'deer': ['data/mpeg7/deer-12.png', 'data/mpeg7/deer-14.png']
}

level_curve_kwargs = {
    'octopus': {'smoothing': 2},
    'deer': {'smoothing': 3}
}

if datakey in filenames:
    fns = filenames[datakey]
    kwargs = level_curve_kwargs[datakey] if datakey in level_curve_kwargs else {}
    Q0 = level_curves(fns[0], npoints=n, **kwargs)
    Xt = level_curves(fns[1], npoints=n, **kwargs)

def get_png_points():
    return Q0.points, Xt.points

make_dataset_dict = {
    'shapes': partial(make_two_shapes, shapes=(square, circle), n=n, 
        center=((.25, .5), (.75,  .5)), radius=(.15, .15), fill=fill, nlayers=4),
    'moons': partial(dataset_to_XY, dataset_cls=partial(
        datasets.make_moons, n_samples=(n//2, n-n//2), noise=.05)),
    'circles': partial(dataset_to_XY, dataset_cls=partial(
        datasets.make_circles, n_samples=2*n,  factor=.5, noise=.05)),
}

make_dataset_dict.update({k: get_png_points for k in filenames})

get_dataset = make_dataset_dict[datakey]
X, Y = get_dataset()

g0, gL = GridData(nlines=grid_nlines, xlim=xlim, ylim=ylim, nsubticks=6)
g0fine, gLfine = GridData(nlines=22, xlim=xlim, ylim=ylim)

p0 = np.zeros(X.shape) * 1.
if datakey is 'amoeba':
    μ = line_vertex_area(Q0.points, Q0.connectivity)
    ν = line_vertex_area(Xt.points, Xt.connectivity)
    print(f'sum(μ)={np.sum(μ):3f}, sum(ν)={np.sum(ν):3f}')
else:
    μ = np.ones((X.shape[0],))
    ν = np.ones((Y.shape[0],))


fig, axs = plt.subplots(1,2,figsize=(20,10))
ax = axs[0]
ax.scatter(X[:,0], X[:,1], color=cmap(.8), marker='o')
ax.scatter(Y[:,0], Y[:,1], color=cmap(.2), marker='x')
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect('equal')

plt_grid(ax, g0, gL)
ax = axs[1]
# ax.imshow(k(X))


In [None]:
## Plotting


n_samples = 10
optimize_v = False

def k(X, Y=None, params=None):
    ℓ0 = BijSoftplus.forward(params['k_0']['ℓ'])
    ℓ1 = BijSoftplus.forward(params['k_1']['ℓ'])
    return cov_se(X,Y,σ2=1,ℓ=.025) + cov_se(X,Y,σ2=.75,ℓ=.15)
#     return cov_se(X,Y,σ2=1,ℓ=ℓ0) + cov_se(X,Y,σ2=1,ℓ=ℓ1)


def loss_fn(params, key, x, μ, y, ν, g, λ_regu, k):
    """ Compute loss    D(q1, μ, y, ν) + λ*R(q̇)
            - geodeisc shooting obtain q1
            - compute data matching term D(q1, y)
    """
    k = partial(k, params=params)    
    shooting = partial(HamiltonianShooting, k=k, euler_steps=euler_steps, δt=δt)
    shooting_batched = vmap(shooting, (None, 0))
    carrying = partial(HamiltonianCarrying, k=k, euler_steps=euler_steps, δt=δt)
    kl_mvn_batched = vmap(kl_mvn_tril_zero_mean_prior, (1, None, None)) # wrt D dim in vμ
    
    if optimize_v:
        v = params['μ']
        p = v2p(k, x, v)

    Qμ = params['μ']
    QΣ = BijSoftplus.forward(params['Σ'])
    QΣ = QΣ.reshape(Qμ.shape[0],-1)
    
    key, subkey = random.split(key, 2)
    p = Qμ + random.normal(subkey, (n_samples, *Qμ.shape))*QΣ # (#samples, #V, 2)
    
    q, _ = shooting_batched(x, p)
    q1, p1, g1 = carrying(jax.lax.stop_gradient(x),
                          jax.lax.stop_gradient(Qμ), g)

    C = vmap(cost_fn, (0,None))(q, y)
    Cμ = np.mean(C, axis=0)
    if datakey == 'amoeba':
        edge_lengths = line_edge_area(q1, Q0.connectivity)
        μ = np.sum(edge_lengths[Q0.line_vertex_area_ind], axis=1)/2
    π, loss_data = data_fn(μ, ν, Cμ)
    
    ## kl(q(v)||p(v)) where q ~ N(Kμ, K\diag[h]Kᵀ) and
    #                       p ~ N(0, K)
    K = k(x)
    L = cholesky_jitter(K, jitter=5e-5)
    vμ, vΣ = mvn_linear(K, Qμ, QΣ)
    vL = cholesky_jitter(vΣ, jitter=5e-5)
    loss_regu = -np.sum(kl_mvn_batched(vμ, vL, L)) * λ_regu
    
    #; loss_regu = Hqp(x, Qμ, k) * λ_regu

    loss = loss_regu + loss_data

    return loss, {'loss': loss,
                  'loss_regu': loss_regu,
                  'loss_data': loss_data,
                  'π': π, 'μ': μ, 'C': C, 'Cvar': np.var(C, axis=0),
                  'q1': q1, 'p1': p1, 'g1': g1}

loss_fn_capture = jit(partial(loss_fn, x=X, μ=μ, y=Y, ν=ν, g=g0, λ_regu=λ_regu, k=k))
value_and_grad_fn = jit(value_and_grad(loss_fn_capture, has_aux=True))
    

n_steps = 800; lr = .001

key = random.PRNGKey(0)
if optimize_v:
    params = {'μ': v2p(k, X, p0)}
else:
    key, sk1, sk2 = random.split(key,3)
    params = {'μ': p0, 'Σ': BijSoftplus.reverse(np.ones((len(p0)*1,))*.01),
              'k_0': {'σ2': BijSoftplus.reverse(np.ones((1,))),
                      'ℓ':  BijSoftplus.reverse((np.ones((1,)) + .2*random.normal(sk1, (1,)))*.01),},
              'k_1': {'σ2': BijSoftplus.reverse(np.ones((1,))),
                      'ℓ':  BijSoftplus.reverse((np.ones((1,)) + .2*random.normal(sk2, (1,)))*.1),}}

opt_init, opt_update, get_params = optimizers.adam(step_size=lr)
opt_state = opt_init(params)

axi = 0
display_its = [int(x*n_steps) 
               for x in [0.,.03, .1,.3,1-1/n_steps]]
fig, axs = plt.subplots(2, len(display_its),
                        figsize=(5*len(display_its), 5*2), sharex=True, sharey=True)
plt.setp(axs, xlim=xlim, ylim=ylim, xticks=[], yticks=[])



def plt_momentum_shooting(axs, params, info, p0_color='b'):

    def k(X, Y=None):
        ℓ0 = BijSoftplus.forward(params['k_0']['ℓ'])
        ℓ1 = BijSoftplus.forward(params['k_1']['ℓ'])
        return cov_se(X,Y,σ2=1,ℓ=.025) + cov_se(X,Y,σ2=.75,ℓ=.15)
#         return cov_se(X,Y,σ2=1,ℓ=ℓ0) + cov_se(X,Y,σ2=1,ℓ=ℓ1)
    
    p0 = v2p(k, X, params['μ'])if optimize_v else params['μ']
    p0Σ = BijSoftplus.forward(params['Σ'])
    q1, g1, π = info['q1'], info['g1'], info['π']
        
    ax = axs[0]
    plt_grid(ax, g0, gL)
    plt_vectorfield(ax, g0, p2v(k, X, p0, g0), scale=None, color='k')
    plt_points(ax, X, color=cmap(.8), marker='x', s=2)
    p0_color = cmap(.1) if p0_color == 'b' else plt.cm.get_cmap('OrRd')(mpl.colors.Normalize()(p0Σ))
    plt_vectorfield(ax, X, p0, scale=scale_momentum,
                    color=p0_color)
    ax = axs[1]
    plt_grid(ax, g1, gL)
    shooting_intermediates = jit(vmap(HamiltonianShooting, (None,None,None,0,None)), static_argnums=(2,))
    qs, _ = shooting_intermediates(X, params['μ'], k, np.arange(euler_steps+1), δt)
    ind = random.randint(key, (100,), 0, len(X))
    segs = qs[:,ind,:].transpose((1,0,2))
    ax.add_collection(LineCollection(
        segs, color='k', linewidths=(1,), linestyle='dashed'))
    plt_points(ax, X, color='k', marker='o', s=2)
    
    plt_transport_plan(ax, π, q1, Y, thresh=np.mean(π))
    plt_points(ax, Y, color=cmap(.2), marker='o', s=2)
    plt_points(ax, q1, color=cmap(.8), marker='x', s=2)
    

    
qs = []

for it in range(n_steps):
    key, subkey = random.split(key)
    params = get_params(opt_state)
    (loss, info), grads = value_and_grad_fn(params, subkey)
    opt_state = opt_update(it, grads, opt_state)
    qs.append(info['q1'])
    
    if it%(n_steps//10) == 0:
        ℓ0 = BijSoftplus.forward(params["k_0"]["ℓ"])[0]
        ℓ1 = BijSoftplus.forward(params["k_1"]["ℓ"])[0]
        
        print(f'[{it:4}] loss={info["loss"]:7.3f}'
              f'({info["loss_data"]:7.4f} +{info["loss_regu"]:7.4f})'
              f'\tsum(μ)={np.sum(info["μ"]):.3f}'
              f'\tlogdet={-2*np.sum(np.log(BijSoftplus.forward(params["Σ"]))):.3f}'
              f'\tnorm(var(C))={linalg.norm(info["Cvar"]):.3f}'
              f'\tℓ={ℓ0:.3f},{ℓ1:.3f}')
    
    if it in display_its:
        plt_momentum_shooting(axs[:,axi], params, info)
        axs[0,axi].text(0.025, .975, f't={it}', fontsize=30, ha='left', va='top')
        axi += 1
    


def plt_savefig(fig, save_path):
    fig.tight_layout()
    fig.savefig(save_path, bbox_inches='tight', dpi=400)
    
fig.tight_layout()
plt_savefig(fig, f'summary/assets/{datakey}/plt_lddmm_points_training.png')


In [None]:
carrying = partial(HamiltonianCarrying, k=partial(k,params=params), euler_steps=euler_steps, δt=δt)
carrying_batched = jit(vmap(carrying, (None, 0, None)))
shooting_batched = jit(vmap(partial(HamiltonianShooting, k=partial(k,params=params), euler_steps=euler_steps, δt=δt), (None, 0)))

In [None]:

fig, axs = plt.subplots(1,4,figsize=(20,5))
plt.setp(axs, xlim=xlim, ylim=ylim, xticks=[], yticks=[])
ax = axs[0]
plt_points(ax, X, color=cmap(.8), marker='x', s=2)
plt_points(ax, Y, color=cmap(.2), marker='o', s=2)
plt_grid(ax, g0, gL)
plt_momentum_shooting(axs[1:3], params, info, p0_color='x')


ax = axs[3]
meshn = 100
n_samples = 100
XX, YY = np.meshgrid(np.linspace(*xlim, meshn), np.linspace(*ylim, meshn))
ZZ = np.stack((XX, YY)).reshape(2,-1).T
key, subkey = random.split(key, 2)
Qμ = params['μ']
QΣ = BijSoftplus.forward(params['Σ'])
QΣ = QΣ.reshape(Qμ.shape[0],-1)
ps = Qμ + random.normal(subkey, (n_samples, *Qμ.shape))*QΣ
ps = np.vstack((Qμ[np.newaxis,...],ps)) # ps[0] is Qμ

qs, _, g1fine = carrying_batched(X, ps, ZZ)
F = g1fine.transpose((1,2,0))
Fμ, Fσ2 = np.mean(F, axis=-1), np.mean(np.var(F, axis=-1), axis=-1)

interp_Fσ2 = interp.Rbf(Fμ[:,0], Fμ[:,1], Fσ2, function='thin_plate', smooth=0)
im = ax.imshow(interp_Fσ2(ZZ[:,0], ZZ[:,1]).reshape(meshn,meshn), extent=[0,1,0,1], origin='lower', cmap='Greys')
plt_points(ax, Y, color=cmap(.2), marker='o', s=2)
ax.text(.975, .975, f'{np.mean(Fσ2):.1e}', fontsize=30, ha='right', va='top', backgroundcolor=plt.cm.get_cmap('Greys')(.5))


fig.tight_layout()
plt_savefig(fig, f'summary/assets/{datakey}/plt_lddmm_points.png')


In [None]:

n_samples = 100
key, subkey = random.split(key, 2)
ps = Qμ + random.normal(subkey, (n_samples, *Qμ.shape))*QΣ
ps = np.vstack((Qμ[np.newaxis,...],ps)) # ps[0] is Qμ

qs, _, g1fine = carrying_batched(X, ps, ZZ)
F = g1fine.transpose((1,2,0))
Fμ, Fσ2 = np.mean(F, axis=-1), np.mean(np.var(F, axis=-1), axis=-1)

# mpl_norm = mpl.colors.Normalize()
# cmap_m = plt.cm.get_cmap('OrRd')
# color_by_Σ = cmap_m(mpl_norm(QΣ))
# ind = np.argsort(QΣ.flatten())
# # nn = 10; ind = np.hstack((ind[:nn], ind[-nn:]))
# for q1 in qs:
#     plt_points(ax, q1[ind,:], color=color_by_Σ[ind,:], marker='o', s=20)

# mpl_norm = mpl.colors.Normalize()
# cmap_u = plt.cm.get_cmap('viridis')
# plt_points(ax, Fμ, color=cmap_u(mpl_norm(Fσ2)), marker='o', s=10)
# fig.colorbar(plt.cm.ScalarMappable(norm=mpl_norm, cmap=cmap_u),
#              cax=plt_scaled_colorbar_ax(ax), shrink=.5)

interp_Fσ2 = interp.Rbf(Fμ[:,0], Fμ[:,1], Fσ2, function='thin_plate', smooth=0)
im = ax.imshow(interp_Fσ2(ZZ[:,0], ZZ[:,1]).reshape(meshn,meshn), extent=[0,1,0,1], origin='lower')

plt_points(ax, Y, color='k', marker='o', s=5)
plt_points(ax, qs[0], color='r', marker='x', s=5)

fig.colorbar(im,cax=plt_scaled_colorbar_ax(ax), shrink=.5)

In [None]:

Qμ = params['μ']
QΣ = BijSoftplus.forward(params['Σ'])
QΣ = QΣ.reshape(Qμ.shape[0],-1)
K = k(g0, X, params=params)
vμ, vΣ = mvn_linear(K, Qμ, QΣ); vΣ = np.diag(vΣ)


fig, axs = plt.subplots(1,2,figsize=(12,20))
plt.setp(axs, xlim=xlim, ylim=ylim, xticks=[], yticks=[])

ax = axs[0]

plt_points(ax, X, color='r', marker='x', s=5)

meshn = 100
XX, YY = np.meshgrid(np.linspace(*xlim, meshn), np.linspace(*ylim, meshn))
ZZ = np.stack((XX, YY)).reshape(2,-1).T

vfineμ, vfineΣ = mvn_linear(k(ZZ, X, params=params),Qμ,QΣ)
vfineΣ = np.diag(vfineΣ).reshape((meshn, meshn))

im = ax.imshow(vfineΣ, extent=[0,1,0,1], origin='lower')
fig.colorbar(im,cax=plt_scaled_colorbar_ax(ax), shrink=.5)


# ax = axs[1]
# mpl_norm = mpl.colors.Normalize()
# cmap_v = plt.cm.get_cmap('OrRd')

# plt_points(ax, g0, color=cmap_v(mpl_norm(vΣ)))
# plt_vectorfield(ax, g0, vμ, scale=None, color=cmap_v(mpl_norm(vΣ)))
# fig.colorbar(plt.cm.ScalarMappable(norm=mpl_norm, cmap=cmap_v),
#              cax=plt_scaled_colorbar_ax(ax), shrink=.5)


# mpl_norm = mpl.colors.Normalize()
# cmap_m = plt.cm.get_cmap('OrRd')
# plt_vectorfield(ax, X, Qμ, scale=scale_momentum, color=cmap_m(mpl_norm(QΣ)))


plt_grid(ax, g0, gL)


## 
ax = axs[1]

n_samples = 100
key, subkey = random.split(key, 2)
ps = Qμ + random.normal(subkey, (n_samples, *Qμ.shape))*QΣ
ps = np.vstack((Qμ[np.newaxis,...],ps)) # ps[0] is Qμ

qs, _, g1fine = carrying_batched(X, ps, ZZ)
F = g1fine.transpose((1,2,0))
Fμ, Fσ2 = np.mean(F, axis=-1), np.mean(np.var(F, axis=-1), axis=-1)

# mpl_norm = mpl.colors.Normalize()
# cmap_m = plt.cm.get_cmap('OrRd')
# color_by_Σ = cmap_m(mpl_norm(QΣ))
# ind = np.argsort(QΣ.flatten())
# # nn = 10; ind = np.hstack((ind[:nn], ind[-nn:]))
# for q1 in qs:
#     plt_points(ax, q1[ind,:], color=color_by_Σ[ind,:], marker='o', s=20)

# mpl_norm = mpl.colors.Normalize()
# cmap_u = plt.cm.get_cmap('viridis')
# plt_points(ax, Fμ, color=cmap_u(mpl_norm(Fσ2)), marker='o', s=10)
# fig.colorbar(plt.cm.ScalarMappable(norm=mpl_norm, cmap=cmap_u),
#              cax=plt_scaled_colorbar_ax(ax), shrink=.5)

interp_Fσ2 = interp.Rbf(Fμ[:,0], Fμ[:,1], Fσ2, function='thin_plate', smooth=0)
im = ax.imshow(interp_Fσ2(ZZ[:,0], ZZ[:,1]).reshape(meshn,meshn), extent=[0,1,0,1], origin='lower')

plt_points(ax, Y, color='k', marker='o', s=5)
plt_points(ax, qs[0], color='r', marker='x', s=5)

fig.colorbar(im,cax=plt_scaled_colorbar_ax(ax), shrink=.5)

fig.tight_layout()
plt_savefig(fig, f'summary/assets/{datakey}/plt_lddmm_points_momentumvelocity_field.png')

In [None]:
# np.abs(F-np.mean(F,axis=-1).squeeze())**2
foo = np.abs(F - np.mean(F,axis=-1)[...,np.newaxis])**2
foo = np.mean(foo,axis=-1)

fig,axs = plt.subplots(1,2,figsize=(20,10))
ax = axs[0]

interp_Fσ2 = interp.Rbf(Fμ[:,0], Fμ[:,1], Fσ2, function='thin_plate', smooth=0)
im = ax.imshow(interp_Fσ2(ZZ[:,0], ZZ[:,1]).reshape(meshn,meshn), extent=[0,1,0,1], origin='lower')

plt_points(ax, Y, color='k', marker='o', s=5)
plt_points(ax, qs[0], color='r', marker='x', s=5)

fig.colorbar(im,cax=plt_scaled_colorbar_ax(ax), shrink=.5)

ax = axs[1]

mpl_norm = mpl.colors.Normalize()
cmap_u = plt.cm.get_cmap('viridis')
plt_points(ax, Fμ, color=cmap_u(mpl_norm(Fσ2)), marker='o', s=10)
fig.colorbar(plt.cm.ScalarMappable(norm=mpl_norm, cmap=cmap_u),
             cax=plt_scaled_colorbar_ax(ax), shrink=.5)

In [None]:

# plt_savefig(fig, f'summary/assets/{datakey}/plt_lddmm_points_momentumvelocity_field.png')

In [None]:
fig, axs = plt.subplots(1,euler_steps,figsize=(5*euler_steps,5))
plt.setp(axs, xlim=xlim, ylim=ylim, xticks=[], yticks=[])


shooting_intermediates = jit(vmap(HamiltonianCarrying, (None,None,None,None,0,None)), static_argnums=(3,))
qs, ps, gs = shooting_intermediates(X, Qμ, g0, partial(k,params=params),np.arange(euler_steps)+1,δt)

for i in range(euler_steps):
    ax = axs[i]
    plt_points(ax, qs[i], color=cmap(.8), marker='x', s=2)
    plt_points(ax, Y, color='k', marker='o', s=2)
    plt_vectorfield(ax, qs[i], ps[i], scale=scale_momentum*2, color='b')
#     plt_vectorfield(ax, gs[i], p2v(partial(k,params=params), qs[i], ps[i], gs[i]), scale=None, color='k')
    plt_grid(ax, gs[i], gL)

    
fig.tight_layout()
plt_savefig(fig, f'summary/assets/{datakey}/plt_lddmm_points_forwardint.png')


In [None]:
# # try f-bfgs
# # - get Desired error not necessarily achieved due to precision loss. 
# #   cannot be fixed ... 

# from scipy.optimize import minimize

# def loss_fn_minimize(p, x, μ, y, ν, g, λ_regu):
#     """ Compute loss    D(q1, μ, y, ν) + λ*R(q̇)
#             - geodeisc shooting obtain q1
#             - compute data matching term D(q1, y)
#     """
#     p = p.reshape(-1,2)
#     q1, p1, g1 = shooting(x, p, g)

#     C = cost_fn(q1, y)
#     π, loss_data = data_fn(μ, ν, C)
#     loss_regu = regu_fn(x, p) * λ_regu

#     loss = loss_regu + loss_data

#     return loss


# loss_fn_minimize_capture = partial(loss_fn_minimize, x=X,μ=μ,y=Y,ν=ν,g=g0,λ_regu=λ_regu)
# value_and_grad = jax.jit(jax.value_and_grad(loss_fn_minimize_capture))
    
# results = minimize(value_and_grad, p0.reshape(-1,), method='BFGS',
#                    jac=True, tol=1e-10, options={'maxiter': 5000, 'disp': True})
# print(results.nit, results.message)


# params = {'p0': results.x.reshape(-1,2)}
# _, info = loss_fn_capture(params)

# fig, axs = plt.subplots(1,2,figsize=(20,10))

# ax = axs[0]
# plt_shape(ax, X, Y)
# plt_vectorfield(ax, X, params['p0'], color=cmap(.9))

# ax = axs[1]
# plt_shape(ax, info['q1'], Y)
# plt_grid(ax, info['g1'], gL)
# plt_vectorfield(ax, info['g1'], k(info['g1'], info['q1'])@info['p1'], scale=None, color='k')

