In [95]:
import argparse
import pathlib
import json
import os
import torch
import numpy as np
import time
import sys
import matplotlib.pyplot as plt
import geomloss
from scipy.special import softmax

sys.path.append('/data/pauly2/tk/crashs')
from vtkutil import *
from lddmm import *
from crashs import MeshData, Template, ASHSFolder, Workspace, ashs_output_to_cruise_input, run_cruise, cruise_postproc

%cd /data/pauly2/tk/crashs

/data/pauly2/tk/crashs


In [96]:
# Load the template
template = dict({ s: Template('/data/pauly2/ashs_xv/manifest/template_init_dir', s) for s in ['left','right'] })

# Load the ASHS json
with open('/data/pauly2/ashs_xv/manifest/crashs_input_fold_0.json') as fd:
    ashs_input_desc = json.load(fd)

# Prepare device
# device = torch.device(args.device) if torch.cuda.is_available() else 'cpu'
device = torch.device(torch.cuda.current_device() if torch.cuda.is_available() else 'cpu')
print("Is cuda available?", torch.cuda.is_available())
print("Device count?", torch.cuda.device_count())
print("Current device?", torch.cuda.current_device())
print("Device name? ", torch.cuda.get_device_name(torch.cuda.current_device()))

# Keep track of ASHS importers and workspaces created
data = {}

# Run basic Nighres for each subject
for d in ashs_input_desc:
    id = d['id']
    side = d['side']

    # Create the output dir
    out_dir = os.path.join('/data/pauly2/ashs_xv/work/fold_0/crashs_build', id)
    os.makedirs(out_dir, exist_ok=True)
    workspace = Workspace(out_dir, id)

    # Store the data
    data[id] = { 'side': side, 'workspace': workspace }

Is cuda available? True
Device count? 1
Current device? 0
Device name?  Quadro RTX 5000


In [61]:
# From the template directory, load the left/right flip file. 
flip_lr = np.loadtxt(os.path.join(template['left'].root, 'ashs_template_flip.mat'))

# Set the sigma tensors
# sigma_varifold = torch.tensor([template['left'].get_varifold_sigma()], dtype=torch.float32, device=device)
sigma_varifold = torch.tensor([5], dtype=torch.float32, device=device)
sigma_lddmm = torch.tensor([template['left'].get_lddmm_sigma()], dtype=torch.float32, device=device)

# Load each of the meshes that will be used to build the template
md = {}
for id, sd in data.items():
    
    # Depending on the side, apply flip_lr as the transform
    transform = flip_lr if sd['side'] == 'right' else None

    # Load the mesh data (inflated avg surface in ASHS template space)
    sd['pd_input'] = load_vtk(sd['workspace'].affine_moving)

    # Downsample the mesh to a reasonable number of vertices
    md[id] = MeshData(sd['pd_input'], device, transform=transform, target_faces=20000)
    sd['pd_ds'] = vtk_make_pd(md[id].v, md[id].f)

    # Apply additional Taubin smoothing to help build a smoother template
    v,f = taubin_smooth(md[id].v, md[id].f, lam=0.5, mu=-0.45, steps=1000)
    md[id] = MeshData(vtk_set_cell_array(vtk_make_pd(v, f), 'plab', md[id].lp), device)

    pd = vtk_make_pd(md[id].v, md[id].f)
    vtk_set_cell_array(pd, 'plab', md[id].lp)
    save_vtk(pd, f'tmp/test_decimate_{id}.vtk')


Decimating mesh, target: 20000 faces
Decimation complete, 19992 faces
Decimating mesh, target: 20000 faces
Decimation complete, 20000 faces
Decimating mesh, target: 20000 faces
Decimation complete, 20000 faces
Decimating mesh, target: 20000 faces
Decimation complete, 20000 faces
Decimating mesh, target: 20000 faces
Decimation complete, 20000 faces
Decimating mesh, target: 20000 faces
Decimation complete, 19968 faces
Decimating mesh, target: 20000 faces
Decimation complete, 20000 faces
Decimating mesh, target: 20000 faces
Decimation complete, 19976 faces
Decimating mesh, target: 20000 faces
Decimation complete, 19984 faces
Decimating mesh, target: 20000 faces
Decimation complete, 19992 faces
Decimating mesh, target: 20000 faces
Decimation complete, 20000 faces
Decimating mesh, target: 20000 faces
Decimation complete, 20000 faces
Decimating mesh, target: 20000 faces
Decimation complete, 20000 faces


In [4]:
# Compute the varifold loss between all pairs of atlases in the template subset before
# any registration - this is to determine the best candidate for the template
ids = list(md.keys())
dsq_sub = np.zeros((len(md), len(md)))
for i1, (k1,v1) in enumerate(md.items()):
    for i2, (k2,v2) in enumerate(md.items()):
        pair_loss = lossVarifoldSurfWithLabels(v1.ft, v2.vt, v2.ft, v1.lpt, v2.lpt, 
                                                GaussLinKernelWithLabels(sigma_varifold, v1.lp.shape[1]))
        dsq_sub[i1,i2] = pair_loss(v1.vt).item()
        print(k1, k2, dsq_sub[i1,i2])
        
# Find the index of the template candidate
i_src = np.argmin(dsq_sub.sum(axis=1))
id_src = ids[i_src]
print(f'Best template candidate: {id_src}, mean distance {dsq_sub.mean(axis=1)[i_src]} vs {dsq_sub.mean()}')



104937L 104937L 0.0
104937L 106049L 167569.28125
104937L 106312R 243772.5625
104937L 113909R 195552.71875
104937L 116748R 124347.96875
104937L 117243R 181867.875
104937L 117667R 122609.0625
104937L 118374L 182103.5
104937L 118430R 242721.0625
104937L 120126L 121219.84375
104937L 120267L 168783.46875
104937L 120937L 195741.1875
104937L 121250L 65552.875
106049L 104937L 167569.28125
106049L 106049L 0.0
106049L 106312R 171448.625
106049L 113909R 142091.8125
106049L 116748R 154462.8125
106049L 117243R 106478.28125
106049L 117667R 126409.9375
106049L 118374L 81883.84375
106049L 118430R 127655.0625
106049L 120126L 181912.90625
106049L 120267L 151421.25
106049L 120937L 181152.59375
106049L 121250L 206505.6875
106312R 104937L 243772.5625
106312R 106049L 171448.625
106312R 106312R 0.0
106312R 113909R 135512.4375
106312R 116748R 218999.09375
106312R 117243R 175810.25
106312R 117667R 217042.9375
106312R 118374L 186314.5
106312R 118430R 197615.21875
106312R 120126L 247729.3125
106312R 120267L 1879

In [5]:
def my_rotation_from_vector(x):
    """
    Generate a 3D rotation vector from three parameters.

    Args:
        x: 
            A torch tensor of shape (3). It contains the parameters of the rotation.
            [Write more detail about what the parameters mean geometrically]
    Output:
        A shape (3,3) tensor holding a rotation matrix corresponding to x
    """
    # I will use the the axis/angle representation. The norm of the vector x gives the
    # angle in radians, and the normalized vector is the axis around which the rotation
    # is performed. At x=[0,0,0], there is a degeneracy that requires special handling
    # but this should not prevent the code from being used in optimization
    
    # Compute theta, no issues here
    theta = torch.norm(x)

    # Use the trick from `torch.nn.functional.normalize`, which adds a small epsilon to
    # the denominator to avoid division by zero
    v = torch.nn.functional.normalize(x, dim=0)

    # Apply the Rodrigues formula
    A = torch.zeros(3, 3, dtype=x.dtype, device=x.device)
    A[0,1], A[0,2], A[1,2] = -v[2], v[1], -v[0]
    K = A - A.T
    R = torch.eye(3, dtype=x.dtype, device=x.device) + torch.sin(theta) * K + (1-torch.cos(theta)) * (K @ K)
    return R

In [6]:
# Compute the varifold loss between all pairs of atlases with similarity transform,
# running a quick registration between all pairs. This might get too much for large
# training sets though
ids = list(md.keys())
dsq_sub = np.zeros((len(md), len(md)))

