In [1]:
import argparse
import pathlib
import json
import os
import torch
import numpy as np
import time
import sys
import matplotlib.pyplot as plt
import geomloss
from scipy.special import softmax
import pymeshlab
from sklearn.decomposition import PCA
import SimpleITK as sitk
import nighres

sys.path.append('/data/pauly2/tk/crashs')
from vtkutil import *
from lddmm import *
from crashs import MeshData, Template, ASHSFolder, Workspace, ashs_output_to_cruise_input, run_cruise, cruise_postproc

%cd /data/pauly2/tk/crashs

/data/pauly2/tk/crashs


# Load the inputs and parameters

In [168]:
# Load the template
template = Template('/data/pauly2/ashs_xv/manifest/template_init_dir')

# Load the ASHS json
with open('/data/pauly2/ashs_xv/manifest/crashs_input_fold_0.json') as fd:
    ashs_input_desc = json.load(fd)

# Prepare device
# device = torch.device(args.device) if torch.cuda.is_available() else 'cpu'
device = torch.device(torch.cuda.current_device() if torch.cuda.is_available() else 'cpu')
print("Is cuda available?", torch.cuda.is_available())
print("Device count?", torch.cuda.device_count())
print("Current device?", torch.cuda.current_device())
print("Device name? ", torch.cuda.get_device_name(torch.cuda.current_device()))

# Keep track of ASHS importers and workspaces created
data = {}

# Run basic Nighres for each subject
for d in ashs_input_desc:
    id = d['id']
    side = d['side']

    # Create the output dir
    out_dir = os.path.join('/data/pauly2/ashs_xv/work/fold_0/crashs_build', id)
    os.makedirs(out_dir, exist_ok=True)
    workspace = Workspace(out_dir, id, side)

    # Store the data
    data[id] = { 'side': side, 'workspace': workspace }

Is cuda available? True
Device count? 1
Current device? 0
Device name?  Quadro RTX 5000


In [5]:
# From the template directory, load the left/right flip file. 
flip_lr = np.loadtxt(os.path.join(template.root, 'ashs_template_flip.mat'))

# Set the sigma tensors
# sigma_varifold = torch.tensor([template['left'].get_varifold_sigma()], dtype=torch.float32, device=device)
sigma_varifold = torch.tensor([5], dtype=torch.float32, device=device)
sigma_lddmm = torch.tensor([template.get_lddmm_sigma()], dtype=torch.float32, device=device)

# Load each of the meshes that will be used to build the template
md = {}
md_ds = {}
for id, sd in data.items():
    
    # Depending on the side, apply flip_lr as the transform
    transform = flip_lr if sd['side'] == 'right' else None

    # Load the mesh data (inflated avg surface in ASHS template space)
    md[id] = MeshData(load_vtk(sd['workspace'].affine_moving), device, transform=transform)

    # Also downsample the mesh for faster affine registration
    md_ds[id] = MeshData(vtk_clone_pd(md[id].pd), device, transform=transform, target_faces=5000)

    # Apply additional Taubin smoothing to help build a smoother template
    save_vtk(md[id].pd, f'tmp/tb_input_{id}.vtk')


Decimating mesh, target: 5000 faces
Decimation complete, 5000 faces
Decimating mesh, target: 5000 faces
Decimation complete, 5000 faces
Decimating mesh, target: 5000 faces
Decimation complete, 5000 faces
Decimating mesh, target: 5000 faces
Decimation complete, 5000 faces
Decimating mesh, target: 5000 faces
Decimation complete, 5000 faces
Decimating mesh, target: 5000 faces
Decimation complete, 5000 faces
Decimating mesh, target: 5000 faces
Decimation complete, 5000 faces
Decimating mesh, target: 5000 faces
Decimation complete, 5000 faces
Decimating mesh, target: 5000 faces
Decimation complete, 5000 faces
Decimating mesh, target: 5000 faces
Decimation complete, 5000 faces
Decimating mesh, target: 5000 faces
Decimation complete, 5000 faces
Decimating mesh, target: 5000 faces
Decimation complete, 5000 faces
Decimating mesh, target: 5000 faces
Decimation complete, 5000 faces


In [106]:
# Compute the varifold loss between all pairs of atlases in the template subset before
# any registration - this is to determine the best candidate for the template
""" ids = list(md.keys())
dsq_sub = np.zeros((len(md), len(md)))
for i1, (k1,v1) in enumerate(md.items()):
    for i2, (k2,v2) in enumerate(md.items()):
        pair_loss = lossVarifoldSurfWithLabels(v1.ft, v2.vt, v2.ft, v1.lpt, v2.lpt, 
                                                GaussLinKernelWithLabels(sigma_varifold, v1.lp.shape[1]))
        dsq_sub[i1,i2] = pair_loss(v1.vt).item()
        print(k1, k2, dsq_sub[i1,i2])
        
# Find the index of the template candidate
i_src = np.argmin(dsq_sub.sum(axis=1))
id_src = ids[i_src]
print(f'Best template candidate: {id_src}, mean distance {dsq_sub.mean(axis=1)[i_src]} vs {dsq_sub.mean()}') """



" ids = list(md.keys())\ndsq_sub = np.zeros((len(md), len(md)))\nfor i1, (k1,v1) in enumerate(md.items()):\n    for i2, (k2,v2) in enumerate(md.items()):\n        pair_loss = lossVarifoldSurfWithLabels(v1.ft, v2.vt, v2.ft, v1.lpt, v2.lpt, \n                                                GaussLinKernelWithLabels(sigma_varifold, v1.lp.shape[1]))\n        dsq_sub[i1,i2] = pair_loss(v1.vt).item()\n        print(k1, k2, dsq_sub[i1,i2])\n        \n# Find the index of the template candidate\ni_src = np.argmin(dsq_sub.sum(axis=1))\nid_src = ids[i_src]\nprint(f'Best template candidate: {id_src}, mean distance {dsq_sub.mean(axis=1)[i_src]} vs {dsq_sub.mean()}') "

In [107]:
def my_rotation_from_vector(x):
    """
    Generate a 3D rotation vector from three parameters.

    Args:
        x: 
            A torch tensor of shape (3). It contains the parameters of the rotation.
            [Write more detail about what the parameters mean geometrically]
    Output:
        A shape (3,3) tensor holding a rotation matrix corresponding to x
    """
    # I will use the the axis/angle representation. The norm of the vector x gives the
    # angle in radians, and the normalized vector is the axis around which the rotation
    # is performed. At x=[0,0,0], there is a degeneracy that requires special handling
    # but this should not prevent the code from being used in optimization
    
    # Compute theta, no issues here
    theta = torch.norm(x)

    # Use the trick from `torch.nn.functional.normalize`, which adds a small epsilon to
    # the denominator to avoid division by zero
    v = torch.nn.functional.normalize(x, dim=0)

    # Apply the Rodrigues formula
    A = torch.zeros(3, 3, dtype=x.dtype, device=x.device)
    A[0,1], A[0,2], A[1,2] = -v[2], v[1], -v[0]
    K = A - A.T
    R = torch.eye(3, dtype=x.dtype, device=x.device) + torch.sin(theta) * K + (1-torch.cos(theta)) * (K @ K)
    return R

In [110]:
# Compute the varifold loss between all pairs of atlases with similarity transform,
# running a quick registration between all pairs. This might get too much for large
# training sets though
ids = list(md.keys())
dsq_sub = np.zeros((len(md), len(md)))

# The default affine parameters
theta_all = { }
kernel = GaussLinKernelWithLabels(sigma_varifold, md[ids[0]].lp.shape[1])
for i1, (k1,v1) in enumerate(md_ds.items()):
    for i2, (k2,v2) in enumerate(md_ds.items()):
        if k1 != k2: 
            # Define the symmetric loss for this pair
            loss_ab = lossVarifoldSurfWithLabels(v1.ft, v2.vt, v2.ft, v1.lpt, v2.lpt, kernel)
            loss_ba = lossVarifoldSurfWithLabels(v2.ft, v1.vt, v1.ft, v2.lpt, v1.lpt, kernel)
            pair_theta = torch.tensor([0.01, 0.01, 0.01, 1.0, 0.0, 0.0, 0.0], 
                                    dtype=torch.float32, device=device, requires_grad=True)
            
            # Create optimizer
            opt_affine = torch.optim.LBFGS([pair_theta], max_eval=10, max_iter=10, line_search_fn='strong_wolfe')

            # Define closure
            def closure():
                opt_affine.zero_grad()

                R = my_rotation_from_vector(pair_theta[0:3]) * pair_theta[3]
                b = pair_theta[4:]
                R_inv = torch.inverse(R)
                b_inv = - R_inv @ b

                v1_to_v2 = (R @ v1.vt.t()).t() + b
                v2_to_v1 = (R_inv @ v2.vt.t()).t() + b_inv

                L = 0.5 * (loss_ab(v1_to_v2) + loss_ba(v2_to_v1))
                L.backward()
                return L
            
            # Run the optimization
            for i in range(10):
                opt_affine.step(closure)

            # Print loss and record the best run/best parameters
            dsq_sub[i1,i2] = closure().item()
            theta_all[(k1,k2)] = pair_theta.detach()
            print(f'Pair {k1}, {k2} loss : {dsq_sub[i1,i2]:8.6f}')

Pair 104937L, 106049L loss : 100608.500000
Pair 104937L, 106312R loss : 106330.859375
Pair 104937L, 113909R loss : 96887.109375
Pair 104937L, 116748R loss : 81386.609375
Pair 104937L, 117243R loss : 113444.625000
Pair 104937L, 117667R loss : 93864.312500
Pair 104937L, 118374L loss : 81879.468750
Pair 104937L, 118430R loss : 105904.375000
Pair 104937L, 120126L loss : 75038.906250
Pair 104937L, 120267L loss : 95297.304688
Pair 104937L, 120937L loss : 118714.132812
Pair 104937L, 121250L loss : 63058.117188
Pair 106049L, 104937L loss : 100663.859375
Pair 106049L, 106312R loss : 87120.843750
Pair 106049L, 113909R loss : 86781.898438
Pair 106049L, 116748R loss : 91687.468750
Pair 106049L, 117243R loss : 75494.968750
Pair 106049L, 117667R loss : 84323.539062
Pair 106049L, 118374L loss : 47913.578125
Pair 106049L, 118430R loss : 72208.617188
Pair 106049L, 120126L loss : 88430.046875
Pair 106049L, 120267L loss : 75380.109375
Pair 106049L, 120937L loss : 108996.382812
Pair 106049L, 121250L loss 

In [111]:
# Find the index of the template candidate
i_best = np.argmin(dsq_sub.sum(axis=1))
k_best = ids[i_best]
print(f'Best template candidate: {k_best}, mean distance {dsq_sub.mean(axis=1)[i_best]} vs {dsq_sub.mean()}')

# Now we need to go through and apply the best transformation to each case
for i, (k,v) in enumerate(md_ds.items()):
    affine_mat = np.eye(4)
    md_full = md[k]
    if k != k_best:
        # Compute the transform to move k to k_best
        theta_pair = theta_all[(k,k_best)]
        R = my_rotation_from_vector(theta_pair[0:3]) * theta_pair[3]
        b = theta_pair[4:]

        # Save the transformation
        affine_mat[0:3,0:3] = R.detach().cpu().numpy()
        affine_mat[0:3,  3] = b.detach().cpu().numpy()

        # Save the registered mesh (original resolution)
        v_reg = (R @ md_full.vt.t()).t() + b
        pd_reg = vtk_make_pd(v_reg.detach().cpu().numpy(), md_full.f)
        pd_reg = vtk_set_cell_array(pd_reg, 'plab', md_full.lp)
        save_vtk(pd_reg, f'tmp/test_register_{k}.vtk')
    else:
        pd_reg = vtk_make_pd(md_full.vt.detach().cpu().numpy(), md_full.f)
        pd_reg = vtk_set_cell_array(pd_reg, 'plab', md_full.lp)
        save_vtk(pd_reg, f'tmp/test_register_{k}.vtk')

    np.savetxt(f'tmp/affine_to_template_{k}.mat', affine_mat)

Best template candidate: 118374L, mean distance 70489.14963942308 vs 81912.80422522189


# Build the template starting from an ellipsoid

In [6]:
# Now let's read these meshes all over again
md_aff = {}
for id, sd in data.items():
    md_aff[id] = MeshData(load_vtk(f'tmp/test_register_{id}.vtk'), device)


In [137]:
# Generate a sphere
ms = pymeshlab.MeshSet()
ms.create_sphere(subdiv = 4)
m0 = ms.mesh(0)
v_sph, f_sph = m0.vertex_matrix(), m0.face_matrix()
pd_sph = vtk_make_pd(v_sph, f_sph)
pd_sph = vtk_set_cell_array(pd_sph, 'plab', np.zeros((f_sph.shape[0],1)))
md_sph = MeshData(pd_sph, device)

