goal
- register simple pointsets in 2d using lddmm+shooting+data loss
- extension to probabilty related to gp + dynamic model work ?

In [None]:
# %load_ext autoreload
# %autoreload 2


## Geodesic shooting for registering simple shapes
# 
from functools import partial

import os
os.environ['JAX_ENABLE_X64'] = 'False'

import jax
from jax import random, grad, jit, value_and_grad
import jax.numpy as np
import jax.numpy.linalg as linalg
from jax.scipy.linalg import cho_solve, solve_triangular
from jax.experimental import optimizers

import matplotlib.pyplot as plt
import matplotlib as mpl
# https://matplotlib.org/3.1.1/gallery/style_sheets/style_sheets_reference.html
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 25
mpl.rcParams['font.family'] = 'Times New Roman'
cmap = plt.cm.get_cmap('bwr')

import sklearn.datasets as datasets

import sys
sys.path.append('../gp')
sys.path.append('../ot')
from gpax import cholesky_jitter
from plt_utils import plt_savefig
from otax import sinkhorn_log_stabilized, sinkhorn_divergence, plt_transport_plan

from jax_registration import *
from lddmm_pytorch import Curve


In [None]:
## Parameters

datakey = 'amoeba'
n = 4*20 if datakey is not 'amoeba' else 198
fill = False
xlim = (0, 1)
ylim = (0, 1)
ℓ = .1
euler_steps = 7
δt = .1
grid_nlines = 11
ot_ρ = 1e5
ot_ρ = (10)**2
ot_ϵ = (.015)**2
ot_n_iters=100
λ_regu = .01

## partials

@jax.jit
def k(X, Y=None):
    return cov_se(X,Y,σ2=.5,ℓ=.01) + cov_se(X,Y,σ2=.3,ℓ=.075) + cov_se(X,Y,σ2=.2,ℓ=.15)

@jax.jit
def k(X, Y=None):
    return cov_se(X,Y,σ2=1.,ℓ=.025) + cov_se(X,Y,σ2=.75,ℓ=.15)


shooting = jit(partial(HamiltonianShooting, k=k, euler_steps=euler_steps, δt=δt))
carrying = jit(partial(HamiltonianCarrying, k=k, euler_steps=euler_steps, δt=δt))
data_fn = jit(partial(sinkhorn_log_stabilized, ϵ=ot_ϵ, ρ=ot_ρ, n_iters=ot_n_iters))
cost_fn = jit(sqdist)


## Data
key = random.PRNGKey(5)

def dataset_to_XY(dataset_cls, scale=.25, center=np.array([.5,.5])):
    X, label = dataset_cls()
    X = scale*X
    X = X + (center - np.mean(X,axis=0))
    X, Y = X[label==0], X[label==1]
    return X, Y

f1_name = 'data/amoeba_1.png'
f2_name = 'data/amoeba_2.png'
Q0 = Curve.from_file(f1_name)
Xt = Curve.from_file(f2_name)
    
def get_amoeba():
    return Q0.points, Xt.points