# The default affine parameters
theta_all = { }
kernel = GaussLinKernelWithLabels(sigma_varifold, md[ids[0]].lp.shape[1])
for i1, (k1,v1) in enumerate(md.items()):
    for i2, (k2,v2) in enumerate(md.items()):
        if k1 != k2: 
            # Define the symmetric loss for this pair
            loss_ab = lossVarifoldSurfWithLabels(v1.ft, v2.vt, v2.ft, v1.lpt, v2.lpt, kernel)
            loss_ba = lossVarifoldSurfWithLabels(v2.ft, v1.vt, v1.ft, v2.lpt, v1.lpt, kernel)
            pair_theta = torch.tensor([0.01, 0.01, 0.01, 1.0, 0.0, 0.0, 0.0], 
                                    dtype=torch.float32, device=device, requires_grad=True)
            
            # Create optimizer
            opt_affine = torch.optim.LBFGS([pair_theta], max_eval=10, max_iter=10, line_search_fn='strong_wolfe')

            # Define closure
            def closure():
                opt_affine.zero_grad()

                R = my_rotation_from_vector(pair_theta[0:3]) * pair_theta[3]
                b = pair_theta[4:]
                R_inv = torch.inverse(R)
                b_inv = - R_inv @ b

                v1_to_v2 = (R @ v1.vt.t()).t() + b
                v2_to_v1 = (R_inv @ v2.vt.t()).t() + b_inv

                L = 0.5 * (loss_ab(v1_to_v2) + loss_ba(v2_to_v1))
                L.backward()
                return L
            
            # Run the optimization
            for i in range(10):
                opt_affine.step(closure)

            # Print loss and record the best run/best parameters
            dsq_sub[i1,i2] = closure().item()
            theta_all[(k1,k2)] = pair_theta.detach()
            print(f'Pair {k1}, {k2} loss : {dsq_sub[i1,i2]:8.6f}')

Pair 104937L, 106049L loss : 119642.062500


KeyboardInterrupt: 

In [302]:
# Find the index of the template candidate
i_best = np.argmin(dsq_sub.sum(axis=1))
k_best = ids[i_best]
print(f'Best template candidate: {k_best}, mean distance {dsq_sub.mean(axis=1)[i_best]} vs {dsq_sub.mean()}')

# Now we need to go through and apply the best transformation to each case
for i, (k,v) in enumerate(md.items()):
    affine_mat = np.eye(4)
    if k != k_best:
        # Compute the transform to move k to k_best
        theta_pair = theta_all[(k,k_best)]
        R = my_rotation_from_vector(theta_pair[0:3]) * theta_pair[3]
        b = theta_pair[4:]

        # Save the transformation
        affine_mat[0:3,0:3] = R.detach().cpu().numpy()
        affine_mat[0:3,  3] = b.detach().cpu().numpy()

        # Save the registered mesh
        v_reg = (R @ v.vt.t()).t() + b
        pd_reg = vtk_make_pd(v_reg.detach().cpu().numpy(), v.f)
        pd_reg = vtk_set_cell_array(pd_reg, 'plab', v.lp)
        save_vtk(pd_reg, f'tmp/test_register_{k}.vtk')
    else:
        pd_reg = vtk_make_pd(v.vt.detach().cpu().numpy(), v.f)
        pd_reg = vtk_set_cell_array(pd_reg, 'plab', v.lp)
        save_vtk(pd_reg, f'tmp/test_register_{k}.vtk')

    np.savetxt(f'tmp/affine_to_template_{k}.mat', affine_mat)

Best template candidate: 113909R, mean distance 82006.578125 vs 96785.2781989645


# Start here once you did affine

In [62]:
# Now let's read these meshes all over again
md_aff = {}
for id, sd in data.items():
    md_aff[id] = MeshData(load_vtk(f'tmp/test_register_{id}.vtk'), device)


In [63]:
# Compute the varifold loss between all pairs of atlases in the template subset before
# any registration - this is to determine the best candidate for the template
dsq_sub_aff = np.zeros((len(md_aff), len(md_aff)))
for i1, (k1,v1) in enumerate(md_aff.items()):
    for i2, (k2,v2) in enumerate(md_aff.items()):
        pair_loss = lossVarifoldSurfWithLabels(v1.ft, v2.vt, v2.ft, v1.lpt, v2.lpt, 
                                                GaussLinKernelWithLabels(sigma_varifold, v1.lp.shape[1]))
        dsq_sub_aff[i1,i2] = pair_loss(v1.vt).item()
        
# Find the index of the template candidate
i_src = np.argmin(dsq_sub_aff.sum(axis=1))
id_src = list(md_aff.keys())[i_src]
print(f'Best template candidate: {id_src}, mean distance {dsq_sub_aff.mean(axis=1)[i_src]} vs {dsq_sub_aff.mean()}')

Best template candidate: 113909R, mean distance 57551.558293269234 vs 66266.58621486687


In [33]:
# Select this candidate and go with it
md_src = md_aff[id_src]

# Create losses for each of the target meshes
loss = {}
for i, (id,v) in enumerate(md_aff.items()):

    # Data loss with label similarity
    dataloss = lossVarifoldSurfWithLabels(
        md_src.ft, v.vt, v.ft, md_src.lpt, v.lpt, 
        GaussLinKernelWithLabels(sigma_varifold, md_src.lp.shape[1]))
    
    # Complete LDDMM loss
    loss[id] = LDDMMloss(GaussKernel(sigma=sigma_lddmm), dataloss, gamma=0.1)
    
# Create a storage for the template coordinates and momenta at each iteration
td = list()

# Initialize the first template iteration
td.append({
    'q': torch.tensor(md_src.vt, dtype=torch.float32, device=device, requires_grad=True).contiguous(),
    'p': torch.zeros((len(md_aff),) + md_src.vt.shape, dtype=torch.float32, device=device, requires_grad=True).contiguous()
})

# Outer loop: template update iterations
for m in range(5):

    # Generate a combined objective function
    optimizer = torch.optim.LBFGS([td[m]['p']], max_eval=16, max_iter=16, line_search_fn='strong_wolfe')
    print("performing optimization...")
    start = time.time()

    def closure():
        optimizer.zero_grad()
        L = 0
        for i, (id,v) in enumerate(md_aff.items()):
            L = L + loss[id](td[m]['p'][i,:], td[m]['q'])
        L = L / len(md_aff.items())
        print("loss", L.detach().cpu().numpy())
        L.backward()
        return L

    for i in range(10):
        print("loop", m, "it ", i, ": ", end="")
        optimizer.step(closure)

    print("Optimization (L-BFGS) time: ", round(time.time() - start, 2), " seconds")

    # Compute the new starting point for optimization
    (p_mean, q_mean)=Shooting(td[m]['p'].mean(axis=0), td[m]['q'], GaussKernel(sigma=sigma_lddmm))[-1]
    td.append({
        'q': q_mean.detach().clone().detach().requires_grad_(True),
        'p': torch.zeros_like(td[m]['p']).requires_grad_(True)
    })

# Save the template with labels
v_opt = td[-1]['q'].detach().cpu().numpy()
pd_template = vtk_make_pd(v_opt, md_src.f)

# mlp_template = mesh_label_prob_maps(md_src['pd'])
# lab_template = np.argmax(mlp_template, axis=1)
# vtk_set_cell_array(pd_template, 'plab', mlp_template)
# vtk_set_cell_array(pd_template, 'label', lab_template)

# Save the left side template
save_vtk(pd_template, f'tmp/template_shoot_built.vtk')


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