# Map the cartesian coordinates to spherical coordinates
sph_phi = np.arctan2(v_sph[:,1], v_sph[:,0])
sph_theta = np.arccos(v_sph[:,2])

In [138]:
# Find an affine transformation of the sphere that best aligns with the data 
# using the varifold measure
v_all = np.concatenate([ x.v for id,x in md_aff.items() ], 0)
pca = PCA(n_components=3)
pca.fit(v_all)

# Create losses for each of the target meshes
kernel = GaussLinKernel(sigma_varifold)
loss = { id: lossVarifoldSurf(md_sph.ft, v.vt, v.ft, kernel) for (id,v) in md_aff.items() }

# Create a parameter tensor for the sphere
b = torch.tensor(pca.mean_, dtype=torch.float32, device=device, requires_grad=True).contiguous()
A = torch.tensor(pca.get_covariance(), dtype=torch.float32, device=device, requires_grad=True).contiguous()

# Generate a combined objective function
optimizer = torch.optim.LBFGS([A,b], max_eval=16, max_iter=16, line_search_fn='strong_wolfe')
start = time.time()

def closure():
    optimizer.zero_grad()
    # Apply transformation to the sphere
    x = md_sph.vt
    y = (A @ x.T).T + b
    L = 0
    for i, (id,v) in enumerate(md_aff.items()):
        L = L + loss[id](y)
    L = L / len(md_aff.items())
    L.backward()
    return L

for i in range(30):
    print(f'Iter {i:03d}, Loss: {closure()}')
    optimizer.step(closure)

Iter 000, Loss: 12633112.0
Iter 001, Loss: 248984.203125
Iter 002, Loss: 244574.953125
Iter 003, Loss: 244538.859375
Iter 004, Loss: 244467.703125
Iter 005, Loss: 244355.515625
Iter 006, Loss: 243975.21875
Iter 007, Loss: 243362.234375
Iter 008, Loss: 243129.703125
Iter 009, Loss: 242992.734375
Iter 010, Loss: 242734.4375
Iter 011, Loss: 241542.375
Iter 012, Loss: 239994.625
Iter 013, Loss: 238290.015625
Iter 014, Loss: 235576.390625
Iter 015, Loss: 234737.625
Iter 016, Loss: 226558.734375
Iter 017, Loss: 207570.4375
Iter 018, Loss: 138724.328125
Iter 019, Loss: 117534.96875
Iter 020, Loss: 117534.1875
Iter 021, Loss: 117534.1796875
Iter 022, Loss: 117534.1796875
Iter 023, Loss: 117534.1796875
Iter 024, Loss: 117534.1796875
Iter 025, Loss: 117534.1796875
Iter 026, Loss: 117534.1796875
Iter 027, Loss: 117534.1796875
Iter 028, Loss: 117534.1796875
Iter 029, Loss: 117534.1796875


In [117]:
# Compute the new rotated sphere
v_sph_opt = (A @ md_sph.vt.T).T + b

# Peform remeshing of the sphere
ms = pymeshlab.MeshSet()
ms.add_mesh(pymeshlab.Mesh(vertex_matrix=v_sph_opt.detach().cpu().numpy(), face_matrix=md_sph.f))
ms.meshing_isotropic_explicit_remeshing()
v_ell, f_ell = ms.mesh(0).vertex_matrix(), ms.mesh(0).face_matrix()

pd_ell = vtk_make_pd(v_ell, f_ell)
pd_ell = vtk_set_cell_array(pd_ell, 'plab', np.zeros((f_ell.shape[0],1)))
md_ell = MeshData(pd_ell, device)

pd_sphere_opt = vtk_make_pd(v_ell, f_ell)
save_vtk(pd_sphere_opt, 'tmp/ellipsoid_best_fit.vtk')

In [109]:
# Now we need to initialize the labeling of the sphere. We can try directly to use OMT to 
# match the sphere to each of the meshes and maybe that's going to be good enough for getting
# the initial label distributions. If not, have to deform
def to_measure(points, triangles):
    """Turns a triangle into a weighted point cloud."""

    # Our mesh is given as a collection of ABC triangles:
    A, B, C = points[triangles[:, 0]], points[triangles[:, 1]], points[triangles[:, 2]]

    # Locations and weights of our Dirac atoms:
    X = (A + B + C) / 3  # centers of the faces
    S = torch.sqrt(torch.sum(torch.cross(B - A, C - A) ** 2, dim=1)) / 2  # areas of the faces

    # We return a (normalized) vector of weights + a "list" of points
    return S / torch.sum(S), X

# Compute optimal transport matching
def match_omt(vs, fs, vt, ft):
    (a_src, x_src) = to_measure(vs, fs)
    (a_trg, x_trg) = to_measure(vt, ft)
    x_src.requires_grad_(True)
    x_trg.requires_grad_(True)

    # Generate correspondence between models using OMT
    t_start = time.time()
    w_loss = geomloss.SamplesLoss("sinkhorn", p=2, blur=0.05, scaling=0.8, backend='multiscale', verbose=True)
    w_loss_value = w_loss(a_src, x_src, a_trg, x_trg)
    [w_loss_grad] = torch.autograd.grad(w_loss_value, x_src)
    w_match = x_src - w_loss_grad / a_src[:, None]
    t_end = time.time()

    print(f'OMT matching distance: {w_loss_value.item()}, time elapsed: {t_end-t_start}')
    return w_loss_value, w_match

In [119]:
plab_sample = []
for (id, md_i) in md_aff.items():
    f_omt, w_omt = match_omt(md_ell.vt, md_ell.ft, md_i.vt, md_i.ft)
    lp_omt = vtk_sample_cell_array_at_vertices(md_i.pd, md_i.lp, w_omt.detach().cpu().numpy())
    plab_sample.append(lp_omt)