make_dataaset_dict = {
    'shapes': partial(make_two_shapes, shapes=(square, circle), n=n, 
        center=((.25, .5), (.75,  .5)), radius=(.15, .15), fill=fill, nlayers=4),
    'moons': partial(dataset_to_XY, dataset_cls=partial(
        datasets.make_moons, n_samples=(n//2, n-n//2), noise=.05)),
    'circles': partial(dataset_to_XY, dataset_cls=partial(
        datasets.make_circles, n_samples=2*n,  factor=.5, noise=.05)),
    'amoeba': get_amoeba
}

get_dataset = make_dataaset_dict[datakey]
X, Y = get_dataset()

g0, gL = GridData(nlines=grid_nlines, xlim=xlim, ylim=ylim, nsubticks=6)

p0 = np.zeros(X.shape) * 1.
if datakey is 'amoeba':
    
    def line_vertex_area(X, L):
        """Gives vertex weight as average of neighboring edges """
        v1 = X[L[:,0]]
        v2 = X[L[:,1]]
        a = np.sqrt(np.sum((v2-v1)**2, axis=1))

        _, ind0 = np.unique(L[:,0], return_index=True)
        _, ind1 = np.unique(L[:,1], return_index=True)
        ind = np.column_stack((ind0,ind1))
        a = np.sum(a[ind], axis=1) / 2
        return a

    μ = line_vertex_area(X, Q0.connectivity)
    ν = line_vertex_area(Y, Xt.connectivity)
else:
    μ = np.ones((X.shape[0],))
    ν = np.ones((Y.shape[0],))



def plt_shape(ax, q, y):
    ax.scatter(y[:,0], y[:,1], color=cmap(.2), marker='x')
    ax.scatter(q[:,0], q[:,1], color=cmap(.8), marker='o')
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')

fig, axs = plt.subplots(1,2,figsize=(20,10))
plt_shape(axs[0], X, Y)
plt_grid(axs[0], g0, gL)
axs[1].imshow(k(X))


In [None]:
# n_samples = 1
# qs = []
# for i in range(n_samples):
#     key, subkey = random.split(key, 2)
#     p = Qμ + QΣ*random.normal(subkey, Qμ.shape)
#     q1, p1 = shooting(X, p)
#     qs.append(q1)
# qs = np.stack(qs)
Qμ = params['μ']
QΣ = BijSoftplus.forward(params['Σ'])
QΣ = QΣ.reshape(Qμ.shape[0],-1)
print(Hqp(X, Qμ, k))


# print(Hqv(X, p2v(k,X,Qμ,X), k))




K = k(X)
L = cholesky_jitter(K, jitter=5e-5)

vμ = K@Qμ
vΣ = K@np.diag(QΣ.flatten())@K.T
vL = cholesky_jitter(vΣ, jitter=5e-5)

α = solve_triangular(L, vμ, lower=True)
mahan = -.5*np.sum(np.square(α))
lgdet = -2*np.sum(np.log(np.diag(vL)))

print(mahan, lgdet)


#     def log_prob(self, x):
#         """x: (n, m) -> (n*m,)"""
#         x = x.reshape(self.μ.shape)
#         α = solve_triangular(self.L, (x-self.μ), lower=True)
#         mahan = -.5*np.sum(np.square(α))
#         lgdet = -np.sum(np.log(np.diag(self.L)))
#         const = -.5*self.d*np.log(2*np.pi)
#         return mahan + const + lgdet
    



In [None]:
## Plotting
_, ind0 = np.unique(L[:,0], return_index=True)
_, ind1 = np.unique(L[:,1], return_index=True)
line_vertex_area_ind = np.column_stack((ind0,ind1))
line_vertex_area_ind

def line_vertex_area(X, L):
    """Gives vertex weight as average of neighboring edges """
    v1 = X[L[:,0]]
    v2 = X[L[:,1]]
    a = np.sqrt(np.sum((v2-v1)**2, axis=1))
    a = np.sum(a[line_vertex_area_ind], axis=1) / 2
    return a

n_samples = 3
optimize_v = False

def loss_fn(params, key, x, μ, y, ν, g, λ_regu):
    """ Compute loss    D(q1, μ, y, ν) + λ*R(q̇)
            - geodeisc shooting obtain q1
            - compute data matching term D(q1, y)
    """
    if optimize_v:
        v = params['μ']
        p = v2p(k, x, v)
        

    Qμ = params['μ']
    QΣ = BijSoftplus.forward(params['Σ'])
    QΣ = QΣ.reshape(Qμ.shape[0],-1)
    key, subkey = random.split(key, 2)
    p = Qμ + QΣ*random.normal(subkey, Qμ.shape)
    
#     p0s = []; q1s = []
#     for i in range(n_samples):
#         key, subkey = random.split(key, 2)
#         p = Qμ + QΣ*random.normal(subkey, Qμ.shape)
#         q1 = carrying(x, p, g)
#         p0s.append(p); q1s.append(q1)
    
    
    q1, p1, g1 = carrying(x, p, g)

    C = cost_fn(q1, y)
    if datakey == 'amoeba':
        μ = line_vertex_area(q1, Q0.connectivity)
    π, loss_data = data_fn(μ, ν, C)
    
    ## negative log prob ... 
    
    K = k(x)
    L = cholesky_jitter(K, jitter=5e-5)

    vμ = K@Qμ
    vΣ = K@np.diag(QΣ.flatten())@K.T
    vL = cholesky_jitter(vΣ, jitter=5e-5)

    α = solve_triangular(L, vμ, lower=True)
    mahan = -.5*np.sum(np.square(α))
    lgdet = -2*np.sum(np.log(np.diag(vL)))
    loss_regu = -(mahan+lgdet) * λ_regu
    
#     loss_regu = Hqp(x, Qμ, k) * λ_regu

    loss = loss_regu + loss_data

    return loss, {'loss': loss,
                  'loss_regu': loss_regu,
                  'loss_data': loss_data,
                  'π': π, 'μ': μ,
                  'q1': q1, 'p1': p1, 'g1': g1}

loss_fn_capture = jit(partial(loss_fn, x=X, μ=μ, y=Y, ν=ν, g=g0, λ_regu=λ_regu))
value_and_grad_fn = jit(value_and_grad(loss_fn_capture, has_aux=True))
    

n_steps = 600; lr = .001
if optimize_v:
    params = {'μ': v2p(k, X, p0)}
else:
    params = {'μ': p0, 'Σ': BijSoftplus.reverse(np.ones((len(p0)*1,))*.01)}

opt_init, opt_update, get_params = optimizers.adam(step_size=lr)
opt_state = opt_init(params)

axi = 0
display_its = [int(x*n_steps) 
               for x in [0.,.03, .1,.3,1-1/n_steps]]
fig, axs = plt.subplots(2, len(display_its),
                        figsize=(5*len(display_its), 5*2), sharex=True, sharey=True)


def plt_momentum_shooting(axs, params, info):
    
    p0 = v2p(k, X, params['μ'])if optimize_v else params['μ']
    p0Σ = BijSoftplus.forward(params['Σ'])
    q1, g1, π = info['q1'], info['g1'], info['π']
        
    ax = axs[0]
    plt_grid(ax, g0, gL)
    plt_vectorfield(ax, g0, p2v(k, X, p0, g0), scale=None, color='k')
    plt_shape(ax, X, Y)
    plt_vectorfield(ax, X, p0, scale=.3,
                    color=plt.cm.get_cmap('OrRd')(mpl.colors.Normalize()(p0Σ)))
    ax = axs[1]
    plt_grid(ax, g1, gL)
    plt_transport_plan(ax, π, q1, Y, thresh=np.mean(π))
    plt_shape(ax, q1, Y)
    
qs = []

for it in range(n_steps):
    key, subkey = random.split(key)
    params = get_params(opt_state)
    (loss, info), grads = value_and_grad_fn(params, subkey)
    opt_state = opt_update(it, grads, opt_state)
    qs.append(info['q1'])
    
    if it%(n_steps//10) == 0:
        print(f'[{it:4}] loss={info["loss"]:7.3f}'
              f'({info["loss_data"]:7.4f} +{info["loss_regu"]:7.4f})'
              f'\tsum(μ)={np.sum(info["μ"]):.3f}'
              f'\tlogdet={-2*np.sum(np.log(BijSoftplus.forward(params["Σ"]))):.3f}')
    
    if it in display_its:
        plt_momentum_shooting(axs[:,axi], params, info)
        axs[0,axi].set_title(f't={it}', fontsize=40)
        axi += 1
    

fig.tight_layout()
plt_savefig(fig, f'summary/assets/plt_lddmm_points_{datakey}_training.png')


fig, axs = plt.subplots(1,3,figsize=(15,5))
ax = axs[0]
plt_shape(ax, X, Y)
plt_grid(ax, g0, gL)
plt_momentum_shooting(axs[1:], params, info)
fig.tight_layout()
plt_savefig(fig, f'summary/assets/plt_lddmm_points_{datakey}.png')

In [None]:


    
def plt_points(ax, X, **kwargs):
    ax.scatter(X[:,0], X[:,1], **kwargs)
    
def plt_scaled_colorbar_ax(ax, **kwargs):
    """ `fig.colorbar(im, cax=plt_scaled_colobar_ax(ax))` """
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.1,)
    return cax
    

In [None]:
Qμ = params['μ']
QΣ = BijSoftplus.forward(params['Σ'])
QΣ = QΣ.reshape(Qμ.shape[0],-1)

n_samples = 1
qs = []
for i in range(n_samples):
    key, subkey = random.split(key, 2)
    p = Qμ + QΣ*random.normal(subkey, Qμ.shape)
    q1, p1 = shooting(X, p)
    qs.append(q1)
qs = np.stack(qs)


fig, axs = plt.subplots(2,1,figsize=(10,20))

ax = axs[0]

ind = np.argsort(QΣ.flatten())
# nn = 3; ind = np.hstack((ind[:nn], ind[-nn:]))

plt_points(ax, X, color=cmap(.8), marker='x')

mpl_norm = mpl.colors.Normalize()
cmap_OrRd = plt.cm.get_cmap('OrRd')
color_by_Σ = cmap_OrRd(mpl_norm(QΣ))
# for q1 in qs:
#     plt_points(ax, q1[ind,:], color=color_by_Σ[ind,:], marker='o')
plt_vectorfield(ax, X, Qμ, scale=.2, color=color_by_Σ)
    
plt_grid(ax, g0, gL)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect('equal')

fig.colorbar(plt.cm.ScalarMappable(norm=mpl_norm, cmap=cmap_OrRd),
             cax=plt_scaled_colorbar_ax(ax), shrink=.5)



ax = axs[1]

mpl_norm = mpl.colors.Normalize()
cmap_RdPu = plt.cm.get_cmap('RdPu')

K = k(g0, X)
vΣ = np.diag(K@np.diag(QΣ.flatten())@K.T)
plt_vectorfield(ax, g0, p2v(k, X, Qμ, g0), scale=5, color=cmap_RdPu(mpl_norm(vΣ)))
fig.colorbar(plt.cm.ScalarMappable(norm=mpl_norm, cmap=cmap_RdPu),
             cax=plt_scaled_colorbar_ax(ax), shrink=.5)

plt_grid(ax, g0, gL)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect('equal')

fig.tight_layout()

In [None]:




fig, axs = plt.subplots(2,3,figsize=(15,10))

for i in range(6):
    
    Qμ = params['μ']
    QΣ = BijSoftplus.forward(params['Σ'])
    QΣ = QΣ.reshape(Qμ.shape[0],-1)
    key, subkey = random.split(key, 2)
    p = Qμ + QΣ*random.normal(subkey, Qμ.shape)
    q1, p1 = shooting(X, p)
    
    
    ax = axs[i//3, i%3]
    mpl_norm = mpl.colors.Normalize()
    cmap_OrRd = plt.cm.get_cmap('OrRd')
    color_by_Σ = cmap_OrRd(mpl_norm(QΣ))
    
    plt_points(ax, X, color=cmap(.8), marker='x')
    plt_points(ax, q1, color=color_by_Σ, marker='o')
    plt_vectorfield(ax, X, Qμ, scale=.3, color=color_by_Σ)
    plt_grid(ax, g0, gL)
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')
    
    fig.colorbar(plt.cm.ScalarMappable(norm=mpl_norm, cmap=cmap_OrRd),
                 cax=plt_scaled_colorbar_ax(ax), shrink=.5)
    




In [None]:

fig, axs = plt.subplots(1,3,figsize=(15,5))
ax = axs[0]
plt_shape(ax, X, Y)
plt_grid(ax, g0, gL)
ind = random.randint(key, (10,), 0, len(X))
paramsp0 = params['p0'] if optimize_p else v2p(k, X, params['v0'])
plt_momentum_shooting(axs[1:], paramsp0, info['q1'], info['g1'], info['π'])
fig.tight_layout()

In [None]:
# # try f-bfgs
# # - get Desired error not necessarily achieved due to precision loss. 
# #   cannot be fixed ... 

# from scipy.optimize import minimize

# def loss_fn_minimize(p, x, μ, y, ν, g, λ_regu):
#     """ Compute loss    D(q1, μ, y, ν) + λ*R(q̇)
#             - geodeisc shooting obtain q1
#             - compute data matching term D(q1, y)
#     """
#     p = p.reshape(-1,2)
#     q1, p1, g1 = shooting(x, p, g)

#     C = cost_fn(q1, y)
#     π, loss_data = data_fn(μ, ν, C)
#     loss_regu = regu_fn(x, p) * λ_regu

#     loss = loss_regu + loss_data

#     return loss


# loss_fn_minimize_capture = partial(loss_fn_minimize, x=X,μ=μ,y=Y,ν=ν,g=g0,λ_regu=λ_regu)
# value_and_grad = jax.jit(jax.value_and_grad(loss_fn_minimize_capture))
    
# results = minimize(value_and_grad, p0.reshape(-1,), method='BFGS',
#                    jac=True, tol=1e-10, options={'maxiter': 5000, 'disp': True})
# print(results.nit, results.message)


# params = {'p0': results.x.reshape(-1,2)}
# _, info = loss_fn_capture(params)

# fig, axs = plt.subplots(1,2,figsize=(20,10))

# ax = axs[0]
# plt_shape(ax, X, Y)
# plt_vectorfield(ax, X, params['p0'], color=cmap(.9))

# ax = axs[1]
# plt_shape(ax, info['q1'], Y)
# plt_grid(ax, info['g1'], gL)
# plt_vectorfield(ax, info['g1'], k(info['g1'], info['q1'])@info['p1'], scale=None, color='k')