performing optimization...
loop 0 it  0 : loss 58590.33
loss 58538.117
loss 58070.35
loss 55783.45
loss 36880.926
loss 21507.525
loss 15940.157
loss 12152.864
loss 9967.647
loss 8353.243
loss 7264.7305
loss 6356.4136
loss 5804.233
loss 5108.625
loss 4657.018
loss 4310.9624
loop 0 it  1 : loss 4310.9624
loss 4037.0576
loss 3760.9807
loss 3412.2766
loss 3247.2654
loss 3000.9998
loss 2890.3289
loss 2770.5088
loss 2613.636
loss 2529.1516
loss 2419.1067
loss 2344.7275
loss 2247.0496
loss 2124.166
loss 2026.0796
loss 1932.8729
loop 0 it  2 : loss 1932.8729
loss 1838.4634
loss 1759.1803
loss 1699.0504
loss 1631.6964
loss 1575.5435
loss 1515.2178
loss 1469.5781
loss 1415.8208
loss 1379.9711
loss 1343.8052
loss 1310.4236
loss 1288.2694
loss 1252.809
loss 1213.577
loss 1184.6647
loop 0 it  3 : loss 1184.6647
loss 1166.1188
loss 1133.3737
loss 1105.3508
loss 1077.1577
loss 1046.4344
loss 1026.9844
loss 1004.40796
loss 978.8319
loss 955.822
loss 940.3514
loss 926.3813
loss 913.2293
loss 900.2513
l

In [36]:
(p_mean, q_mean)=Shooting(td[-1]['p'].mean(axis=0), td[-1]['q'], GaussKernel(sigma=sigma_lddmm))[-1]
pd_template = vtk_make_pd(q_mean.detach().cpu().numpy(), md_src.f)

# Add the original array
pd_template = vtk_set_cell_array(pd_template, 'plab', md_src.lp)

# mlp_template = mesh_label_prob_maps(md_src['pd'])
# lab_template = np.argmax(mlp_template, axis=1)
# vtk_set_cell_array(pd_template, 'plab', mlp_template)
# vtk_set_cell_array(pd_template, 'label', lab_template)

# Save the left side template
save_vtk(pd_template, f'tmp/template_shoot_built.vtk')

# Fitting to targets with OMT

In [None]:
# Now let's read these meshes all over again
md_aff = {}
for id, sd in data.items():
    md_aff[id] = MeshData(load_vtk(f'tmp/test_register_{id}.vtk'), device)

# Load the template previously saved (the left is equivalent to the right, except fo the flip)
md_template = MeshData(load_vtk(f'results_symm/template_shoot_built.vtk'), device)

# Create losses for each of the target meshes relative to the template
for i, (k,v) in enumerate(md_aff.items()):

    # Data loss with label similarity
    v['dataloss'] = lossVarifoldSurfWithLabels(
        md_template['ft'], v['vt'], v['ft'], md_template['lpt'], v['lpt'], 
        GaussLinKernelWithLabels(sigma=sigma))
    
    # Complete LDDMM loss
    v['loss'] = LDDMMloss(GaussKernel(sigma=sigma), v['dataloss'])

In [92]:
md_aff

{'104937L': <crashs.MeshData at 0x7f4d016f4490>,
 '106049L': <crashs.MeshData at 0x7f4d00dc8d30>,
 '106312R': <crashs.MeshData at 0x7f4cf002e070>,
 '113909R': <crashs.MeshData at 0x7f4cf002eb50>,
 '116748R': <crashs.MeshData at 0x7f4cf002e5b0>,
 '117243R': <crashs.MeshData at 0x7f4cf002ed90>,
 '117667R': <crashs.MeshData at 0x7f4cf002e0a0>,
 '118374L': <crashs.MeshData at 0x7f4cf002e910>,
 '118430R': <crashs.MeshData at 0x7f4cf002e6d0>,
 '120126L': <crashs.MeshData at 0x7f4cf03d8c40>,
 '120267L': <crashs.MeshData at 0x7f4cf03d8a00>,
 '120937L': <crashs.MeshData at 0x7f4cf02bc9d0>,
 '121250L': <crashs.MeshData at 0x7f4cf03f3490>}

# Try a different way to build the template, non-iteratively

In [64]:
import pymeshlab
from sklearn.decomposition import PCA


In [65]:
md_src = md_aff[id_src]
pd_result = vtk_make_pd(md_src.v, md_src.f)
pd_result = vtk_set_cell_array(pd_result, 'plab', md_src.lp)
save_vtk(pd_result, 'tmp/smooth_input.vtk')

In [66]:
# Try to generate an ellipsoid to use as a template instead of one subject
v_all = np.concatenate([ x.v for id,x in md_aff.items() ], 0)
pca = PCA(n_components=3)
pca.fit(v_all)
pca.get_covariance()

array([[  46.23693976,  -10.32145243,   22.4971508 ],
       [ -10.32145243,  263.13849297, -119.82218186],
       [  22.4971508 , -119.82218186,   77.77082017]])

In [67]:
ms = pymeshlab.MeshSet()
ms.create_sphere(subdiv = 4)
m0 = ms.mesh(0)
v_sph, f_sph = m0.vertex_matrix(), m0.face_matrix()
pd_sph = vtk_make_pd(v_sph, f_sph)
pd_sph = vtk_set_cell_array(pd_sph, 'plab', np.zeros((f_sph.shape[0],1)))
md_sph = MeshData(pd_sph, device)

In [68]:
# Find an affine transformation of the sphere that best aligns with the data 
# using the varifold measure

# Create losses for each of the target meshes
kernel = GaussLinKernel(sigma_varifold)
loss = { id: lossVarifoldSurf(md_sph.ft, v.vt, v.ft, kernel) for (id,v) in md_aff.items() }

# Create a parameter tensor for the sphere
b = torch.tensor(pca.mean_, dtype=torch.float32, device=device, requires_grad=True).contiguous()
A = torch.tensor(pca.get_covariance(), dtype=torch.float32, device=device, requires_grad=True).contiguous()

# Generate a combined objective function
optimizer = torch.optim.LBFGS([A,b], max_eval=16, max_iter=16, line_search_fn='strong_wolfe')
start = time.time()

def closure():
    optimizer.zero_grad()
    # Apply transformation to the sphere
    x = md_sph.vt
    y = (A @ x.T).T + b
    L = 0
    for i, (id,v) in enumerate(md_aff.items()):
        L = L + loss[id](y)
    L = L / len(md_aff.items())
    L.backward()
    return L

for i in range(30):
    print(f'Iter {i:03d}, Loss: {closure()}')
    optimizer.step(closure)

Iter 000, Loss: 10108611.0
Iter 001, Loss: 213723.046875
Iter 002, Loss: 213498.46875
Iter 003, Loss: 213409.53125
Iter 004, Loss: 213280.65625
Iter 005, Loss: 213204.390625
Iter 006, Loss: 213041.046875
Iter 007, Loss: 212903.515625
Iter 008, Loss: 212737.234375
Iter 009, Loss: 212451.328125
Iter 010, Loss: 211823.859375
Iter 011, Loss: 211607.953125
Iter 012, Loss: 211356.640625
Iter 013, Loss: 210616.21875
Iter 014, Loss: 210040.828125
Iter 015, Loss: 208005.625
Iter 016, Loss: 201438.078125
Iter 017, Loss: 199097.15625
Iter 018, Loss: 176922.046875
Iter 019, Loss: 93612.8203125
Iter 020, Loss: 92065.5234375
Iter 021, Loss: 92065.5234375
Iter 022, Loss: 92065.5234375
Iter 023, Loss: 92065.5234375
Iter 024, Loss: 92065.5234375
Iter 025, Loss: 92065.5234375
Iter 026, Loss: 92065.5234375
Iter 027, Loss: 92065.5234375
Iter 028, Loss: 92065.5234375
Iter 029, Loss: 92065.5234375


In [69]:
# Compute the new rotated sphere
v_sph_opt = (A @ md_sph.vt.T).T + b

# Peform remeshing of the sphere
ms = pymeshlab.MeshSet()
ms.add_mesh(pymeshlab.Mesh(vertex_matrix=v_sph_opt.detach().cpu().numpy(), face_matrix=md_sph.f))
ms.meshing_isotropic_explicit_remeshing()
v_ell, f_ell = ms.mesh(0).vertex_matrix(), ms.mesh(0).face_matrix()

pd_ell = vtk_make_pd(v_ell, f_ell)
pd_ell = vtk_set_cell_array(pd_ell, 'plab', np.zeros((f_ell.shape[0],1)))
md_ell = MeshData(pd_ell, device)

pd_sphere_opt = vtk_make_pd(v_ell, f_ell)
save_vtk(pd_sphere_opt, 'tmp/ellipsoid_best_fit.vtk')