233x385 clusters, computed at scale = 3.630
Successive scales :  79.210, 79.210, 63.368, 50.695, 40.556, 32.444, 25.956, 20.764, 16.612, 13.289, 10.631, 8.505, 6.804, 5.443, 4.355, 3.484, 2.787, 2.230, 1.784, 1.427, 1.142, 0.913, 0.731, 0.584, 0.468, 0.374, 0.299, 0.239, 0.192, 0.153, 0.123, 0.098, 0.078, 0.063, 0.050, 0.050
Jump from coarse to fine between indices 14 (σ=4.355) and 15 (σ=3.484).
Keep 37203/89705 = 41.5% of the coarse cost matrix.
Keep 25109/54289 = 46.3% of the coarse cost matrix.
Keep 52585/148225 = 35.5% of the coarse cost matrix.
OMT matching distance: 14.987838745117188, time elapsed: 3.0157408714294434
290x483 clusters, computed at scale = 3.324
Successive scales :  72.530, 72.530, 58.024, 46.419, 37.135, 29.708, 23.767, 19.013, 15.211, 12.169, 9.735, 7.788, 6.230, 4.984, 3.987, 3.190, 2.552, 2.042, 1.633, 1.307, 1.045, 0.836, 0.669, 0.535, 0.428, 0.343, 0.274, 0.219, 0.175, 0.140, 0.112, 0.090, 0.072, 0.057, 0.050
Jump from coarse to fine between indices 14 (σ=3.

In [120]:
plab_sample_avg = np.stack(plab_sample).mean(0)

In [121]:
pd_sphere_opt_2 = vtk_set_cell_array(pd_sphere_opt, 'plab', softmax(plab_sample_avg * 10, axis=1))
save_vtk(pd_sphere_opt_2, 'tmp/ellipsoid_best_fit_plab.vtk')

In [122]:
# Use the sphere as a starting point for template fitting
md_src2 = MeshData(pd_sphere_opt_2, device)

In [112]:
# Map an array to new vertex locations
def vtk_sample_point_array_at_vertices(pd_src, array, x_samples):
    # Use the locator to sample from the halfway mesh
    loc = vtk.vtkCellLocator()
    loc.SetDataSet(pd_src)
    loc.BuildLocator()
    result = np.zeros((x_samples.shape[0], array.shape[1]))    
    cellId = vtk.reference(0)
    c = [0.0, 0.0, 0.0]
    subId = vtk.reference(0)
    d = vtk.reference(0.0)
    pcoord = [0.0, 0.0, 0.0]
    wgt = [0.0, 0.0, 0.0]
    xj = [0.0, 0.0, 0.0]
    for j in range(x_samples.shape[0]):
        loc.FindClosestPoint(x_samples[j,:], c, cellId, subId, d)
        cell = pd_src.GetCell(cellId)
        cell.EvaluatePosition(x_samples[j,:], c, subId, pcoord, d, wgt)
        result[j] = np.sum(np.stack([ array[cell.GetPointId(i),:] * w for i, w in enumerate(wgt) ]), 0)
    return result


# Given a set of sampling locations on a triangle mesh surface, generate arrays of
# vertex indices and weights that allow data from the source mesh to be sampled at
# the sampling locations. This can be used to interpolate point data, coordinates,
# etc from the source mesh or spatial transformations thereof 
def vtk_get_interpolation_arrays_for_sample(pd_src, x_samples):
    
    # Use the locator to sample from the halfway mesh
    loc = vtk.vtkCellLocator()
    loc.SetDataSet(pd_src)
    loc.BuildLocator()

    # Return data: array of vertex indices and weights
    v_res = np.zeros((x_samples.shape[0], 3), dtype=np.int32)
    w_res = np.zeros((x_samples.shape[0], 3), dtype=np.double)

    cellId = vtk.reference(0)
    c = [0.0, 0.0, 0.0]
    subId = vtk.reference(0)
    d = vtk.reference(0.0)
    pcoord = [0.0, 0.0, 0.0]
    wgt = [0.0, 0.0, 0.0]
    xj = [0.0, 0.0, 0.0]
    for j in range(x_samples.shape[0]):
        loc.FindClosestPoint(x_samples[j,:], c, cellId, subId, d)
        cell = pd_src.GetCell(cellId)
        cell.EvaluatePosition(x_samples[j,:], c, subId, pcoord, d, wgt)
        for i, w in enumerate(wgt):
            v_res[j,i], w_res[j,i] = cell.GetPointId(i), w

    return v_res, w_res



# Define a function that can fit a model to a population
def fit_model_to_population_nojac(md_root, md_targets, n_iter = 10, 
                                  sigma_lddmm=5, sigma_root=20, sigma_varifold=5, gamma_lddmm=0.1):

    # LDDMM kernels
    device = md_root.vt.device
    K_root = GaussKernel(sigma=torch.tensor(sigma_root, dtype=torch.float32, device=device))
    K_temp = GaussKernel(sigma=torch.tensor(sigma_lddmm, dtype=torch.float32, device=device))
    K_vari = GaussLinKernelWithLabels(torch.tensor(sigma_varifold, dtype=torch.float32, device=device), md_root.lp.shape[1])

    # Create losses for each of the target meshes
    d_loss = { id: lossVarifoldSurfWithLabels(md_root.ft, v.vt, v.ft, md_root.lpt, v.lpt, K_vari) for id,v in md_targets.items() }
            
    # Create the root->template points/momentum, as well as the template->subject momenta
    q_root = torch.tensor(md_root.vt, dtype=torch.float32, device=device, requires_grad=True).contiguous()
    p_root = torch.zeros(md_root.vt.shape, dtype=torch.float32, device=device, requires_grad=True).contiguous()
    p_temp = torch.zeros((len(md_targets),) + md_root.vt.shape, dtype=torch.float32, device=device, requires_grad=True).contiguous()

    # Create the optimizer
    start = time.time()
    optimizer = torch.optim.LBFGS([p_root, p_temp], max_eval=16, max_iter=16, line_search_fn='strong_wolfe')

    def closure():
        optimizer.zero_grad()

        # Shoot root->template
        _, q_temp = Shooting(p_root, q_root, K_root)[-1]

        # Make the momenta applied to the template average out to zero
        p_temp_z = p_temp - torch.mean(p_temp, 0, keepdim=True)

        # Compute the loss
        L = 0
        for i, (id,v) in enumerate(md_targets.items()):
            _, q_i = Shooting(p_temp_z[i,:], q_temp, K_temp)[-1]
            L = L + gamma_lddmm * Hamiltonian(K_temp)(p_temp_z[i,:], q_temp) + d_loss[id](q_i)
        L = L / len(md_targets.items())
        L.backward()
        return L

    # Perform optimization
    for i in range(n_iter):
        print(f'Iteration {i:03d}  Loss {closure()}')
        optimizer.step(closure)

    print(f'Optimization (L-BFGS) time: {round(time.time() - start, 2)} seconds')

    # Return the root model and the momenta
    p_temp_z = p_temp - torch.mean(p_temp, 0, keepdim=True)
    return p_root, p_temp_z 


# Define a function that can fit a model to a population
def fit_model_to_population(md_root, md_targets, n_iter = 10, 
                            sigma_lddmm=5, sigma_root=20, sigma_varifold=5, 
                            gamma_lddmm=0.1, w_jacobian_penalty=1.0):

    # LDDMM kernels
    device = md_root.vt.device
    K_root = GaussKernel(sigma=torch.tensor(sigma_root, dtype=torch.float32, device=device))
    K_temp = GaussKernel(sigma=torch.tensor(sigma_lddmm, dtype=torch.float32, device=device))
    K_vari = GaussLinKernelWithLabels(torch.tensor(sigma_varifold, dtype=torch.float32, device=device), md_root.lp.shape[1])

    # Create losses for each of the target meshes
    d_loss = { id: lossVarifoldSurfWithLabels(md_root.ft, v.vt, v.ft, md_root.lpt, v.lpt, K_vari) for id,v in md_targets.items() }
            
    # Create the root->template points/momentum, as well as the template->subject momenta
    q_root = torch.tensor(md_root.vt, dtype=torch.float32, device=device, requires_grad=True).contiguous()
    p_root = torch.zeros(md_root.vt.shape, dtype=torch.float32, device=device, requires_grad=True).contiguous()
    p_temp = torch.zeros((len(md_targets),) + md_root.vt.shape, dtype=torch.float32, device=device, requires_grad=True).contiguous()

    # Create the optimizer
    start = time.time()
    optimizer = torch.optim.LBFGS([p_root, p_temp], max_eval=16, max_iter=16, line_search_fn='strong_wolfe')
    n_subj = len(md_targets.items())

    def closure(detail = False):
        optimizer.zero_grad()

        # Shoot root->template
        _, q_temp = Shooting(p_root, q_root, K_root)[-1]

        # Compute the triangle areas for the template
        z0 = q_temp[md_root.ft]
        area_0 = torch.norm(torch.cross(z0[:,1,:] - z0[:,0,:], z0[:,2,:] - z0[:,0,:]),dim=1) 

        # Make the momenta applied to the template average out to zero
        p_temp_z = p_temp - torch.mean(p_temp, 0, keepdim=True)

        # Compute the loss
        l_ham, l_data, l_jac = 0, 0, 0
        for i, (id,v) in enumerate(md_targets.items()):
            _, q_i = Shooting(p_temp_z[i,:], q_temp, K_temp)[-1]

            z = q_i[md_root.ft]
            area = torch.norm(torch.cross(z[:,1,:] - z[:,0,:], z[:,2,:] - z[:,0,:]),dim=1)
            log_jac = torch.log10(area / area_0)

            l_ham = l_ham + gamma_lddmm * Hamiltonian(K_temp)(p_temp_z[i,:], q_temp)
            l_data = l_data + d_loss[id](q_i)
            l_jac = l_jac + torch.sum(log_jac ** 2) * w_jacobian_penalty

        l_ham, l_data, l_jac = l_ham / n_subj, l_data / n_subj, l_jac / n_subj
        L = l_ham + l_data + l_jac
        L.backward()

        # Return loss or detail
        if detail:
            return l_ham, l_data, l_jac, L
        else:
            return L

    # Perform optimization
    for i in range(n_iter):
        l_ham, l_data, l_jac, L = closure(True)
        print(f'Iteration {i:03d}  Loss H={l_ham:6.2f}  D={l_data:6.2f}  J={l_jac:6.2f}  Total={L:6.2f}')
        optimizer.step(closure)

    print(f'Optimization (L-BFGS) time: {round(time.time() - start, 2)} seconds')

    # Return the root model and the momenta
    p_temp_z = p_temp - torch.mean(p_temp, 0, keepdim=True)
    return p_root, p_temp_z 


# Compute label probability sampling using shooting and OMT
# def map_array_to_fittedect_shooting_omt(q_root, p_root, p_temp, md_targets, array):
#     _, q_temp = Shooting(p_root, q_root, K)[-1]
#     plab_sample = []
#     p_temp_z = p_temp - torch.mean(p_temp, 0, keepdim=True)
#     for (id, md_i) in md_aff.items():
#         _, q_i = Shooting(p_temp_z[i,:], q_temp, K)[-1]
#         f_omt, w_omt = match_omt(q_i, md_src2.ft, md_i.vt, md_i.ft)
#         lp_omt = vtk_sample_cell_array_at_vertices(md_i.pd, md_i.lp, w_omt.detach().cpu().numpy())
#         plab_sample.append(lp_omt)


def shoot_root_to_template(md_root, p_root, sigma_root=20):
    device = md_root.vt.device
    K_root = GaussKernel(sigma=torch.tensor(sigma_root, dtype=torch.float32, device=device))
    _, q_temp = Shooting(p_root, md_root.vt.clone().requires_grad_(True).contiguous(), K_root)[-1]
    pd = vtk_make_pd(q_temp.detach().cpu().numpy(), md_root.f)
    pd = vtk_set_cell_array(pd, 'plab', md_root.lp)
    return MeshData(pd, device=q_temp.device)


def shoot_template_to_subject(md_temp, p, sigma_root=20):
    device = md_temp.vt.device
    K_root = GaussKernel(sigma=torch.tensor(sigma_root, dtype=torch.float32, device=device))
    _, q = Shooting(p, md_temp.vt.clone().requires_grad_(True).contiguous(), K_root)[-1]
    pd = vtk_make_pd(q.detach().cpu().numpy(), md_temp.f)
    pd = vtk_set_cell_array(pd, 'plab', md_temp.lp)
    return MeshData(pd, device=device)


# Update the template by remeshing and updating probability labels
def update_model_by_remeshing(md_root, md_targets, p_root, p_temp_z, 
                              sigma_lddmm=5, sigma_root=20):

    # LDDMM kernels
    device = md_root.vt.device
    K_root = GaussKernel(sigma=torch.tensor(sigma_root, dtype=torch.float32, device=device))
    K_temp = GaussKernel(sigma=torch.tensor(sigma_lddmm, dtype=torch.float32, device=device))

    # Shoot from root to obtain the template
    q_root = md_root.vt.clone().requires_grad_(True).contiguous()
    _, q_temp = Shooting(p_root, q_root, K_root)[-1]
    pd_template = vtk_make_pd(q_temp.detach().cpu().numpy(), md_root.f)

    # Sample and average the plab array from the subjects using OMT
    plab_sample = []
    for i, (id, md_i) in enumerate(md_targets.items()):
        _, q_i = Shooting(p_temp_z[i,:], q_temp, K_temp)[-1]
        _, w_omt = match_omt(q_i, md_root.ft, md_i.vt, md_i.ft)
        lp_omt = vtk_sample_cell_array_at_vertices(md_i.pd, md_i.lp, w_omt.detach().cpu().numpy())
        plab_sample.append(lp_omt)
    plab_sample_avg = np.stack(plab_sample).mean(0)

    # Apply remeshing to the template
    ms = pymeshlab.MeshSet()
    ms.add_mesh(pymeshlab.Mesh(vertex_matrix=q_temp.detach().cpu().numpy(), face_matrix=md_root.f))
    ms.meshing_isotropic_explicit_remeshing()
    v_remesh, f_remesh = ms.mesh(0).vertex_matrix(), ms.mesh(0).face_matrix()
    pd_remesh = vtk_make_pd(v_remesh, f_remesh)

    # Apply the remeshing to the plab array
    _, w_omt = match_omt(torch.tensor(v_remesh, dtype=torch.float32, device=device), 
                         torch.tensor(f_remesh, dtype=torch.long, device=device),
                         q_temp, md_root.ft)
    lp_remesh = vtk_sample_cell_array_at_vertices(pd_template, plab_sample_avg, w_omt.detach().cpu().numpy())
    pd_remesh = vtk_set_cell_array(pd_remesh, 'plab', softmax(lp_remesh * 10, 1))

    # Return the new template as MeshData
    return MeshData(pd_remesh, device=md_root.vt.device)


# After OMT matching, we get a mapping from triangles to triangles. This code
# converts this into a vertex to vertex mapping, returning for every vertex in
# the fitted mesh a set of source vertices and their weights in the target mesh
def omt_match_to_vertex_weights(pd_fitted, pd_target, w_omt):
    pd_omt = vtk_clone_pd(pd_fitted)
    pd_omt = vtk_set_cell_array(pd_omt, 'match', w_omt)
    vtk_cell_array_to_point_array(pd_omt, 'match')
    v_omt = vtk_get_point_array(pd_omt, 'match')
    v_int, w_int = vtk_get_interpolation_arrays_for_sample(pd_target, v_omt)
    return v_omt, v_int, w_int


# Map template into subject space and return mesh
def map_template_to_subject(md_temp, md_target, p_temp, sigma_lddmm=5):

    # LDDMM kernels
    device = md_temp.vt.device
    K_temp = GaussKernel(sigma=torch.tensor(sigma_lddmm, dtype=torch.float32, device=device))

    # Shoot from template to subject and save as polydata/meshdata
    q_temp = md_temp.vt.clone().requires_grad_(True).contiguous()
    _, q_fitted = Shooting(p_temp, q_temp, K_temp)[-1]
    pd_fitted = vtk_make_pd(q_fitted.detach().cpu().numpy(), md_temp.f)
    pd_fitted = vtk_set_cell_array(pd_fitted, 'plab', md_temp.lp)

    # Match the subject via OMT, i.e. every template vertex is mapped to somewhere on the
    # subject mesh, this fits more closely than LDDMM but might break topology
    # _, w_omt = match_omt(md_target.vt, md_target.ft, q_fitted, md_temp.ft)
    _, w_omt = match_omt(q_fitted, md_temp.ft, md_target.vt, md_target.ft)
    v_omt, v_int, w_int = omt_match_to_vertex_weights(pd_fitted, md_target.pd, w_omt.detach().cpu().numpy())

    # Create a clean model to return
    pd_omt = vtk_make_pd(v_omt, md_temp.f)
    pd_omt = vtk_set_cell_array(pd_omt, 'plab', md_temp.lp)
    pd_omt = vtk_set_cell_array(pd_omt, 'match', w_omt.detach().cpu().numpy())

    # Get the interpolation arrays from the matching and place them into the fitted template
    pd_fitted = vtk_set_point_array(pd_fitted, 'omt_v_int', v_int)
    pd_fitted = vtk_set_point_array(pd_fitted, 'omt_w_int', w_int)

    return pd_fitted, pd_omt


In [373]:
def build_template_multistage(md_root, md_targets, schedule, 
                              sigma_lddmm=5, sigma_root=20, sigma_varifold=5,
                              gamma_lddmm=0.1, w_jacobian_penalty=1.0):

    # Iterate over the schedule
    for i, iter in enumerate(schedule):

        # Print iteration
        print(f'*** TEMPLATE BUILD STAGE {i} ***')

        # Fit the model to the population
        p_root, p_temp_z = fit_model_to_population(md_root, md_targets, iter,
                                                   sigma_lddmm=sigma_lddmm, sigma_root=sigma_root, 
                                                   sigma_varifold=sigma_varifold, 
                                                   gamma_lddmm=gamma_lddmm,
                                                   w_jacobian_penalty=w_jacobian_penalty)

        # Compute the template by forward shooting
        md_temp = shoot_root_to_template(md_root, p_root, sigma_root=sigma_root)
        save_vtk(md_temp.pd, f'tmp/template_iter{i:02d}.vtk')

        # Remesh the template
        md_remesh = update_model_by_remeshing(md_root, md_targets, p_root, p_temp_z, sigma_lddmm=sigma_lddmm, sigma_root=sigma_root)
        save_vtk(md_remesh.pd, f'tmp/template_iter{i:02d}_remesh.vtk')

        # Make the template the new root
        md_root = md_remesh

    # Return the model from the last iteration
    return md_temp, p_temp_z


In [372]:
template['left'].get_lddmm_gamma()

1.0

In [374]:
# Run the whole template building pipeline
# template['left'].get_varifold_sigma()
md_temp, p_temp = build_template_multistage(md_src2, md_aff, schedule=[10, 10, 10, 50], 
                                            sigma_root = 2.4 * template['left'].get_lddmm_sigma(), 
                                            sigma_lddmm = template['left'].get_lddmm_sigma(), 
                                            sigma_varifold = 5,
                                            gamma_lddmm=0.05, w_jacobian_penalty=10.0)

*** TEMPLATE BUILD STAGE 0 ***



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



Iteration 000  Loss H=  0.00  D=97949.55  J=  0.00  Total=97949.55
Iteration 001  Loss H=  0.35  D=61498.59  J=112.11  Total=61611.04
Iteration 002  Loss H=  3.46  D=42560.18  J=1072.43  Total=43636.07
Iteration 003  Loss H=  7.07  D=31487.99  J=1997.64  Total=33492.70
Iteration 004  Loss H= 10.79  D=23870.60  J=2663.31  Total=26544.71
Iteration 005  Loss H= 14.80  D=17575.81  J=3486.48  Total=21077.09
Iteration 006  Loss H= 17.06  D=13633.41  J=3604.33  Total=17254.79
Iteration 007  Loss H= 19.54  D=10460.05  J=3769.55  Total=14249.14
Iteration 008  Loss H= 19.92  D=9520.86  J=3700.99  Total=13241.77
Iteration 009  Loss H= 20.91  D=8462.33  J=3720.27  Total=12203.51
Optimization (L-BFGS) time: 546.64 seconds
385x408 clusters, computed at scale = 3.584
Successive scales :  78.208, 78.208, 62.566, 50.053, 40.042, 32.034, 25.627, 20.502, 16.401, 13.121, 10.497, 8.397, 6.718, 5.374, 4.300, 3.440, 2.752, 2.201, 1.761, 1.409, 1.127, 0.902, 0.721, 0.577, 0.462, 0.369, 0.295, 0.236, 0.189, 0.

In [375]:
# Save the template and the momenta
pd_temp_save = vtk_clone_pd(md_temp.pd)
for i, (id, md_i) in enumerate(md_aff.items()):
    vtk_set_point_array(pd_temp_save, f'momenta_{id}', p_temp[i,:,:].detach().cpu().numpy())
save_vtk(pd_temp_save, 'tmp/template_final_with_momenta.vtk')

In [376]:
# Match template to each subject and save the resulting meshes
for i, (id, md_i) in enumerate(md_aff.items()):
    # We want to map the template to the halfway surface of the MTL cortex in subject space
    # But that mesh has different number of vertices than the meshes we fitted in the template
    # building process. So we need to apply affine transform to the full-resolution inflated
    # mesh, then do OMT matching between the warped template and this target, then use this to
    # get coordinates of template vertices the space of the native halfway surface
    pd_infl_fullres = load_vtk(data[id]['workspace'].affine_moving)
    A = np.loadtxt(f'tmp/affine_to_template_{id}.mat')
    v_infl_fullres = vtk_get_points(pd_infl_fullres)
    v_infl_fullres_affine = (A[0:3,0:3] @ v_infl_fullres.T).T + A[0:3,3]

    # Send this mesh to the device
    vtk_set_points(pd_infl_fullres, v_infl_fullres_affine)
    vtk_set_cell_array(pd_infl_fullres, 'plab', np.zeros((v_infl_fullres_affine.shape[0],1)))
    md_infl_fullres = MeshData(pd_infl_fullres, device)

    # Get template mesh and OMT fitted mesh
    pd_i_fit, pd_i_omt = map_template_to_subject(md_temp, md_i, p_temp[i], 
                                                 sigma_lddmm = template['left'].get_lddmm_sigma())
    
    # These interpolation arrays map every template vertex to a face on the subject mesh. 
    v_int = vtk_get_point_array(pd_i_fit, 'omt_v_int').astype(np.int32)
    w_int = vtk_get_point_array(pd_i_fit, 'omt_w_int')

    # We also want to go all the way back to the original space. Load the original space mesh
    pd_hw_native = load_vtk(data[id]['workspace'].fn_cruise('mtl_avg_l2m-mesh-ras.vtk'))
    v_native = vtk_get_points(pd_hw_native)
    
    # Now we need to perform two interpolations
    v_temp_to_native = np.einsum('vij,vi->vj', v_native[v_int,:], w_int)

    #pd_i_fit, pd_i_omt = map_template_to_subject(md_temp, md_infl_fullres, p_temp[i], 
    #                                             sigma_lddmm = template['left'].get_lddmm_sigma())
    
    # The original mesh has a different number of vertices compared to the 
    # v_int = vtk_get_point_array(pd_i_fit, 'omt_v_int').astype(np.int32)
    # w_int = vtk_get_point_array(pd_i_fit, 'omt_w_int')
    # v_temp_to_native = np.einsum('vij,vi->vj', v_native[v_int,:], w_int)
    pd_temp_omt_to_native = vtk_make_pd(v_temp_to_native, vtk_get_triangles(pd_i_fit))
    pd_temp_omt_to_native = vtk_set_cell_array(pd_temp_omt_to_native, 'plab', vtk_get_cell_array(pd_i_fit, 'plab'))

    save_vtk(pd_i_fit, f'tmp/template_fit_to_{id}.vtk')
    save_vtk(pd_i_omt, f'tmp/template_omt_to_{id}.vtk')
    save_vtk(pd_temp_omt_to_native, f'tmp/template_omt_to_avg_ras_{id}.vtk')

385x410 clusters, computed at scale = 3.573
Successive scales :  77.975, 77.975, 62.380, 49.904, 39.923, 31.939, 25.551, 20.441, 16.353, 13.082, 10.466, 8.373, 6.698, 5.358, 4.287, 3.429, 2.744, 2.195, 1.756, 1.405, 1.124, 0.899, 0.719, 0.575, 0.460, 0.368, 0.295, 0.236, 0.189, 0.151, 0.121, 0.097, 0.077, 0.062, 0.050
Jump from coarse to fine between indices 14 (σ=4.287) and 15 (σ=3.429).
Keep 55103/157850 = 34.9% of the coarse cost matrix.
Keep 51577/148225 = 34.8% of the coarse cost matrix.
Keep 58802/168100 = 35.0% of the coarse cost matrix.
OMT matching distance: 0.19189341366291046, time elapsed: 4.253779888153076
462x497 clusters, computed at scale = 3.277
Successive scales :  71.518, 71.518, 57.215, 45.772, 36.617, 29.294, 23.435, 18.748, 14.999, 11.999, 9.599, 7.679, 6.143, 4.915, 3.932, 3.145, 2.516, 2.013, 1.610, 1.288, 1.031, 0.825, 0.660, 0.528, 0.422, 0.338, 0.270, 0.216, 0.173, 0.138, 0.111, 0.089, 0.071, 0.057, 0.050
Jump from coarse to fine between indices 14 (σ=3.932) 

389x426 clusters, computed at scale = 3.612
Successive scales :  78.817, 78.817, 63.053, 50.443, 40.354, 32.283, 25.827, 20.661, 16.529, 13.223, 10.579, 8.463, 6.770, 5.416, 4.333, 3.466, 2.773, 2.218, 1.775, 1.420, 1.136, 0.909, 0.727, 0.582, 0.465, 0.372, 0.298, 0.238, 0.191, 0.152, 0.122, 0.098, 0.078, 0.062, 0.050
Jump from coarse to fine between indices 14 (σ=4.333) and 15 (σ=3.466).
Keep 59778/165714 = 36.1% of the coarse cost matrix.
Keep 54883/151321 = 36.3% of the coarse cost matrix.
Keep 65134/181476 = 35.9% of the coarse cost matrix.
OMT matching distance: 0.09053988009691238, time elapsed: 4.318037986755371
364x366 clusters, computed at scale = 3.535
Successive scales :  77.133, 77.133, 61.706, 49.365, 39.492, 31.594, 25.275, 20.220, 16.176, 12.941, 10.353, 8.282, 6.626, 5.301, 4.240, 3.392, 2.714, 2.171, 1.737, 1.390, 1.112, 0.889, 0.711, 0.569, 0.455, 0.364, 0.291, 0.233, 0.186, 0.149, 0.119, 0.095, 0.076, 0.061, 0.050
Jump from coarse to fine between indices 14 (σ=4.240)

In [391]:
# Destination folder for the template
temp_save_folder = 'tmp/crashs_template'
os.makedirs(temp_save_folder, exist_ok=True)

# Compute curvature measures on the template
ms = pymeshlab.MeshSet()
ms.add_mesh(pymeshlab.Mesh(vertex_matrix=md_temp.v, face_matrix=md_temp.f))

for c_id,c_nm in { 0: 'Mean', 1: 'Gaussian', 4: 'ShapeIndex', 5: 'Curvedness' }.items():
    ms.compute_curvature_principal_directions_per_vertex(
        method='Scale Dependent Quadric Fitting', 
        curvcolormethod=c_id,
        scale=pymeshlab.AbsoluteValue(3.0))
    q = ms.mesh(0)
    vtk_set_point_array(md_temp.pd, f'Curvature_{c_nm}', ms.mesh(0).vertex_scalar_array())

# Set the label array in the template by taking argmax over plab
vtk_set_cell_array(md_temp.pd, 'label', np.argmax(md_temp.lp, axis=1))

# Save the left template
save_vtk(md_temp.pd, f'{temp_save_folder}/template_shoot_left.vtk')

# Apply a flip to the left template
pd_template_flip = vtk_clone_pd(md_temp.pd)
vtk_set_points(pd_template_flip,
               np.einsum('ij,kj->ki', flip_lr[:3,:3].T, md_temp.v) + flip_lr[:3,3])
save_vtk(pd_template_flip, f'{temp_save_folder}/template_shoot_right.vtk')

# Save the JSON
with open(f'{temp_save_folder}/template.json','wt') as fd:
    json.dump(template['left'].json, fd)


average vertex num in each fit: 58.576469
average vertex num in each fit: 58.599689
average vertex num in each fit: 58.599537
average vertex num in each fit: 58.599537


In [169]:
with open(f'{temp_save_folder}/template.json','wt') as fd:
    json.dump(template.json, fd)


# Compute consensus labeling

In [377]:
import SimpleITK as sitk

# Load the template again


# Ok, so we have the template in native space, we can now sample the complete selection of labels. Load the scans.
xvashs_label = np.zeros((md_temp.v.shape[0], len(md_aff)))
for i, (id, md_i) in enumerate(md_aff.items()):
    # Read the global segmentation
    fn_global_seg = f'/data/pauly2/ashs_xv/input/{id}/{id}_label_global.nii.gz'
    img = sitk.ReadImage(fn_global_seg, outputPixelType=sitk.sitkFloat32)

    # Load the mesh
    pd_temp_omt_to_native = load_vtk(f'tmp/template_omt_to_avg_ras_{id}.vtk')

    # Sample all the vertices - nearest neighbor
    v = vtk_get_points(pd_temp_omt_to_native)
    for j in range(v.shape[0]):
        x = [ -v[j,0], -v[j,1], v[j,2] ]
        try:
            l = img.EvaluateAtPhysicalPoint(x, sitk.sitkNearestNeighbor)
        except:
            l = 0
        xvashs_label[j,i] = l

    # Assign to the mesh
    vtk_set_point_array(pd_temp_omt_to_native, 'label_xv', xvashs_label[:,i]);
    save_vtk(pd_temp_omt_to_native, f'tmp/template_omt_to_avg_ras_withlabel_{id}.vtk')
    # break
        
# Consensus labeling
lab_consensus = np.argmax([np.sum(xvashs_label == v, 1) for v in range(69)], 0)
pd_consensus = vtk_clone_pd(md_temp.pd)
vtk_set_point_array(pd_consensus, 'xvashs', lab_consensus)

save_vtk(pd_consensus, 'tmp/template_global_label.vtk')

# More serious consensus labeling, using all Cruise layers

In [142]:
# Implements marching cubes with remeshing in voxel space
def layer_to_mesh(layer, edge_len_pct=0.75):
    pix_raw = sitk.GetArrayFromImage(layer)
    img = vtk.vtkImageData()
    img.GetPointData().SetScalars(numpy_to_vtk(pix_raw.flatten(), array_type=vtk.VTK_FLOAT))
    img.SetDimensions(layer.GetSize()[0], layer.GetSize()[1], layer.GetSize()[2]) 
    # img.SetOrigin(layer.GetOrigin()[0],layer.GetOrigin()[1],layer.GetOrigin()[2])
    # img.SetSpacing(layer.GetSpacing()[0],layer.GetSpacing()[1],layer.GetSpacing()[2])
    # img.SetDirectionMatrix(*layer.GetDirection())

    cube = vtk.vtkMarchingCubes()
    cube.SetInputData(img)
    cube.SetNumberOfContours(1)
    cube.SetValue(0, 0.0)

    tri1 = vtk.vtkTriangleFilter()
    tri1.SetInputConnection(cube.GetOutputPort())
    tri1.PassLinesOff()
    tri1.PassVertsOff()

    clean = vtk.vtkCleanPolyData()
    clean.SetInputConnection(tri1.GetOutputPort())
    clean.PointMergingOn()
    clean.SetTolerance(0.0)

    tri2 = vtk.vtkTriangleFilter()
    tri2.SetInputConnection(clean.GetOutputPort())
    tri2.PassLinesOff()
    tri2.PassVertsOff()

    tri2.Update()
    pd_cubes = tri2.GetOutput()

    # Apply remeshing to the template
    ms = pymeshlab.MeshSet()
    ms.add_mesh(pymeshlab.Mesh(vertex_matrix=vtk_get_points(pd_cubes), face_matrix=vtk_get_triangles(pd_cubes)))
    ms.meshing_isotropic_explicit_remeshing(targetlen = pymeshlab.Percentage(edge_len_pct))
    v_remesh, f_remesh = ms.mesh(0).vertex_matrix(), ms.mesh(0).face_matrix()
    return vtk_make_pd(v_remesh, f_remesh)


# Propagate a mesh through the levelset shells using OMT, starting with the 
# specified level. If no level specified, middle level will be used
def profile_meshing_omt(img_ls, source_mesh=None, init_layer=None):

    # The OMT algorithm that is called repeatedly
    def do_omt(l_src, l_trg):
        _, w_omt = match_omt(
            torch.tensor(vtk_get_points(l_src), dtype=torch.float32, device=device),
            torch.tensor(vtk_get_triangles(l_src), dtype=torch.int32, device=device),
            torch.tensor(vtk_get_points(l_trg), dtype=torch.float32, device=device),
            torch.tensor(vtk_get_triangles(l_trg), dtype=torch.int32, device=device))
        v_omt, _, _ = omt_match_to_vertex_weights(l_src, l_trg, w_omt.detach().cpu().numpy())
        return vtk_make_pd(v_omt, vtk_get_triangles(l_src))
    
    # Extract all the isolayers from the image 
    layers = [ layer_to_mesh(img_ls[:,:,:,k]) for k in range(img_ls.GetSize()[3]) ]
    matched = [ None for k in layers ]

    # Determine init layer
    if init_layer is None:
        init_layer = img_ls.GetSize()[3] // 2

    # Propagate from the mesh to the initial layer (unless we want to use a layer as source)
    matched[init_layer] = do_omt(source_mesh, layers[init_layer]) if source_mesh else layers[init_layer]

    # Propagate in both directions
    seq1 = reversed(range(0, init_layer))
    seq2 = reversed(range(img_ls.GetSize()[3]-1, init_layer, -1))
    for seq in seq1, seq2:
        prop_source = matched[init_layer]
        for k in seq:
            print(f'Propagating to layer {k}')
            matched[k] = do_omt(prop_source, layers[k])
            prop_source = matched[k]

    # Return the propagated meshes
    return matched, init_layer

In [144]:
# Load the CSV with the label descriptions
import pandas as pd
df_uclm = pd.read_csv('/data/pauly2/ashs_xv/manifest/label_grouping_ashs.csv',
                      names=('uclm','group','uclm_desc','group_desc'))

# These are the gray matter labels
uclm_group_labels_gm = (1,3,4,5,7,8,9)

# Label to group map
uclm_label_grouping = { q['uclm']: q['group'] for _, q in df_uclm.iterrows() if q['group'] in uclm_group_labels_gm }

In [147]:
# (re)load template
temp_save_folder = 'tmp/crashs_template'
md_temp = MeshData(load_vtk(f'{temp_save_folder}/template_shoot_left.vtk'), device)
v_temp, f_temp = md_temp.v, md_temp.f

In [158]:
# Define the list of labels that we care about
uclm_labels = uclm_label_grouping.keys()

# Create a label map tensor. For each template vertex, we store for the five Nighres layers
# the probability of each label, plus the background probability
uclm_label_map = np.zeros((v_temp.shape[0], 11, len(md_aff), len(uclm_labels) + 1))

# Load each segmentation in turn
for i, (id, md_i) in enumerate(md_aff.items()):

    # Read the global segmentation
    fn_global_seg = f'/data/pauly2/ashs_xv/input/{id}/{id}_label_global.nii.gz'
    img = sitk.ReadImage(fn_global_seg, outputPixelType=sitk.sitkFloat32)

    # Read the cortex image - because it has the dimensions that match the mesh
    img_ref = sitk.ReadImage(data[id]['workspace'].cruise_gm_prob)

    # Split the segmentation into binary images
    x_img = sitk.GetArrayFromImage(img)
    p_lab = np.stack([ x_img == l for l in uclm_labels])
    p_lab = np.append(p_lab, [1 - p_lab.sum(0)], axis=0)

    # Put it back into a vector ITK image (not very efficient)
    img_p = sitk.GetImageFromArray(p_lab.transpose(1,2,3,0), isVector=True)
    img_p.CopyInformation(img)

    # Smooth the image a bit
    img_p = sitk.SmoothingRecursiveGaussian(img_p, [0.4, 0.4, 0.4])

    # Load the fitted mesh and its interpolation arrays
    pd_i_fit = load_vtk(f'tmp/template_fit_to_{id}.vtk')
    v_int = vtk_get_point_array(pd_i_fit, 'omt_v_int').astype(np.int32)
    w_int = vtk_get_point_array(pd_i_fit, 'omt_w_int')

    # Load the midsurface mesh from cruise and map the template to this space
    pd_hw_nighres = load_vtk(data[id]['workspace'].fn_cruise('mtl_avg_l2m-mesh.vtk'))
    v_temp_to_hw_nighres = np.einsum('vij,vi->vj', vtk_get_points(pd_hw_nighres)[v_int,:], w_int)
    pd_temp_to_hw_nighres = vtk_make_pd(v_temp_to_hw_nighres, vtk_get_triangles(pd_i_fit))

    # Propagate this fitted mesh through the levelset layers using OMT
    img_ls = sitk.ReadImage(data[id]['workspace'].fn_cruise('mtl_layering-boundaries.nii.gz'))
    prof_meshes, mid_layer = profile_meshing_omt(img_ls, source_mesh=pd_temp_to_hw_nighres)

    # Repeat for each mesh layer
    print(f'Sampling {id}')
    for layer, pd_layer in enumerate(prof_meshes):

        # Load the mesh for that layer
        # pd_layer = load_vtk(data[id]['workspace'].fn_cruise(f'mtl_mesh-p{layer}.vtk'))
        # pd_layer = load_vtk(f'tmp/template_subdiv_{id}_mesh-p{layer}.vtk')
        v_layer = vtk_get_points(pd_layer)

        # Sample the label map at this layer, keeping in mind that the layer is going to
        # be in image coordinate space, not physical space
        n_hit, n_miss = 0, 0
        for j in range(v_layer.shape[0]):
            # Transform to physical point via reference image
            p_j = img_ref.TransformContinuousIndexToPhysicalPoint(v_layer[j, :].tolist())
            try:
                l = img_p.EvaluateAtPhysicalPoint(p_j, sitk.sitkLinear)
                n_hit = n_hit + 1
            except:
                n_miss = n_miss + 1
                l = 0
            uclm_label_map[j, layer, i, :] = l

517x538 clusters, computed at scale = 7.212
Successive scales :  157.380, 157.380, 125.904, 100.723, 80.579, 64.463, 51.570, 41.256, 33.005, 26.404, 21.123, 16.899, 13.519, 10.815, 8.652, 6.922, 5.537, 4.430, 3.544, 2.835, 2.268, 1.814, 1.452, 1.161, 0.929, 0.743, 0.595, 0.476, 0.381, 0.304, 0.244, 0.195, 0.156, 0.125, 0.100, 0.080, 0.064, 0.051, 0.050
Jump from coarse to fine between indices 14 (σ=8.652) and 15 (σ=6.922).
Keep 84964/278146 = 30.5% of the coarse cost matrix.
Keep 82501/267289 = 30.9% of the coarse cost matrix.
Keep 87660/289444 = 30.3% of the coarse cost matrix.
OMT matching distance: 0.8265379667282104, time elapsed: 1.4615063667297363
Propagating to layer 4
513x528 clusters, computed at scale = 7.191
Successive scales :  156.919, 156.919, 125.535, 100.428, 80.342, 64.274, 51.419, 41.135, 32.908, 26.327, 21.061, 16.849, 13.479, 10.783, 8.627, 6.901, 5.521, 4.417, 3.533, 2.827, 2.261, 1.809, 1.447, 1.158, 0.926, 0.741, 0.593, 0.474, 0.379, 0.304, 0.243, 0.194, 0.155, 0

In [160]:
# Average the map between subjects
uclm_label_consensus = np.mean(uclm_label_map, 2)

# Consensus labeling
pd_consensus = vtk_make_pd(v_temp, f_temp)
for label in range(uclm_label_consensus.shape[2]):
    layer_prob = uclm_label_consensus[:,:,label].astype(np.float32)
    vtk_set_point_array(pd_consensus, f'uclm_plab_{label:03d}', layer_prob)

save_vtk(pd_consensus, 'tmp/template_ec_test.vtk', binary=True)

### Now do the same for the nnU-Net segmentations!

In [161]:
import glob
test_ids = [ os.path.basename(x) for x in glob.glob('/data/pauly2/ashs_xv/work/fold_0/nnunet/*') ]

In [162]:
# Keep track of ASHS importers and workspaces created
test_data = {}

# Run basic Nighres for each subject
for id in test_ids:
    # Get the side
    side = 'left' if id.endswith('L') else 'right'

    # Create the output dir
    crashs_dir = os.path.join('/data/pauly2/ashs_xv/work/fold_0/crashs_apply_nnunet', id)
    workspace = Workspace(crashs_dir, id, side)

    # Store the data
    test_data[id] = { 'side': side, 'workspace': workspace }

In [None]:
from pykeops.torch import LazyTensor
def laplacian_kernel(x, y, sigma=0.1):
    x_i = LazyTensor(x[:, None, :])  # (M, 1, 1)
    y_j = LazyTensor(y[None, :, :])  # (1, N, 1)
    D_ij = ((x_i - y_j) ** 2).sum(-1)  # (M, N) symbolic matrix of squared distances
    return (-D_ij.sqrt() / sigma).exp()


In [172]:
for i, id in enumerate(test_ids):

    # Read the segmentation that we want to apply this labeling to
    fn_target = f'/data/pauly2/ashs_xv/work/fold_0/nnunet/{id}/{id}_ivseg_ashs_upsample.nii.gz'
    img = sitk.ReadImage(fn_target, outputPixelType=sitk.sitkFloat32)

    # Extract pixels of interest
    x_img = sitk.GetArrayViewFromImage(img)
    x_img = np.sum(np.stack([x_img == l for l in uclm_group_labels_gm]), axis=0)

    # Get the pixel coordinates - these are flipped so they correspond to ITK
    nz = np.flip(np.transpose(np.stack(np.nonzero(x_img)).astype(np.float32)), 1)

    # Map each of these coordinates into the Nighres mesh coordinates
    img_ref = sitk.ReadImage(test_data[id]['workspace'].cruise_gm_prob)
    y = np.array([img_ref.TransformPhysicalPointToContinuousIndex(
        img.TransformContinuousIndexToPhysicalPoint(nz[j,:].tolist())) for j in range(nz.shape[0]) ])
    
    # Load the fitted template with its interpolation arrays
    pd_i_fit = load_vtk(test_data[id]['workspace'].fit_omt_match)
    v_int = vtk_get_point_array(pd_i_fit, 'omt_v_int').astype(np.int32)
    w_int = vtk_get_point_array(pd_i_fit, 'omt_w_int')

    # Load the midsurface mesh from cruise and map the template to this space
    pd_hw_nighres = load_vtk(test_data[id]['workspace'].fn_cruise('mtl_avg_l2m-mesh.vtk'))
    v_temp_to_hw_nighres = np.einsum('vij,vi->vj', vtk_get_points(pd_hw_nighres)[v_int,:], w_int)
    pd_temp_to_hw_nighres = vtk_make_pd(v_temp_to_hw_nighres, vtk_get_triangles(pd_i_fit))

    # Propagate this fitted mesh through the levelset layers using OMT
    img_ls = sitk.ReadImage(test_data[id]['workspace'].fn_cruise('mtl_layering-boundaries.nii.gz'))
    prof_meshes, mid_layer = profile_meshing_omt(img_ls, source_mesh=pd_temp_to_hw_nighres)

    # Load the Nighres meshes - these are the coordinates from which we wish to interpolate
    mesh_coord, mesh_label = [], []
    for layer, pd_layer in enumerate(prof_meshes):

        # Map template vertices to that layer and append to list of all interpolated coords
        mesh_coord.append(vtk_get_points(pd_layer))
        mesh_label.append(uclm_label_consensus[:,layer,:])

    # Put the coordinates together - these are the points we want to interpolate from
    x = np.concatenate(mesh_coord)

    # Put the label probabilities together - these are the values we want to propagate
    b = np.concatenate(mesh_label)

    # Perform RBF interpolation, so easy!
    K_yx = laplacian_kernel(
        torch.tensor(y, dtype=torch.float32, device=device), 
        torch.tensor(x, dtype=torch.float32, device=device),
        sigma=1.0)
    b_y = (K_yx @ torch.tensor(b, dtype=torch.float32, device=device)).detach().cpu().numpy()

    # Assign the values to the vertices
    l_best = np.argmax(b_y[:,:-1], 1)
    l_best_remap = np.zeros_like(l_best)
    for k, label in enumerate(uclm_label_grouping.keys()):
        l_best_remap[l_best == k] = label
    x_img[x_img != 0] = l_best_remap

    img_result = sitk.GetImageFromArray(x_img)
    img_result.CopyInformation(img)

    # Save the image
    sitk.WriteImage(img_result, f'tmp/crashs_nnunet_fill_labels_{id}.nii.gz')

609x655 clusters, computed at scale = 6.835
Successive scales :  149.151, 149.151, 119.321, 95.457, 76.365, 61.092, 48.874, 39.099, 31.279, 25.023, 20.019, 16.015, 12.812, 10.250, 8.200, 6.560, 5.248, 4.198, 3.359, 2.687, 2.149, 1.720, 1.376, 1.101, 0.880, 0.704, 0.563, 0.451, 0.361, 0.289, 0.231, 0.185, 0.148, 0.118, 0.095, 0.076, 0.061, 0.050
Jump from coarse to fine between indices 14 (σ=8.200) and 15 (σ=6.560).
Keep 114380/398895 = 28.7% of the coarse cost matrix.
Keep 106887/370881 = 28.8% of the coarse cost matrix.
Keep 122079/429025 = 28.5% of the coarse cost matrix.
OMT matching distance: 1.2052192687988281, time elapsed: 1.6201815605163574
Propagating to layer 4
611x632 clusters, computed at scale = 6.820
Successive scales :  148.825, 148.825, 119.060, 95.248, 76.198, 60.959, 48.767, 39.014, 31.211, 24.969, 19.975, 15.980, 12.784, 10.227, 8.182, 6.545, 5.236, 4.189, 3.351, 2.681, 2.145, 1.716, 1.373, 1.098, 0.879, 0.703, 0.562, 0.450, 0.360, 0.288, 0.230, 0.184, 0.147, 0.118, 

# Older propagation code based on Nighres

In [8]:
# (re)load template
temp_save_folder = 'tmp/crashs_template'
md_temp = MeshData(load_vtk(f'{temp_save_folder}/template_shoot_left.vtk'), device)

In [54]:
def loop_subdivide(v, f, iterations=1):
    ms = pymeshlab.MeshSet()
    ms.add_mesh(pymeshlab.Mesh(vertex_matrix=v, face_matrix=f))
    ms.meshing_surface_subdivision_loop(iterations=iterations, threshold=pymeshlab.Percentage(0.0))
    return ms.mesh(0).vertex_matrix(), ms.mesh(0).face_matrix()

In [56]:
# Subdivide the template by splitting triangles
subdivide_level = 1
v_temp_sub, f_temp_sub = loop_subdivide(md_temp.v, md_temp.f, iterations=subdivide_level)
print(md_temp.v.shape, md_temp.f.shape, v_temp_sub.shape, f_temp_sub.shape)

(6460, 3) (12916, 3) (25834, 3) (51664, 3)


In [57]:
# Load the CSV with the label descriptions
import pandas as pd
df_uclm = pd.read_csv('/data/pauly2/ashs_xv/manifest/label_grouping_ashs.csv',
                      names=('uclm','group','uclm_desc','group_desc'))

# These are the gray matter labels
uclm_group_labels_gm = (1,3,4,5,7,8,9)

# Label to group map
uclm_label_grouping = { q['uclm']: q['group'] for _, q in df_uclm.iterrows() if q['group'] in uclm_group_labels_gm }

In [58]:
import SimpleITK as sitk
import nighres

# Define the list of labels that we care about
uclm_labels = uclm_label_grouping.keys()

# Create a label map tensor. For each template vertex, we store for the five Nighres layers
# the probability of each label, plus the background probability
uclm_label_map = np.zeros((v_temp_sub.shape[0], 5, len(md_aff), len(uclm_labels) + 1))

# Load each segmentation in turn
for i, (id, md_i) in enumerate(md_aff.items()):

    # Read the global segmentation
    fn_global_seg = f'/data/pauly2/ashs_xv/input/{id}/{id}_label_global.nii.gz'
    img = sitk.ReadImage(fn_global_seg, outputPixelType=sitk.sitkFloat32)

    # Read the cortex image - because it has the dimensions that match the mesh
    img_ref = sitk.ReadImage(data[id]['workspace'].cruise_gm_prob)

    # Split the segmentation into binary images
    x_img = sitk.GetArrayFromImage(img)
    p_lab = np.stack([ x_img == l for l in uclm_labels])
    p_lab = np.append(p_lab, [1 - p_lab.sum(0)], axis=0)

    # Put it back into a vector ITK image (not very efficient)
    img_p = sitk.GetImageFromArray(p_lab.transpose(1,2,3,0), isVector=True)
    img_p.CopyInformation(img)

    # Smooth the image a bit
    img_p = sitk.SmoothingRecursiveGaussian(img_p, [0.4, 0.4, 0.4])

    # Load the fitted mesh and its interpolation arrays
    pd_i_fit = load_vtk(f'tmp/template_fit_to_{id}.vtk')
    v_int = vtk_get_point_array(pd_i_fit, 'omt_v_int').astype(np.int32)
    w_int = vtk_get_point_array(pd_i_fit, 'omt_w_int')

    # Load the midsurface mesh from cruise and map the template to this space
    pd_hw_nighres = load_vtk(data[id]['workspace'].fn_cruise('mtl_mesh-p2.vtk'))
    v_temp_to_hw_nighres = np.einsum('vij,vi->vj', vtk_get_points(pd_hw_nighres)[v_int,:], w_int)

    # Load the template that is OMT matched to the target inflated mesh, and subdivide it
    v_omt_sub, f_omt_sub = loop_subdivide(v_temp_to_hw_nighres, vtk_get_triangles(pd_i_fit), iterations=subdivide_level)
    print(v_temp_to_hw_nighres.shape, vtk_get_triangles(pd_i_fit).shape, v_omt_sub.shape, f_omt_sub.shape)

    # Perform profile meshing using the subdivided mesh
    profile_meshing = nighres.laminar.profile_meshing(
                          profile_surface_image=data[id]['workspace'].fn_cruise(f'mtl_layering-boundaries.nii.gz'),
                          starting_surface_mesh={'points': v_omt_sub, 'faces': f_omt_sub},
                          save_data=True,
                          overwrite=True,
                          file_name=f'template_subdiv_{id}.vtk',
                          output_dir='tmp/')
    
    # Repeat for each mesh layer
    print(f'Sampling {id}')
    for layer in range(5):

        # Load the mesh for that layer
        # pd_layer = load_vtk(data[id]['workspace'].fn_cruise(f'mtl_mesh-p{layer}.vtk'))
        pd_layer = load_vtk(f'tmp/template_subdiv_{id}_mesh-p{layer}.vtk')
        v_layer = vtk_get_points(pd_layer)

        # Sample the label map at this layer, keeping in mind that the layer is going to
        # be in image coordinate space, not physical space
        n_hit, n_miss = 0, 0
        for j in range(v_layer.shape[0]):
            # Transform to physical point via reference image
            p_j = img_ref.TransformContinuousIndexToPhysicalPoint(v_layer[j, :].tolist())
            try:
                l = img_p.EvaluateAtPhysicalPoint(p_j, sitk.sitkLinear)
                n_hit = n_hit + 1
            except:
                n_miss = n_miss + 1
                l = 0
            uclm_label_map[j, layer, i, :] = l

(6460, 3) (12916, 3) (25834, 3) (51664, 3)

Profile meshing

Outputs will be saved to tmp/

Saving tmp/template_subdiv_104937L_mesh-p0.vtk

Saving tmp/template_subdiv_104937L_mesh-p1.vtk

Saving tmp/template_subdiv_104937L_mesh-p2.vtk

Saving tmp/template_subdiv_104937L_mesh-p3.vtk

Saving tmp/template_subdiv_104937L_mesh-p4.vtk
lines: (2584, 1)
indices: (2584, 5)
lines: (2584, 6)
Sampling 104937L
(6460, 3) (12916, 3) (25834, 3) (51664, 3)

Profile meshing

Outputs will be saved to tmp/

Saving tmp/template_subdiv_106049L_mesh-p0.vtk

Saving tmp/template_subdiv_106049L_mesh-p1.vtk

Saving tmp/template_subdiv_106049L_mesh-p2.vtk

Saving tmp/template_subdiv_106049L_mesh-p3.vtk

Saving tmp/template_subdiv_106049L_mesh-p4.vtk
lines: (2584, 1)
indices: (2584, 5)
lines: (2584, 6)
Sampling 106049L
(6460, 3) (12916, 3) (25834, 3) (51664, 3)

Profile meshing

Outputs will be saved to tmp/

Saving tmp/template_subdiv_106312R_mesh-p0.vtk

Saving tmp/template_subdiv_106312R_mesh-p1.vtk

Saving tmp

In [59]:
""" Old code - no upsample

# --- backup of cell 
import SimpleITK as sitk

# Define the list of labels that we care about
uclm_labels = uclm_label_grouping.keys()

# Create a label map tensor. For each template vertex, we store for the five Nighres layers
# the probability of each label, plus the background probability
uclm_label_map = np.zeros((md_temp.v.shape[0], 5, len(md_aff), len(uclm_labels) + 1))

# Load each segmentation in turn
for i, (id, md_i) in enumerate(md_aff.items()):

    # Read the global segmentation
    fn_global_seg = f'/data/pauly2/ashs_xv/input/{id}/{id}_label_global.nii.gz'
    img = sitk.ReadImage(fn_global_seg, outputPixelType=sitk.sitkFloat32)

    # Read the cortex image - because it has the dimensions that match the mesh
    img_ref = sitk.ReadImage(data[id]['workspace'].cruise_gm_prob)

    # Split the segmentation into binary images
    x_img = sitk.GetArrayFromImage(img)
    p_lab = np.stack([ x_img == l for l in uclm_labels])
    p_lab = np.append(p_lab, [1 - p_lab.sum(0)], axis=0)

    # Put it back into a vector ITK image (not very efficient)
    img_p = sitk.GetImageFromArray(p_lab.transpose(1,2,3,0), isVector=True)
    img_p.CopyInformation(img)

    # Smooth the image a bit
    img_p = sitk.SmoothingRecursiveGaussian(img_p, [0.4, 0.4, 0.4])

    # Load the fitted template with its interpolation arrays
    pd_i_fit = load_vtk(f'tmp/template_fit_to_{id}.vtk')
    v_int = vtk_get_point_array(pd_i_fit, 'omt_v_int').astype(np.int32)
    w_int = vtk_get_point_array(pd_i_fit, 'omt_w_int')

    # Repeat for each mesh layer
    print(f'Sampling {id}')
    for layer in range(5):

        # Load the mesh for that layer
        pd_layer = load_vtk(data[id]['workspace'].fn_cruise(f'mtl_mesh-p{layer}.vtk'))

        # Map template vertices to that layer
        v_temp_to_layer = np.einsum('vij,vi->vj', vtk_get_points(pd_layer)[v_int,:], w_int)

        # Sample the label map at this layer, keeping in mind that the layer is going to
        # be in image coordinate space, not physical space
        n_hit, n_miss = 0, 0
        for j in range(v_temp_to_layer.shape[0]):
            # Transform to physical point via reference image
            p_j = img_ref.TransformContinuousIndexToPhysicalPoint(v_temp_to_layer[j, :].tolist())
            try:
                l = img_p.EvaluateAtPhysicalPoint(p_j, sitk.sitkLinear)
                n_hit = n_hit + 1
            except:
                n_miss = n_miss + 1
                l = 0
            uclm_label_map[j, layer, i, :] = l
"""

" Old code - no upsample\n\n# --- backup of cell \nimport SimpleITK as sitk\n\n# Define the list of labels that we care about\nuclm_labels = uclm_label_grouping.keys()\n\n# Create a label map tensor. For each template vertex, we store for the five Nighres layers\n# the probability of each label, plus the background probability\nuclm_label_map = np.zeros((md_temp.v.shape[0], 5, len(md_aff), len(uclm_labels) + 1))\n\n# Load each segmentation in turn\nfor i, (id, md_i) in enumerate(md_aff.items()):\n\n    # Read the global segmentation\n    fn_global_seg = f'/data/pauly2/ashs_xv/input/{id}/{id}_label_global.nii.gz'\n    img = sitk.ReadImage(fn_global_seg, outputPixelType=sitk.sitkFloat32)\n\n    # Read the cortex image - because it has the dimensions that match the mesh\n    img_ref = sitk.ReadImage(data[id]['workspace'].cruise_gm_prob)\n\n    # Split the segmentation into binary images\n    x_img = sitk.GetArrayFromImage(img)\n    p_lab = np.stack([ x_img == l for l in uclm_labels])\n   

In [63]:
# Average the map between subjects
uclm_label_consensus = np.mean(uclm_label_map, 2)

# Consensus labeling
pd_consensus = vtk_make_pd(v_temp_sub, f_temp_sub)
for layer in range(5):
    layer_prob = uclm_label_consensus[:,layer,:].astype(np.float32)
    vtk_set_point_array(pd_consensus, f'layer_{layer}_prob_uclm', layer_prob)
    vtk_set_point_array(pd_consensus, f'layer_{layer}_label_uclm', np.argmax(layer_prob, 1))

save_vtk(pd_consensus, 'tmp/template_ec_test.vtk', binary=True)

Now the real magic, map the consensus distribution into subject space. Here we can try to use RBF interpolation.

In [66]:
from pykeops.torch import LazyTensor
def laplacian_kernel(x, y, sigma=0.1):
    x_i = LazyTensor(x[:, None, :])  # (M, 1, 1)
    y_j = LazyTensor(y[None, :, :])  # (1, N, 1)
    D_ij = ((x_i - y_j) ** 2).sum(-1)  # (M, N) symbolic matrix of squared distances
    return (-D_ij.sqrt() / sigma).exp()

for i, (id, md_i) in enumerate(md_aff.items()):

    # Read the segmentation that we want to apply this labeling to
    fn_target = f'/data/pauly2/ashs_xv/work/fold_0/work/{id}/{id}_ivseg_ashs_upsample.nii.gz'
    img = sitk.ReadImage(fn_target, outputPixelType=sitk.sitkFloat32)

    # Extract pixels of interest
    x_img = sitk.GetArrayViewFromImage(img)
    x_img = np.sum(np.stack([x_img == l for l in uclm_group_labels_gm]), axis=0)

    # Get the pixel coordinates - these are flipped so they correspond to ITK
    nz = np.flip(np.transpose(np.stack(np.nonzero(x_img)).astype(np.float32)), 1)

    # Map each of these coordinates into the Nighres mesh coordinates
    img_ref = sitk.ReadImage(data[id]['workspace'].cruise_gm_prob)
    y = np.array([img_ref.TransformPhysicalPointToContinuousIndex(
        img.TransformContinuousIndexToPhysicalPoint(nz[j,:].tolist())) for j in range(nz.shape[0]) ])
    
    # Load the fitted mesh and its interpolation arrays
    pd_i_fit = load_vtk(f'tmp/template_fit_to_{id}.vtk')
    v_int = vtk_get_point_array(pd_i_fit, 'omt_v_int').astype(np.int32)
    w_int = vtk_get_point_array(pd_i_fit, 'omt_w_int')

    # Load the midsurface mesh from cruise and map the template to this space
    pd_hw_nighres = load_vtk(data[id]['workspace'].fn_cruise('mtl_mesh-p2.vtk'))
    v_temp_to_hw_nighres = np.einsum('vij,vi->vj', vtk_get_points(pd_hw_nighres)[v_int,:], w_int)

    # Load the template that is OMT matched to the target inflated mesh, and subdivide it
    v_omt_sub, f_omt_sub = loop_subdivide(v_temp_to_hw_nighres, vtk_get_triangles(pd_i_fit), iterations=subdivide_level)

    # Perform profile meshing using the subdivided mesh
    profile_meshing = nighres.laminar.profile_meshing(
                          profile_surface_image=data[id]['workspace'].fn_cruise(f'mtl_layering-boundaries.nii.gz'),
                          starting_surface_mesh={'points': v_omt_sub, 'faces': f_omt_sub},
                          save_data=True,
                          overwrite=True,
                          file_name=f'template_subdiv_{id}.vtk',
                          output_dir='tmp/')
    
    # Load the fitted template with its interpolation arrays
    #pd_i_fit = load_vtk(f'tmp/template_fit_to_{id}.vtk')
    #v_int = vtk_get_point_array(pd_i_fit, 'omt_v_int').astype(np.int32)
    #w_int = vtk_get_point_array(pd_i_fit, 'omt_w_int')

    # Load the Nighres meshes - these are the coordinates from which we wish to interpolate
    mesh_coord, mesh_label = [], []
    for layer in range(5):

        # Load the mesh for that layer
        # pd_layer = load_vtk(data[id]['workspace'].fn_cruise(f'mtl_mesh-p{layer}.vtk'))
        pd_layer = load_vtk(f'tmp/template_subdiv_{id}_mesh-p{layer}.vtk')
        v_layer = vtk_get_points(pd_layer)

        # Map template vertices to that layer and append to list of all interpolated coords
        # mesh_coord.append(np.einsum('vij,vi->vj', vtk_get_points(pd_layer)[v_int,:], w_int))
        mesh_coord.append(v_layer)
        mesh_label.append(uclm_label_consensus[:,layer,:])

    # Put the coordinates together - these are the points we want to interpolate from
    x = np.concatenate(mesh_coord)

    # Put the label probabilities together - these are the values we want to propagate
    b = np.concatenate(mesh_label)

    # Perform RBF interpolation, so easy!
    K_yx = laplacian_kernel(
        torch.tensor(y, dtype=torch.float32, device=device), 
        torch.tensor(x, dtype=torch.float32, device=device), 
        sigma=1.0)
    b_y = (K_yx @ torch.tensor(b, dtype=torch.float32, device=device)).detach().cpu().numpy()

    # Assign the values to the vertices
    l_best = np.argmax(b_y[:,:-1], 1)
    l_best_remap = np.zeros_like(l_best)
    for k, label in enumerate(uclm_label_grouping.keys()):
        l_best_remap[l_best == k] = label
    x_img[x_img != 0] = l_best_remap

    img_result = sitk.GetImageFromArray(x_img)
    img_result.CopyInformation(img)

    # Save the image
    sitk.WriteImage(img_result, f'tmp/crashs_fill_labels_{id}.nii.gz')


Profile meshing

Outputs will be saved to tmp/

Saving tmp/template_subdiv_104937L_mesh-p0.vtk

Saving tmp/template_subdiv_104937L_mesh-p1.vtk

Saving tmp/template_subdiv_104937L_mesh-p2.vtk

Saving tmp/template_subdiv_104937L_mesh-p3.vtk

Saving tmp/template_subdiv_104937L_mesh-p4.vtk
lines: (2584, 1)
indices: (2584, 5)
lines: (2584, 6)
[KeOps] Generating code for formula Sum_Reduction(Exp(-Sqrt(Sum((Var(0,3,0)-Var(1,3,1))**2)))*Var(2,33,1),0) ... OK

Profile meshing

Outputs will be saved to tmp/

Saving tmp/template_subdiv_106049L_mesh-p0.vtk

Saving tmp/template_subdiv_106049L_mesh-p1.vtk

Saving tmp/template_subdiv_106049L_mesh-p2.vtk

Saving tmp/template_subdiv_106049L_mesh-p3.vtk

Saving tmp/template_subdiv_106049L_mesh-p4.vtk
lines: (2584, 1)
indices: (2584, 5)
lines: (2584, 6)

Profile meshing

Outputs will be saved to tmp/

Saving tmp/template_subdiv_106312R_mesh-p0.vtk

Saving tmp/template_subdiv_106312R_mesh-p1.vtk

Saving tmp/template_subdiv_106312R_mesh-p2.vtk

Saving tm

### Now do the same for the nnU-Net segmentations

In [403]:
import glob
test_ids = [ os.path.basename(x) for x in glob.glob('/data/pauly2/ashs_xv/work/fold_0/nnunet/*') ]

In [407]:
# Keep track of ASHS importers and workspaces created
test_data = {}

# Run basic Nighres for each subject
for id in test_ids:
    # Get the side
    side = 'left' if id.endswith('L') else 'right'

    # Create the output dir
    crashs_dir = os.path.join('/data/pauly2/ashs_xv/work/fold_0/crashs_apply_nnunet', id)
    workspace = Workspace(crashs_dir, id)

    # Store the data
    test_data[id] = { 'side': side, 'workspace': workspace }

In [164]:
for i, id in enumerate(test_ids):

    # Read the segmentation that we want to apply this labeling to
    fn_target = f'/data/pauly2/ashs_xv/work/fold_0/nnunet/{id}/{id}_ivseg_ashs_upsample.nii.gz'
    img = sitk.ReadImage(fn_target, outputPixelType=sitk.sitkFloat32)

    # Extract pixels of interest
    x_img = sitk.GetArrayViewFromImage(img)
    x_img = np.sum(np.stack([x_img == l for l in uclm_group_labels_gm]), axis=0)

    # Get the pixel coordinates - these are flipped so they correspond to ITK
    nz = np.flip(np.transpose(np.stack(np.nonzero(x_img)).astype(np.float32)), 1)

    # Map each of these coordinates into the Nighres mesh coordinates
    img_ref = sitk.ReadImage(test_data[id]['workspace'].cruise_gm_prob)
    y = np.array([img_ref.TransformPhysicalPointToContinuousIndex(
        img.TransformContinuousIndexToPhysicalPoint(nz[j,:].tolist())) for j in range(nz.shape[0]) ])
    
    # Load the fitted template with its interpolation arrays
    pd_i_fit = load_vtk(test_data[id]['workspace'].fit_omt_match)
    v_int = vtk_get_point_array(pd_i_fit, 'omt_v_int').astype(np.int32)
    w_int = vtk_get_point_array(pd_i_fit, 'omt_w_int')

    # Load the Nighres meshes - these are the coordinates from which we wish to interpolate
    mesh_coord, mesh_label = [], []
    for layer in range(5):

        # Load the mesh for that layer
        pd_layer = load_vtk(test_data[id]['workspace'].fn_cruise(f'mtl_mesh-p{layer}.vtk'))

        # Map template vertices to that layer and append to list of all interpolated coords
        mesh_coord.append(np.einsum('vij,vi->vj', vtk_get_points(pd_layer)[v_int,:], w_int))
        mesh_label.append(uclm_label_consensus[:,layer,:])

    # Put the coordinates together - these are the points we want to interpolate from
    x = np.concatenate(mesh_coord)

    # Put the label probabilities together - these are the values we want to propagate
    b = np.concatenate(mesh_label)

    # Perform RBF interpolation, so easy!
    K_yx = laplacian_kernel(torch.tensor(y, dtype=torch.float32, device=device), 
                        torch.tensor(x, dtype=torch.float32, device=device))
    b_y = (K_yx @ torch.tensor(b, dtype=torch.float32, device=device)).detach().cpu().numpy()

    # Assign the values to the vertices
    l_best = np.argmax(b_y[:,:-1], 1)
    l_best_remap = np.zeros_like(l_best)
    for k, label in enumerate(uclm_label_grouping.keys()):
        l_best_remap[l_best == k] = label
    x_img[x_img != 0] = l_best_remap

    img_result = sitk.GetImageFromArray(x_img)
    img_result.CopyInformation(img)

    # Save the image
    sitk.WriteImage(img_result, f'tmp/crashs_nnunet_fill_labels_{id}.nii.gz')

KeyError: '100551R'

# Debug extreme stretching in template to subject fitting

In [325]:
# Which id to test on?
test_id = '106049L'
md_temp = MeshData(load_vtk(f'tmp/template_final_with_momenta.vtk'), device=device)
md_subj = MeshData(load_vtk(f'tmp/test_register_{test_id}.vtk'), device=device)

In [367]:
def lddmm_fit_subject(md_temp, md_subj, n_iter=50, sigma_lddmm=5, sigma_varifold=10, gamma_lddmm=0.1):

    # LDDMM kernels
    device = md_temp.vt.device
    K_temp = GaussKernel(sigma=torch.tensor(sigma_lddmm, dtype=torch.float32, device=device))
    K_vari = GaussLinKernelWithLabels(torch.tensor(sigma_varifold, dtype=torch.float32, device=device), md_temp.lp.shape[1])

    # Create losses for each of the target meshes
    d_loss = lossVarifoldSurfWithLabels(md_temp.ft, md_subj.vt, md_subj.ft, md_temp.lpt, md_subj.lpt, K_vari) 
            
    # Create the root->template points/momentum, as well as the template->subject momenta
    q_temp = torch.tensor(md_temp.vt, dtype=torch.float32, device=device, requires_grad=True).contiguous()
    p_temp = torch.zeros(md_temp.vt.shape, dtype=torch.float32, device=device, requires_grad=True).contiguous()

    # Create the optimizer
    start = time.time()
    optimizer = torch.optim.LBFGS([p_temp], max_eval=16, max_iter=16, line_search_fn='strong_wolfe')

    def closure(detail = False):
        optimizer.zero_grad()
        _, q_i = Shooting(p_temp, q_temp, K_temp)[-1]
        l_ham = gamma_lddmm * Hamiltonian(K_temp)(p_temp, q_temp)
        l_data = d_loss(q_i)
        L = l_ham + l_data
        L.backward()
        if detail:
            return l_ham, l_data, L
        else:
            return L

    # Perform optimization
    for i in range(n_iter):
        l_ham, l_data, L = closure(True)
        print(f'Iteration {i:03d}  Losses: H={l_ham:8.6f}  D={l_data:8.6f}  Total={L:8.6f}')
        optimizer.step(closure)

    print(f'Optimization (L-BFGS) time: {round(time.time() - start, 2)} seconds')

    # Return the root model and the momenta
    return p_temp


def lddmm_fit_subject_jac_penalty(md_temp, md_subj, n_iter=50, sigma_lddmm=5, sigma_varifold=10, 
                                  gamma_lddmm=0.1, w_jac_penalty=1.0):

    # LDDMM kernels
    device = md_temp.vt.device
    K_temp = GaussKernel(sigma=torch.tensor(sigma_lddmm, dtype=torch.float32, device=device))
    K_vari = GaussLinKernelWithLabels(torch.tensor(sigma_varifold, dtype=torch.float32, device=device), md_temp.lp.shape[1])

    # Create losses for each of the target meshes
    d_loss = lossVarifoldSurfWithLabels(md_temp.ft, md_subj.vt, md_subj.ft, md_temp.lpt, md_subj.lpt, K_vari) 
            
    # Create the root->template points/momentum, as well as the template->subject momenta
    q_temp = torch.tensor(md_temp.vt, dtype=torch.float32, device=device, requires_grad=True).contiguous()
    p_temp = torch.zeros(md_temp.vt.shape, dtype=torch.float32, device=device, requires_grad=True).contiguous()

    # Create the optimizer
    start = time.time()
    optimizer = torch.optim.LBFGS([p_temp], max_eval=16, max_iter=16, line_search_fn='strong_wolfe')

    z0 = md_temp.vt[md_temp.ft]
    area_0 = torch.norm(torch.cross(z0[:,1,:] - z0[:,0,:], z0[:,2,:] - z0[:,0,:]),dim=1) 

    def closure(detail = False):
        optimizer.zero_grad()
        _, q_i = Shooting(p_temp, q_temp, K_temp)[-1]
        z = q_i[md_temp.ft]
        area = torch.norm(torch.cross(z[:,1,:] - z[:,0,:], z[:,2,:] - z[:,0,:]),dim=1)
        log_jac = torch.log10(area / area_0)

        l_ham = gamma_lddmm * Hamiltonian(K_temp)(p_temp, q_temp)
        l_data = d_loss(q_i)
        l_jac = torch.sum(log_jac ** 2) * w_jac_penalty
        L = l_ham + l_data + l_jac
        L.backward()
        if detail:
            return l_ham, l_data, l_jac, L
        else:
            return L

    # Perform optimization
    for i in range(n_iter):
        l_ham, l_data, l_jac, L = closure(True)
        print(f'Iteration {i:03d}  Losses: H={l_ham:8.6f}  D={l_data:8.6f}  J={l_jac:8.6f}  Total={L:8.6f}')
        optimizer.step(closure)

    print(f'Optimization (L-BFGS) time: {round(time.time() - start, 2)} seconds')

    # Return the root model and the momenta
    return p_temp


In [368]:
# Examine pairwise registration
p = lddmm_fit_subject_jac_penalty(md_temp, md_subj, n_iter=50,
                      sigma_lddmm = template['left'].get_lddmm_sigma(), 
                      sigma_varifold = 5,
                      gamma_lddmm=0.05, w_jac_penalty=10.)


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



Iteration 000  Losses: H=0.000000  D=69699.453125  J=0.000000  Total=69699.453125
Iteration 001  Losses: H=18.435059  D=5901.156250  J=3948.298096  Total=9867.889648
Iteration 002  Losses: H=27.894758  D=2977.500000  J=3456.654785  Total=6462.049805
Iteration 003  Losses: H=35.484474  D=2132.656250  J=3025.152832  Total=5193.293457
Iteration 004  Losses: H=40.652882  D=1759.687500  J=2814.804688  Total=4615.145020
Iteration 005  Losses: H=43.886608  D=1567.750000  J=2646.022461  Total=4257.659180
Iteration 006  Losses: H=46.895168  D=1437.000000  J=2548.185791  Total=4032.081055
Iteration 007  Losses: H=50.572807  D=1359.375000  J=2472.786377  Total=3882.734131
Iteration 008  Losses: H=52.279385  D=1293.437500  J=2437.608398  Total=3783.325195
Iteration 009  Losses: H=53.251942  D=1234.125000  J=2431.110840  Total=3718.487793
Iteration 010  Losses: H=54.692726  D=1203.250000  J=2404.950195  Total=3662.893066
Iteration 011  Losses: H=55.352955  D=1167.250000  J=2388.989014  Total=3611.5

In [369]:
md_fit = shoot_template_to_subject(md_temp, p, template['left'].get_lddmm_sigma())
jac = torch.norm(torch.cross(z[:,1,:] - z[:,0,:], z[:,2,:] - z[:,0,:]),dim=1) / torch.norm(torch.cross(z0[:,1,:] - z0[:,0,:], z0[:,2,:] - z0[:,0,:]),dim=1)

vtk_set_cell_array(md_fit.pd, 'jacobian', jac.detach().cpu().numpy())
save_vtk(md_fit.pd, f'tmp/test_bad_fit_{test_id}.vtk')

In [353]:
z = md_fit.vt[md_fit.ft]
z0 = md_temp.vt[md_temp.ft]

In [355]:
jac = torch.norm(torch.cross(z[:,1,:] - z[:,0,:], z[:,2,:] - z[:,0,:]),dim=1) / torch.norm(torch.cross(z0[:,1,:] - z0[:,0,:], z0[:,2,:] - z0[:,0,:]),dim=1)

save_vtk(pd_temp_save, 'tmp/template_final_with_momenta.vtk')
md_aff[id] = MeshData(load_vtk(f'tmp/test_register_{id}.vtk'), device)

In [362]:
np.max(np.log(jac.detach().cpu().numpy()))

3.3452086

# Debug CRUISE mesh propagation
Layer mesh propagation does not match the average surface all the way down to the GM/CSF surface, some profiles never reach the outer boundary due to low gradient (I think). Try using CRUISE to compute a correspondence between concentric layers

In [132]:
# Define the subject to experiment with
id_test = '118374L'
ws_test = data[id_test]['workspace']

# Load the subject's layer levelset image
img_ls = sitk.ReadImage(ws_test.fn_cruise('mtl_layering-boundaries.nii.gz'))

In [137]:
layer = img_ls[:,:,:,5]

def layer_to_mesh(layer, edge_len_pct=0.75, to_ras=True):
    pix_raw = sitk.GetArrayFromImage(layer)
    img = vtk.vtkImageData()
    img.GetPointData().SetScalars(numpy_to_vtk(pix_raw.flatten(), array_type=vtk.VTK_FLOAT))
    img.SetDimensions(layer.GetSize()[0], layer.GetSize()[1], layer.GetSize()[2]) 
    # img.SetOrigin(layer.GetOrigin()[0],layer.GetOrigin()[1],layer.GetOrigin()[2])
    # img.SetSpacing(layer.GetSpacing()[0],layer.GetSpacing()[1],layer.GetSpacing()[2])
    # img.SetDirectionMatrix(*layer.GetDirection())

    cube = vtk.vtkMarchingCubes()
    cube.SetInputData(img)
    cube.SetNumberOfContours(1)
    cube.SetValue(0, 0.0)

    tri1 = vtk.vtkTriangleFilter()
    tri1.SetInputConnection(cube.GetOutputPort())
    tri1.PassLinesOff()
    tri1.PassVertsOff()

    clean = vtk.vtkCleanPolyData()
    clean.SetInputConnection(tri1.GetOutputPort())
    clean.PointMergingOn()
    clean.SetTolerance(0.0)

    tri2 = vtk.vtkTriangleFilter()
    tri2.SetInputConnection(clean.GetOutputPort())
    tri2.PassLinesOff()
    tri2.PassVertsOff()

    tri2.Update()
    pd_cubes = tri2.GetOutput()

    # Apply remeshing to the template
    ms = pymeshlab.MeshSet()
    ms.add_mesh(pymeshlab.Mesh(vertex_matrix=vtk_get_points(pd_cubes), face_matrix=vtk_get_triangles(pd_cubes)))
    ms.meshing_isotropic_explicit_remeshing(targetlen = pymeshlab.Percentage(edge_len_pct))
    v_remesh, f_remesh = ms.mesh(0).vertex_matrix(), ms.mesh(0).face_matrix()
    return vtk_make_pd(v_remesh, f_remesh)

In [138]:
layer_5 = layer_to_mesh(img_ls[:,:,:,5])
layer_6 = layer_to_mesh(img_ls[:,:,:,6])

_, w_omt = match_omt(
    torch.tensor(vtk_get_points(layer_5), dtype=torch.float32, device=device),
    torch.tensor(vtk_get_triangles(layer_5), dtype=torch.int32, device=device),
    torch.tensor(vtk_get_points(layer_6), dtype=torch.float32, device=device),
    torch.tensor(vtk_get_triangles(layer_6), dtype=torch.int32, device=device))

v_omt, _, _ = omt_match_to_vertex_weights(layer_5, layer_6, w_omt.detach().cpu().numpy())
layer5_omt_to_6 = vtk_make_pd(v_omt, vtk_get_triangles(layer_5))

553x580 clusters, computed at scale = 6.971
Successive scales :  152.127, 152.127, 121.702, 97.361, 77.889, 62.311, 49.849, 39.879, 31.903, 25.523, 20.418, 16.335, 13.068, 10.454, 8.363, 6.691, 5.352, 4.282, 3.426, 2.740, 2.192, 1.754, 1.403, 1.122, 0.898, 0.718, 0.575, 0.460, 0.368, 0.294, 0.235, 0.188, 0.151, 0.121, 0.096, 0.077, 0.062, 0.050
Jump from coarse to fine between indices 14 (σ=8.363) and 15 (σ=6.691).
Keep 102453/320740 = 31.9% of the coarse cost matrix.
Keep 97693/305809 = 31.9% of the coarse cost matrix.
Keep 107462/336400 = 31.9% of the coarse cost matrix.
OMT matching distance: 0.5250161290168762, time elapsed: 1.624000072479248


In [139]:
# Propagate a mesh through the levelset shells using OMT, starting with the 
# specified level. If no level specified, middle level will be used
def profile_meshing_omt(img_ls, source_mesh=None, init_layer=None):

    # The OMT algorithm that is called repeatedly
    def do_omt(l_src, l_trg):
        _, w_omt = match_omt(
            torch.tensor(vtk_get_points(l_src), dtype=torch.float32, device=device),
            torch.tensor(vtk_get_triangles(l_src), dtype=torch.int32, device=device),
            torch.tensor(vtk_get_points(l_trg), dtype=torch.float32, device=device),
            torch.tensor(vtk_get_triangles(l_trg), dtype=torch.int32, device=device))
        v_omt, _, _ = omt_match_to_vertex_weights(l_src, l_trg, w_omt.detach().cpu().numpy())
        return vtk_make_pd(v_omt, vtk_get_triangles(l_src))
    
    # Extract all the isolayers from the image 
    layers = [ layer_to_mesh(img_ls[:,:,:,k]) for k in range(img_ls.GetSize()[3]) ]
    matched = [ None for k in layers ]

    # Determine init layer
    if init_layer is None:
        init_layer = img_ls.GetSize()[3] // 2

    # Propagate from the mesh to the initial layer (unless we want to use a layer as source)
    matched[init_layer] = do_omt(source_mesh, layers[init_layer]) if source_mesh else layers[init_layer]

    # Propagate in both directions
    seq1 = reversed(range(0, init_layer))
    seq2 = reversed(range(img_ls.GetSize()[3]-1, init_layer, -1))
    for seq in seq1, seq2:
        prop_source = matched[init_layer]
        for k in seq:
            print(f'Propagating to layer {k}')
            matched[k] = do_omt(prop_source, layers[k])
            prop_source = matched[k]

    # Return the propagated meshes
    return matched, init_layer



In [140]:
prof_meshes, mid_layer = profile_meshing_omt(img_ls)

Propagating to layer 4
564x551 clusters, computed at scale = 6.935
Successive scales :  151.332, 151.332, 121.066, 96.853, 77.482, 61.986, 49.589, 39.671, 31.737, 25.389, 20.311, 16.249, 12.999, 10.399, 8.320, 6.656, 5.325, 4.260, 3.408, 2.726, 2.181, 1.745, 1.396, 1.117, 0.893, 0.715, 0.572, 0.457, 0.366, 0.293, 0.234, 0.187, 0.150, 0.120, 0.096, 0.077, 0.061, 0.050
Jump from coarse to fine between indices 14 (σ=8.320) and 15 (σ=6.656).
Keep 97692/310764 = 31.4% of the coarse cost matrix.
Keep 100152/318096 = 31.5% of the coarse cost matrix.
Keep 95287/303601 = 31.4% of the coarse cost matrix.
OMT matching distance: 0.4317602515220642, time elapsed: 1.5990068912506104
Propagating to layer 3
561x543 clusters, computed at scale = 6.904
Successive scales :  150.673, 150.673, 120.538, 96.431, 77.144, 61.716, 49.372, 39.498, 31.598, 25.279, 20.223, 16.178, 12.943, 10.354, 8.283, 6.627, 5.301, 4.241, 3.393, 2.714, 2.171, 1.737, 1.390, 1.112, 0.889, 0.712, 0.569, 0.455, 0.364, 0.291, 0.233, 

In [141]:
for i, prop_layer in enumerate(prof_meshes):
    save_vtk(prop_layer, f'tmp/test_mesh_profiles_dist_{i:02d}.vtk')

In [114]:
save_vtk(layer_5, 'tmp/test_levelset_5.vtk')
save_vtk(layer_6, 'tmp/test_levelset_6.vtk')
save_vtk(layer5_omt_to_6, 'tmp/test_levelset_5_to_6.vtk')

In [82]:
img.SetDirectionMatrix(*layer.GetDirection())

In [91]:
img_ls.GetSize()

(118, 112, 120, 11)