In [70]:
# Now we need to initialize the labeling of the sphere. We can try directly to use OMT to 
# match the sphere to each of the meshes and maybe that's going to be good enough for getting
# the initial label distributions. If not, have to deform
def to_measure(points, triangles):
    """Turns a triangle into a weighted point cloud."""

    # Our mesh is given as a collection of ABC triangles:
    A, B, C = points[triangles[:, 0]], points[triangles[:, 1]], points[triangles[:, 2]]

    # Locations and weights of our Dirac atoms:
    X = (A + B + C) / 3  # centers of the faces
    S = torch.sqrt(torch.sum(torch.cross(B - A, C - A) ** 2, dim=1)) / 2  # areas of the faces

    # We return a (normalized) vector of weights + a "list" of points
    return S / torch.sum(S), X

# Compute optimal transport matching
def match_omt(vs, fs, vt, ft):
    (a_src, x_src) = to_measure(vs, fs)
    (a_trg, x_trg) = to_measure(vt, ft)
    x_src.requires_grad_(True)
    x_trg.requires_grad_(True)

    # Generate correspondence between models using OMT
    t_start = time.time()
    w_loss = geomloss.SamplesLoss("sinkhorn", p=2, blur=0.05, scaling=0.8, backend='multiscale', verbose=True)
    w_loss_value = w_loss(a_src, x_src, a_trg, x_trg)
    [w_loss_grad] = torch.autograd.grad(w_loss_value, x_src)
    w_match = x_src - w_loss_grad / a_src[:, None]
    t_end = time.time()

    print(f'OMT matching distance: {w_loss_value.item()}, time elapsed: {t_end-t_start}')
    return w_loss_value, w_match

In [71]:
plab_sample = []
for (id, md_i) in md_aff.items():
    f_omt, w_omt = match_omt(md_ell.vt, md_ell.ft, md_i.vt, md_i.ft)
    lp_omt = vtk_sample_cell_array_at_vertices(md_i.pd, md_i.lp, w_omt.detach().cpu().numpy())
    plab_sample.append(lp_omt)

234x387 clusters, computed at scale = 3.367
Successive scales :  73.469, 73.469, 58.775, 47.020, 37.616, 30.093, 24.074, 19.259, 15.408, 12.326, 9.861, 7.889, 6.311, 5.049, 4.039, 3.231, 2.585, 2.068, 1.654, 1.324, 1.059, 0.847, 0.678, 0.542, 0.434, 0.347, 0.278, 0.222, 0.178, 0.142, 0.114, 0.091, 0.073, 0.058, 0.050
Jump from coarse to fine between indices 14 (σ=4.039) and 15 (σ=3.231).
Keep 37619/90558 = 41.5% of the coarse cost matrix.
Keep 25434/54756 = 46.4% of the coarse cost matrix.
Keep 53261/149769 = 35.6% of the coarse cost matrix.
OMT matching distance: 11.80711555480957, time elapsed: 3.303954839706421
269x446 clusters, computed at scale = 3.169
Successive scales :  69.161, 69.161, 55.329, 44.263, 35.410, 28.328, 22.663, 18.130, 14.504, 11.603, 9.283, 7.426, 5.941, 4.753, 3.802, 3.042, 2.433, 1.947, 1.557, 1.246, 0.997, 0.797, 0.638, 0.510, 0.408, 0.327, 0.261, 0.209, 0.167, 0.134, 0.107, 0.086, 0.068, 0.055, 0.050
Jump from coarse to fine between indices 14 (σ=3.802) and 1

In [72]:
plab_sample_avg = np.stack(plab_sample).mean(0)

In [73]:
pd_sphere_opt_2 = vtk_set_cell_array(pd_sphere_opt, 'plab', softmax(plab_sample_avg * 10, axis=1))
save_vtk(pd_sphere_opt_2, 'tmp/ellipsoid_best_fit_plab.vtk')

In [74]:
# Use the sphere as a starting point for template fitting
md_src2 = MeshData(pd_sphere_opt_2, device)

In [75]:
# Map an array to new vertex locations
def vtk_sample_point_array_at_vertices(pd_src, array, x_samples):
    # Use the locator to sample from the halfway mesh
    loc = vtk.vtkCellLocator()
    loc.SetDataSet(pd_src)
    loc.BuildLocator()
    result = np.zeros((x_samples.shape[0], array.shape[1]))    
    cellId = vtk.reference(0)
    c = [0.0, 0.0, 0.0]
    subId = vtk.reference(0)
    d = vtk.reference(0.0)
    pcoord = [0.0, 0.0, 0.0]
    wgt = [0.0, 0.0, 0.0]
    xj = [0.0, 0.0, 0.0]
    for j in range(x_samples.shape[0]):
        loc.FindClosestPoint(x_samples[j,:], c, cellId, subId, d)
        cell = pd_src.GetCell(cellId)
        cell.EvaluatePosition(x_samples[j,:], c, subId, pcoord, d, wgt)
        result[j] = np.sum(np.stack([ array[cell.GetPointId(i),:] * w for i, w in enumerate(wgt) ]), 0)
    return result


# Given a set of sampling locations on a triangle mesh surface, generate arrays of
# vertex indices and weights that allow data from the source mesh to be sampled at
# the sampling locations. This can be used to interpolate point data, coordinates,
# etc from the source mesh or spatial transformations thereof 
def vtk_get_interpolation_arrays_for_sample(pd_src, x_samples):
    
    # Use the locator to sample from the halfway mesh
    loc = vtk.vtkCellLocator()
    loc.SetDataSet(pd_src)
    loc.BuildLocator()

    # Return data: array of vertex indices and weights
    v_res = np.zeros((x_samples.shape[0], 3), dtype=np.int32)
    w_res = np.zeros((x_samples.shape[0], 3), dtype=np.double)

    cellId = vtk.reference(0)
    c = [0.0, 0.0, 0.0]
    subId = vtk.reference(0)
    d = vtk.reference(0.0)
    pcoord = [0.0, 0.0, 0.0]
    wgt = [0.0, 0.0, 0.0]
    xj = [0.0, 0.0, 0.0]
    for j in range(x_samples.shape[0]):
        loc.FindClosestPoint(x_samples[j,:], c, cellId, subId, d)
        cell = pd_src.GetCell(cellId)
        cell.EvaluatePosition(x_samples[j,:], c, subId, pcoord, d, wgt)
        for i, w in enumerate(wgt):
            v_res[j,i], w_res[j,i] = cell.GetPointId(i), w

    return v_res, w_res



# Define a function that can fit a model to a population
def fit_model_to_population(md_root, md_targets, n_iter = 10, 
                            sigma_lddmm=5, sigma_root=20, sigma_varifold=5, gamma_lddmm=0.1):

    # LDDMM kernels
    device = md_root.vt.device
    K_root = GaussKernel(sigma=torch.tensor(sigma_root, dtype=torch.float32, device=device))
    K_temp = GaussKernel(sigma=torch.tensor(sigma_lddmm, dtype=torch.float32, device=device))
    K_vari = GaussLinKernelWithLabels(torch.tensor(sigma_varifold, dtype=torch.float32, device=device), md_root.lp.shape[1])

    # Create losses for each of the target meshes
    d_loss = { id: lossVarifoldSurfWithLabels(md_root.ft, v.vt, v.ft, md_root.lpt, v.lpt, K_vari) for id,v in md_targets.items() }
            
    # Create the root->template points/momentum, as well as the template->subject momenta
    q_root = torch.tensor(md_root.vt, dtype=torch.float32, device=device, requires_grad=True).contiguous()
    p_root = torch.zeros(md_root.vt.shape, dtype=torch.float32, device=device, requires_grad=True).contiguous()
    p_temp = torch.zeros((len(md_targets),) + md_root.vt.shape, dtype=torch.float32, device=device, requires_grad=True).contiguous()

    # Create the optimizer
    start = time.time()
    optimizer = torch.optim.LBFGS([p_root, p_temp], max_eval=16, max_iter=16, line_search_fn='strong_wolfe')

    def closure():
        optimizer.zero_grad()

        # Shoot root->template
        _, q_temp = Shooting(p_root, q_root, K_root)[-1]

        # Make the momenta applied to the template average out to zero
        p_temp_z = p_temp - torch.mean(p_temp, 0, keepdim=True)

        # Compute the loss
        L = 0
        for i, (id,v) in enumerate(md_targets.items()):
            _, q_i = Shooting(p_temp_z[i,:], q_temp, K_temp)[-1]
            L = L + gamma_lddmm * Hamiltonian(K_temp)(p_temp_z[i,:], q_temp) + d_loss[id](q_i)
        L = L / len(md_targets.items())
        L.backward()
        return L

    # Perform optimization
    for i in range(n_iter):
        print(f'Iteration {i:03d}  Loss {closure()}')
        optimizer.step(closure)

    print(f'Optimization (L-BFGS) time: {round(time.time() - start, 2)} seconds')

    # Return the root model and the momenta
    p_temp_z = p_temp - torch.mean(p_temp, 0, keepdim=True)
    return p_root, p_temp_z 


# Compute label probability sampling using shooting and OMT
# def map_array_to_fittedect_shooting_omt(q_root, p_root, p_temp, md_targets, array):
#     _, q_temp = Shooting(p_root, q_root, K)[-1]
#     plab_sample = []
#     p_temp_z = p_temp - torch.mean(p_temp, 0, keepdim=True)
#     for (id, md_i) in md_aff.items():
#         _, q_i = Shooting(p_temp_z[i,:], q_temp, K)[-1]
#         f_omt, w_omt = match_omt(q_i, md_src2.ft, md_i.vt, md_i.ft)
#         lp_omt = vtk_sample_cell_array_at_vertices(md_i.pd, md_i.lp, w_omt.detach().cpu().numpy())
#         plab_sample.append(lp_omt)


def shoot_root_to_template(md_root, p_root, sigma_root=20):
    device = md_root.vt.device
    K_root = GaussKernel(sigma=torch.tensor(sigma_root, dtype=torch.float32, device=device))
    _, q_temp = Shooting(p_root, md_root.vt.clone().requires_grad_(True).contiguous(), K_root)[-1]
    pd = vtk_make_pd(q_temp.detach().cpu().numpy(), md_root.f)
    pd = vtk_set_cell_array(pd, 'plab', md_root.lp)
    return MeshData(pd, device=q_temp.device)


# Update the template by remeshing and updating probability labels
def update_model_by_remeshing(md_root, md_targets, p_root, p_temp_z, 
                              sigma_lddmm=5, sigma_root=20):

    # LDDMM kernels
    device = md_root.vt.device
    K_root = GaussKernel(sigma=torch.tensor(sigma_root, dtype=torch.float32, device=device))
    K_temp = GaussKernel(sigma=torch.tensor(sigma_lddmm, dtype=torch.float32, device=device))

    # Shoot from root to obtain the template
    q_root = md_root.vt.clone().requires_grad_(True).contiguous()
    _, q_temp = Shooting(p_root, q_root, K_root)[-1]
    pd_template = vtk_make_pd(q_temp.detach().cpu().numpy(), md_root.f)

    # Sample and average the plab array from the subjects using OMT
    plab_sample = []
    for i, (id, md_i) in enumerate(md_targets.items()):
        _, q_i = Shooting(p_temp_z[i,:], q_temp, K_temp)[-1]
        _, w_omt = match_omt(q_i, md_root.ft, md_i.vt, md_i.ft)
        lp_omt = vtk_sample_cell_array_at_vertices(md_i.pd, md_i.lp, w_omt.detach().cpu().numpy())
        plab_sample.append(lp_omt)
    plab_sample_avg = np.stack(plab_sample).mean(0)

    # Apply remeshing to the template
    ms = pymeshlab.MeshSet()
    ms.add_mesh(pymeshlab.Mesh(vertex_matrix=q_temp.detach().cpu().numpy(), face_matrix=md_root.f))
    ms.meshing_isotropic_explicit_remeshing()
    v_remesh, f_remesh = ms.mesh(0).vertex_matrix(), ms.mesh(0).face_matrix()
    pd_remesh = vtk_make_pd(v_remesh, f_remesh)

    # Apply the remeshing to the plab array
    _, w_omt = match_omt(torch.tensor(v_remesh, dtype=torch.float32, device=device), 
                         torch.tensor(f_remesh, dtype=torch.long, device=device),
                         q_temp, md_root.ft)
    lp_remesh = vtk_sample_cell_array_at_vertices(pd_template, plab_sample_avg, w_omt.detach().cpu().numpy())
    pd_remesh = vtk_set_cell_array(pd_remesh, 'plab', softmax(lp_remesh * 10, 1))

    # Return the new template as MeshData
    return MeshData(pd_remesh, device=md_root.vt.device)


# Map template into subject space and return mesh
def map_template_to_subject(md_temp, md_target, p_temp, sigma_lddmm=5):

    # LDDMM kernels
    device = md_temp.vt.device
    K_temp = GaussKernel(sigma=torch.tensor(sigma_lddmm, dtype=torch.float32, device=device))

    # Shoot from template to subject and save as polydata/meshdata
    q_temp = md_temp.vt.clone().requires_grad_(True).contiguous()
    _, q_fitted = Shooting(p_temp, q_temp, K_temp)[-1]
    pd_fitted = vtk_make_pd(q_fitted.detach().cpu().numpy(), md_temp.f)
    pd_fitted = vtk_set_cell_array(pd_fitted, 'plab', md_temp.lp)

    # Match the subject via OMT, i.e. every template vertex is mapped to somewhere on the
    # subject mesh, this fits more closely than LDDMM but might break topology
    # _, w_omt = match_omt(md_target.vt, md_target.ft, q_fitted, md_temp.ft)
    _, w_omt = match_omt(q_fitted, md_temp.ft, md_target.vt, md_target.ft)
    pd_omt = vtk_clone_pd(pd_fitted)
    pd_omt = vtk_set_cell_array(pd_omt, 'match', w_omt.detach().cpu().numpy())
    vtk_cell_array_to_point_array(pd_omt, 'match')
    v_omt = vtk_get_point_array(pd_omt, 'match')

    # Create a clean model to return
    pd_omt = vtk_make_pd(v_omt, md_temp.f)
    pd_omt = vtk_set_cell_array(pd_omt, 'plab', md_temp.lp)
    pd_omt = vtk_set_cell_array(pd_omt, 'match', w_omt.detach().cpu().numpy())

    # Get the interpolation arrays from the matching and place them into the fitted template
    v_int, w_int = vtk_get_interpolation_arrays_for_sample(md_target.pd, v_omt)
    pd_fitted = vtk_set_point_array(pd_fitted, 'omt_v_int', v_int)
    pd_fitted = vtk_set_point_array(pd_fitted, 'omt_w_int', w_int)

    return pd_fitted, pd_omt


In [76]:
def build_template_multistage(md_root, md_targets, schedule, 
                              sigma_lddmm=5, sigma_root=20, sigma_varifold=5,
                              gamma_lddmm=0.1):

    # Iterate over the schedule
    for i, iter in enumerate(schedule):

        # Print iteration
        print(f'*** TEMPLATE BUILD STAGE {i} ***')

        # Fit the model to the population
        p_root, p_temp_z = fit_model_to_population(md_root, md_targets, iter,
                                                   sigma_lddmm=sigma_lddmm, sigma_root=sigma_root, 
                                                   sigma_varifold=sigma_varifold, gamma_lddmm=gamma_lddmm)

        # Compute the template by forward shooting
        md_temp = shoot_root_to_template(md_root, p_root, sigma_root=sigma_root)
        save_vtk(md_temp.pd, f'tmp/template_iter{i:02d}.vtk')

        # Remesh the template
        md_remesh = update_model_by_remeshing(md_root, md_targets, p_root, p_temp_z, sigma_lddmm=sigma_lddmm, sigma_root=sigma_root)
        save_vtk(md_remesh.pd, f'tmp/template_iter{i:02d}_remesh.vtk')

        # Make the template the new root
        md_root = md_remesh

    # Return the model from the last iteration
    return md_temp, p_temp_z


In [77]:
template['left'].get_lddmm_gamma()

1.0

In [78]:
# Run the whole template building pipeline
# template['left'].get_varifold_sigma()
md_temp, p_temp = build_template_multistage(md_src2, md_aff, schedule=[10, 10, 10, 50], 
                                            sigma_root = 2.4 * template['left'].get_lddmm_sigma(), 
                                            sigma_lddmm = template['left'].get_lddmm_sigma(), 
                                            sigma_varifold = 5,
                                            gamma_lddmm=0.05)

*** TEMPLATE BUILD STAGE 0 ***



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



Iteration 000  Loss 73180.25
Iteration 001  Loss 48189.68359375
Iteration 002  Loss 32612.52734375
Iteration 003  Loss 24865.779296875
Iteration 004  Loss 20050.994140625
Iteration 005  Loss 16614.220703125
Iteration 006  Loss 13710.7568359375
Iteration 007  Loss 11285.1953125
Iteration 008  Loss 10152.326171875
Iteration 009  Loss 9215.896484375
Optimization (L-BFGS) time: 511.72 seconds
356x409 clusters, computed at scale = 3.313
Successive scales :  72.302, 72.302, 57.842, 46.274, 37.019, 29.615, 23.692, 18.954, 15.163, 12.130, 9.704, 7.763, 6.211, 4.969, 3.975, 3.180, 2.544, 2.035, 1.628, 1.302, 1.042, 0.834, 0.667, 0.533, 0.427, 0.341, 0.273, 0.219, 0.175, 0.140, 0.112, 0.090, 0.072, 0.057, 0.050
Jump from coarse to fine between indices 14 (σ=3.975) and 15 (σ=3.180).
Keep 51769/145604 = 35.6% of the coarse cost matrix.
Keep 46330/126736 = 36.6% of the coarse cost matrix.
Keep 57027/167281 = 34.1% of the coarse cost matrix.
OMT matching distance: 2.1946306228637695, time elapsed: 3

In [79]:
# Save the template and the momenta
pd_temp_save = vtk_clone_pd(md_temp.pd)
for i, (id, md_i) in enumerate(md_aff.items()):
    vtk_set_point_array(pd_temp_save, f'momenta_{id}', p_temp[i,:,:].detach().cpu().numpy())
save_vtk(pd_temp_save, 'tmp/template_final_with_momenta.vtk')

In [80]:
# Match template to each subject and save the resulting meshes
for i, (id, md_i) in enumerate(md_aff.items()):
    # We want to map the template to the halfway surface of the MTL cortex in subject space
    # But that mesh has different number of vertices than the meshes we fitted in the template
    # building process. So we need to apply affine transform to the full-resolution inflated
    # mesh, then do OMT matching between the warped template and this target, then use this to
    # get coordinates of template vertices the space of the native halfway surface
    pd_infl_fullres = load_vtk(data[id]['workspace'].affine_moving)
    A = np.loadtxt(f'tmp/affine_to_template_{id}.mat')
    v_infl_fullres = vtk_get_points(pd_infl_fullres)
    v_infl_fullres_affine = (A[0:3,0:3] @ v_infl_fullres.T).T + A[0:3,3]

    # Send this mesh to the device
    vtk_set_points(pd_infl_fullres, v_infl_fullres_affine)
    vtk_set_cell_array(pd_infl_fullres, 'plab', np.zeros((v_infl_fullres_affine.shape[0],1)))
    md_infl_fullres = MeshData(pd_infl_fullres, device)

    # Get template mesh and OMT fitted mesh
    pd_i_fit, pd_i_omt = map_template_to_subject(md_temp, md_i, p_temp[i], 
                                                 sigma_lddmm = template['left'].get_lddmm_sigma())
    
    # These interpolation arrays map every template vertex to a face on the subject mesh. But this
    # mesh has been downsampled from a full-resolution mesh. In order to map to the original mesh,
    # we have to account for the downsampling
    v_int = vtk_get_point_array(pd_i_fit, 'omt_v_int').astype(np.int32)
    w_int = vtk_get_point_array(pd_i_fit, 'omt_w_int')

    # Get the full-resolution mesh and corresponding downsampled mesh
    pd_fullres, pd_lowres = data[id]['pd_input'], data[id]['pd_ds']

    # Perform OMT matching between these two meshes
    _, w_omt = match_omt(torch.tensor(vtk_get_points(pd_lowres), dtype=torch.float32, device=device), 
                         torch.tensor(vtk_get_triangles(pd_lowres), dtype=torch.long, device=device), 
                         torch.tensor(vtk_get_points(pd_fullres), dtype=torch.float32, device=device), 
                         torch.tensor(vtk_get_triangles(pd_fullres), dtype=torch.long, device=device))
    
    # The target locations stored in w_omt are at the triangle centers, need to remap them to
    # vertices before we can get interpolation weights
    pd_omt = vtk_clone_pd(pd_lowres)
    pd_omt = vtk_set_cell_array(pd_omt, 'match', w_omt.detach().cpu().numpy())
    vtk_cell_array_to_point_array(pd_omt, 'match')
    v_omt = vtk_get_point_array(pd_omt, 'match')
    v_int_fr, w_int_fr = vtk_get_interpolation_arrays_for_sample(pd_fullres, v_omt)

    # We also want to go all the way back to the original space. Load the original space mesh
    pd_hw_native = load_vtk(data[id]['workspace'].fn_cruise('mtl_avg_l2m-mesh-ras.vtk'))
    v_native = vtk_get_points(pd_hw_native)
    
    # Now we need to perform two interpolations
    v_fullres_to_lowres = np.einsum('vij,vi->vj', v_native[v_int_fr,:], w_int_fr)
    v_temp_to_native = np.einsum('vij,vi->vj', v_fullres_to_lowres[v_int,:], w_int)

    #pd_i_fit, pd_i_omt = map_template_to_subject(md_temp, md_infl_fullres, p_temp[i], 
    #                                             sigma_lddmm = template['left'].get_lddmm_sigma())
    
    # The original mesh has a different number of vertices compared to the 
    # v_int = vtk_get_point_array(pd_i_fit, 'omt_v_int').astype(np.int32)
    # w_int = vtk_get_point_array(pd_i_fit, 'omt_w_int')
    # v_temp_to_native = np.einsum('vij,vi->vj', v_native[v_int,:], w_int)
    pd_temp_omt_to_native = vtk_make_pd(v_temp_to_native, vtk_get_triangles(pd_i_fit))
    pd_temp_omt_to_native = vtk_set_cell_array(pd_temp_omt_to_native, 'plab', vtk_get_cell_array(pd_i_fit, 'plab'))

    save_vtk(pd_i_fit, f'tmp/template_fit_to_{id}.vtk')
    save_vtk(pd_i_omt, f'tmp/template_omt_to_{id}.vtk')
    save_vtk(pd_temp_omt_to_native, f'tmp/template_omt_to_avg_ras_{id}.vtk')

400x412 clusters, computed at scale = 3.303
Successive scales :  72.074, 72.074, 57.659, 46.127, 36.902, 29.521, 23.617, 18.894, 15.115, 12.092, 9.674, 7.739, 6.191, 4.953, 3.962, 3.170, 2.536, 2.029, 1.623, 1.298, 1.039, 0.831, 0.665, 0.532, 0.425, 0.340, 0.272, 0.218, 0.174, 0.139, 0.112, 0.089, 0.071, 0.057, 0.050
Jump from coarse to fine between indices 14 (σ=3.962) and 15 (σ=3.170).
Keep 56998/164800 = 34.6% of the coarse cost matrix.
Keep 56258/160000 = 35.2% of the coarse cost matrix.
Keep 57742/169744 = 34.0% of the coarse cost matrix.
OMT matching distance: 0.29587888717651367, time elapsed: 4.4686267375946045
389x393 clusters, computed at scale = 3.378
Successive scales :  73.725, 73.725, 58.980, 47.184, 37.747, 30.198, 24.158, 19.327, 15.461, 12.369, 9.895, 7.916, 6.333, 5.066, 4.053, 3.242, 2.594, 2.075, 1.660, 1.328, 1.062, 0.850, 0.680, 0.544, 0.435, 0.348, 0.279, 0.223, 0.178, 0.143, 0.114, 0.091, 0.073, 0.058, 0.050
Jump from coarse to fine between indices 14 (σ=4.053) 

IndexError: index 9993 is out of bounds for axis 0 with size 9990

In [94]:
data['118374L']['pd_ds'].GetNumberOfPoints(), md_aff['118374L'].v.shape

(9990, (9994, 3))

In [87]:
load_vtk(data['104937L']['workspace'].affine_moving).GetNumberOfPoints()

26828

# Compute consensus labeling

In [1]:
import SimpleITK as sitk

# Load the template again


# Ok, so we have the template in native space, we can now sample the complete selection of labels. Load the scans.
xvashs_label = np.zeros((md_temp.v.shape[0], len(md_aff)))
for i, (id, md_i) in enumerate(md_aff.items()):
    # Read the global segmentation
    fn_global_seg = f'/data/pauly2/ashs_xv/input/{id}/{id}_label_global.nii.gz'
    img = sitk.ReadImage(fn_global_seg, outputPixelType=sitk.sitkFloat32)

    # Load the mesh
    pd_temp_omt_to_native = load_vtk(f'tmp/template_omt_to_avg_ras_{id}.vtk')

    # Sample all the vertices - nearest neighbor
    v = vtk_get_points(pd_temp_omt_to_native)
    v_label = np.zeros(v.shape[0])
    for j in range(v.shape[0]):
        x = [ -v[j,0], -v[j,1], v[j,2] ]
        try:
            l = img.EvaluateAtPhysicalPoint(x, sitk.sitkNearestNeighbor)
        except:
            l = 0
        xvashs_label[j,i] = l

    # Assign to the mesh
    # vtk_set_point_array(pd_temp_omt_to_native, 'label_xv', v_label);
    # save_vtk(pd_temp_omt_to_native, f'tmp/template_omt_to_avg_ras_withlabel_{id}.vtk')
    # break
        
# Consensus labeling
lab_consensus = np.argmax([np.sum(xvashs_label == v, 1) for v in range(69)], 0)
pd_consensus = vtk_clone_pd(md_temp.pd)
vtk_set_point_array(pd_consensus, 'xvashs', lab_consensus)
save_vtk(pd_consensus, 'tmp/template_global_label.vtk')


NameError: name 'np' is not defined

array([ 0, 68, 21, ..., 51, 12,  9])

In [174]:
# Once the template has been constructed, we need to warp it to the population
# using shooting and then shooting plus OMT



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



loop <pymeshlab.pmeshlab.Mesh object at 0x7f4ce1e39330> it  0 : loss 68562.21
loss 68269.29
loss 65726.28
loss 49957.67
loss 41948.418
loss 35315.43
loss 31865.328
loss 27524.482
loss 24512.904
loss 21901.518
loss 21002.684
loss 18434.291
loss 17767.508
loss 16425.61
loss 14754.634
loss 14320.603
loop <pymeshlab.pmeshlab.Mesh object at 0x7f4ce1e39330> it  1 : loss 14320.603
loss 12964.27
loss 12482.753
loss 11793.728
loss 11102.531
loss 10320.76
loss 9811.943
loss 9344.442
loss 8759.904
loss 8178.0654
loss 7786.9263
loss 7374.3496
loss 6948.0205
loss 6523.957
loss 6197.33
loss 5943.7188
loop <pymeshlab.pmeshlab.Mesh object at 0x7f4ce1e39330> it  2 : loss 5943.7188
loss 5600.515
loss 5372.085
loss 5126.4575
loss 4972.19
loss 4778.7173
loss 4606.7837
loss 4421.307
loss 4305.4272
loss 4178.4614
loss 3998.2537
loss 3849.0493
loss 3742.4143
loss 3622.4014
loss 3509.8123
loss 3404.9758
loop <pymeshlab.pmeshlab.Mesh object at 0x7f4ce1e39330> it  3 : loss 3404.9758
loss 3299.0818
loss 3222.964

In [243]:
md_src3 = update_model_by_remeshing(md_src2, md_aff, p_root, p_temp_z)
save_vtk(md_src3.pd, 'tmp/template_stage2.vtk')

327x356 clusters, computed at scale = 3.157
Successive scales :  68.903, 68.903, 55.122, 44.098, 35.278, 28.222, 22.578, 18.062, 14.450, 11.560, 9.248, 7.398, 5.919, 4.735, 3.788, 3.030, 2.424, 1.939, 1.552, 1.241, 0.993, 0.794, 0.636, 0.508, 0.407, 0.325, 0.260, 0.208, 0.167, 0.133, 0.107, 0.085, 0.068, 0.055, 0.050
Jump from coarse to fine between indices 14 (σ=3.788) and 15 (σ=3.030).
Keep 42597/116412 = 36.6% of the coarse cost matrix.
Keep 39065/106929 = 36.5% of the coarse cost matrix.
Keep 46508/126736 = 36.7% of the coarse cost matrix.
OMT matching distance: 2.1609036922454834, time elapsed: 2.9839351177215576
349x377 clusters, computed at scale = 3.066
Successive scales :  66.898, 66.898, 53.519, 42.815, 34.252, 27.402, 21.921, 17.537, 14.030, 11.224, 8.979, 7.183, 5.747, 4.597, 3.678, 2.942, 2.354, 1.883, 1.506, 1.205, 0.964, 0.771, 0.617, 0.494, 0.395, 0.316, 0.253, 0.202, 0.162, 0.129, 0.104, 0.083, 0.066, 0.053, 0.050
Jump from coarse to fine between indices 14 (σ=3.678) a

In [245]:
p_root_3, p_temp_z_3 = fit_model_to_population(md_src3, md_aff, 10)


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



Iteration 000  Loss 60426.01171875
Iteration 001  Loss 12709.44921875
Iteration 002  Loss 6302.05908203125
Iteration 003  Loss 3513.724609375
Iteration 004  Loss 2372.8330078125
Iteration 005  Loss 1814.6024169921875
Iteration 006  Loss 1471.04736328125
Iteration 007  Loss 1235.9661865234375
Iteration 008  Loss 1092.425537109375
Iteration 009  Loss 981.1105346679688
Optimization (L-BFGS) time: 465.75 seconds


In [249]:
K = GaussKernel(sigma=sigma_lddmm)
_, q_temp_3 = Shooting(p_root_3, md_src3.vt.clone().requires_grad_(True).contiguous(), K)[-1]
pd_result_3 = vtk_make_pd(q_temp_3.detach().cpu().numpy(), md_src3.f)
pd_result_3 = vtk_set_cell_array(pd_result_3, 'plab', md_src3.lp)
save_vtk(pd_result, 'tmp/template_stage2_def.vtk')

for i, (id,v) in enumerate(md_aff.items()):
    _, q_i = Shooting(p_temp_z_3[i,:], q_temp_3, K)[-1]
    # loss_h = 0.1 * Hamiltonian(K)(p_temp_z[i,:], q_temp)
    # loss_d = d_loss[id](q_i)
    # print(f'{id}: loss_h: {loss_h}, loss_d: {loss_d}')
    pd_result = vtk_make_pd(q_i.detach().cpu().numpy(), md_src3.f)
    pd_result = vtk_set_cell_array(pd_result, 'plab', md_src3.lp)
    save_vtk(pd_result, f'tmp/template_stage2_def_to_{id}.vtk')

In [176]:
_, q_temp = Shooting(p_root, q_root, K)[-1]
pd_result = vtk_make_pd(q_temp.detach().cpu().numpy(), md_src2.f)
pd_result = vtk_set_cell_array(pd_result, 'plab', md_src2.lp)
save_vtk(pd_result, 'tmp/template_from_taubin.vtk')

In [177]:
p_temp_z = p_temp - torch.mean(p_temp, 0, keepdim=True)
for i, (id,v) in enumerate(md_aff.items()):
    _, q_i = Shooting(p_temp_z[i,:], q_temp, K)[-1]
    loss_h = 0.1 * Hamiltonian(K)(p_temp_z[i,:], q_temp)
    loss_d = d_loss[id](q_i)
    print(f'{id}: loss_h: {loss_h}, loss_d: {loss_d}')
    pd_result = vtk_make_pd(q_i.detach().cpu().numpy(), md_src2.f)
    pd_result = vtk_set_cell_array(pd_result, 'plab', md_src2.lp)
    save_vtk(pd_result, f'tmp/template_to_{id}.vtk')



104937L: loss_h: 92.9029769897461, loss_d: 299.8125
106049L: loss_h: 53.60837936401367, loss_d: 566.75
106312R: loss_h: 61.9499626159668, loss_d: 592.5625
113909R: loss_h: 42.26411819458008, loss_d: 158.8125
116748R: loss_h: 65.7442626953125, loss_d: 283.09375
117243R: loss_h: 47.26437759399414, loss_d: 273.9375
117667R: loss_h: 74.99248504638672, loss_d: 444.53125
118374L: loss_h: 45.83708572387695, loss_d: 170.875
118430R: loss_h: 73.35762023925781, loss_d: 232.75
120126L: loss_h: 43.98615646362305, loss_d: 253.875
120267L: loss_h: 46.03990173339844, loss_d: 218.84375
120937L: loss_h: 59.33658981323242, loss_d: 399.78125
121250L: loss_h: 50.62836837768555, loss_d: 181.0


In [190]:
# Use OMT after registration to remap labels to the template
_, q_temp = Shooting(p_root, q_root, K)[-1]
plab_sample = []
p_temp_z = p_temp - torch.mean(p_temp, 0, keepdim=True)
for (id, md_i) in md_aff.items():
    _, q_i = Shooting(p_temp_z[i,:], q_temp, K)[-1]
    f_omt, w_omt = match_omt(q_i, md_src2.ft, md_i.vt, md_i.ft)
    lp_omt = vtk_sample_cell_array_at_vertices(md_i.pd, md_i.lp, w_omt.detach().cpu().numpy())
    plab_sample.append(lp_omt)
    
plab_sample_avg = np.stack(plab_sample).mean(0)

327x356 clusters, computed at scale = 3.157
Successive scales :  68.903, 68.903, 55.122, 44.098, 35.278, 28.222, 22.578, 18.062, 14.450, 11.560, 9.248, 7.398, 5.919, 4.735, 3.788, 3.030, 2.424, 1.939, 1.552, 1.241, 0.993, 0.794, 0.636, 0.508, 0.407, 0.325, 0.260, 0.208, 0.167, 0.133, 0.107, 0.085, 0.068, 0.055, 0.050
Jump from coarse to fine between indices 14 (σ=3.788) and 15 (σ=3.030).
Keep 42596/116412 = 36.6% of the coarse cost matrix.
Keep 39065/106929 = 36.5% of the coarse cost matrix.
Keep 46508/126736 = 36.7% of the coarse cost matrix.
OMT matching distance: 2.1609044075012207, time elapsed: 2.936350107192993
349x377 clusters, computed at scale = 3.066
Successive scales :  66.898, 66.898, 53.519, 42.815, 34.252, 27.402, 21.921, 17.537, 14.030, 11.224, 8.979, 7.183, 5.747, 4.597, 3.678, 2.942, 2.354, 1.883, 1.506, 1.205, 0.964, 0.771, 0.617, 0.494, 0.395, 0.316, 0.253, 0.202, 0.162, 0.129, 0.104, 0.083, 0.066, 0.053, 0.050
Jump from coarse to fine between indices 14 (σ=3.678) an

In [233]:
pd_template = vtk_make_pd(q_temp.detach().cpu().numpy(), md_src2.f)
pd_template = vtk_set_cell_array(pd_template, 'plab_orig', md_src2.lp)
pd_template = vtk_set_cell_array(pd_template, 'plab', softmax(plab_sample_avg * 10, 1))
for i, (id, md_i) in enumerate(md_aff.items()):
    pd_template = vtk_set_point_array(pd_template, f'momenta_{id}', p_temp_z[i,:,:].detach().cpu().numpy())

save_vtk(pd_template, 'tmp/template_before_remesh.vtk')

In [234]:
# Let's try to remesh the template
ms = pymeshlab.MeshSet()
ms.add_mesh(pymeshlab.Mesh(vertex_matrix=q_temp.detach().cpu().numpy(), face_matrix=md_src2.f))
ms.meshing_isotropic_explicit_remeshing()
v_remesh, f_remesh = ms.mesh(0).vertex_matrix(), ms.mesh(0).face_matrix()
pd_remesh = vtk_make_pd(v_remesh, f_remesh)

# Apply remeshing to the momenta
p_temp_remesh_src = []
for i, (id, md_i) in enumerate(md_aff.items()):
    p_i = vtk_sample_point_array_at_vertices(pd_template, p_temp_z[i,:,:].detach().cpu().numpy(), v_remesh)
    p_temp_remesh_src.append(p_i)
    pd_remesh = vtk_set_point_array(pd_remesh, f'momenta_{id}', p_i)

# Apply the remeshing to the plab array
f_omt, w_omt = match_omt(torch.tensor(v_remesh, dtype=torch.float32, device=device), 
                         torch.tensor(f_remesh, dtype=torch.long, device=device),
                         q_temp, md_src2.ft)

lp_remesh = vtk_sample_cell_array_at_vertices(pd_template, plab_sample_avg, w_omt.detach().cpu().numpy())
pd_remesh = vtk_set_cell_array(pd_remesh, 'plab', softmax(lp_remesh * 10, 1))

# Save the remeshed array
save_vtk(pd_remesh, 'tmp/template_after_remesh.vtk')


356x350 clusters, computed at scale = 2.803
Successive scales :  61.166, 61.166, 48.933, 39.146, 31.317, 25.054, 20.043, 16.034, 12.827, 10.262, 8.210, 6.568, 5.254, 4.203, 3.363, 2.690, 2.152, 1.722, 1.377, 1.102, 0.881, 0.705, 0.564, 0.451, 0.361, 0.289, 0.231, 0.185, 0.148, 0.118, 0.095, 0.076, 0.061, 0.050
Jump from coarse to fine between indices 14 (σ=3.363) and 15 (σ=2.690).
Keep 43915/124600 = 35.2% of the coarse cost matrix.
Keep 44984/126736 = 35.5% of the coarse cost matrix.
Keep 42898/122500 = 35.0% of the coarse cost matrix.
OMT matching distance: 0.02683163434267044, time elapsed: 2.844156503677368


In [235]:
md_remesh = MeshData(pd_remesh, device)
_, q_test = Shooting(torch.tensor(p_temp_remesh_src[0], dtype=torch.float32, device=device, requires_grad=True).contiguous(),
                     torch.tensor(v_remesh, dtype=torch.float32, device=device, requires_grad=True).contiguous(), K)[-1]

pd_test = vtk_make_pd(q_test.detach().cpu().numpy(), f_remesh)
save_vtk(pd_test,'tmp/remesh_to_subj_shooting_test.vtk')

116x350 clusters, computed at scale = 2.803
Successive scales :  61.166, 61.166, 48.933, 39.146, 31.317, 25.054, 20.043, 16.034, 12.827, 10.262, 8.210, 6.568, 5.254, 4.203, 3.363, 2.690, 2.152, 1.722, 1.377, 1.102, 0.881, 0.705, 0.564, 0.451, 0.361, 0.289, 0.231, 0.185, 0.148, 0.118, 0.095, 0.076, 0.061, 0.050
Jump from coarse to fine between indices 14 (σ=3.363) and 15 (σ=2.690).
Keep 19067/40600 = 47.0% of the coarse cost matrix.
Keep 7906/13456 = 58.8% of the coarse cost matrix.
Keep 42898/122500 = 35.0% of the coarse cost matrix.
OMT matching distance: 50.010459899902344, time elapsed: 1.9429802894592285


In [None]:
# Perform OMT between remeshed mesh and the 
pd_sphere_opt_2 = vtk_set_cell_array(pd_sphere_opt, 'plab', softmax(plab_sample_avg * 10, axis=1))
save_vtk(pd_sphere_opt_2, 'tmp/ellipsoid_best_fit_plab.vtk')

In [183]:
f_remesh.shape

(12542, 3)

In [181]:
md_sph.f.shape

(5120, 3)