In [None]:
from PYME.LMVis import VisGUI

%gui wx

In [None]:
pymevis = VisGUI.ipython_pymevisualize()
pipeline = pymevis.pipeline

import numpy as np
import os

In [None]:
"""
This compares ICTM mesh quality as a function of parameters.
"""

# Where to save the intermediate files generated 
save_fp = 'C:\\Users\\zrc4\\Downloads\\test_ictm_13'
if not os.path.exists(save_fp):
    os.mkdir(save_fp)

# three way junction generation parameters
centroid = np.array([0,0,0])
h, r = 500, 50       # capsule length, capsule radius
smoothing = r      # ~h/2 creates a three-way junction with a "sheet" (but it's puffy like seaweed)
# loc precision
psf_width = 250.0
mean_photon_count = 300.0

# octree-based reconstruction parameters
cull_inner_surfaces = True
n_points_min = 10
remesh = True
repair = False
smooth_curvature = True
density = 9e-7

# screened poisson reconstruction parameters
fulldepth=5


In [None]:
"""
Generate three-way junction
"""

from ch_shrinkwrap.shape import ThreeWayJunction

twj = ThreeWayJunction(h, r, centroid, smoothing)

In [None]:
"""
Generate and save three-way junction ponts
"""
import time

from ch_shrinkwrap import util

cap_points = twj.points(p=0.0001, psf_width=psf_width, mean_photon_count=mean_photon_count, resample=True)
cap_sigma = twj._noise

no = 0.1
scale = 1.2
bbox = [np.min(cap_points[:,0]), np.min(cap_points[:,1]), 
        np.min(cap_points[:,2]), np.max(cap_points[:,0]),
        np.max(cap_points[:,1]), np.max(cap_points[:,2])]
bbox = [scale*x for x in bbox]
xl, yl, zl, xu, yu, zu = bbox
xn, yn, zn = xu-xl, yu-yl, zu-zl
ln = int(no*len(cap_points)/(1.0-no))
noise_points = np.random.rand(ln,3)*(np.array([xn,yn,zn])[None,:]) + (np.array([xl,yl,zl])[None,:])
noise_sigma = util.noise(noise_points.shape, model='poisson', psf_width=psf_width, mean_photon_count=mean_photon_count)

points = np.vstack([cap_points,noise_points])
sigma = np.vstack([cap_sigma,noise_sigma])
s = np.sqrt((sigma*sigma).sum(1))

points_time = time.strftime('%Y%d%m_%HH%M')
points_fn = f"twj_h{h}_r{r}_smoothing{smoothing}_{points_time}".replace('.','_')+".txt"
# points_fn = "twj_h500_r50_smoothing50_20210911_22H12.txt"
points_fp = os.path.join(save_fp, points_fn)
np.savetxt(points_fp, np.vstack([points.T,s]).T, header="x y z sigma")

In [None]:
# Now we need PYMEVis
pymevis.OpenFile(points_fp)


In [None]:
# Octree renderer (original mesh based on parameters above)
from PYME.LMVis.Extras.extra_layers import gen_octree_from_points
from PYME.recipes.surface_fitting import DualMarchingCubes

oc_name = gen_octree_from_points(pymevis)
surf_name, surf_count = pipeline.new_ds_name('surf', return_count=True)

recipe = pipeline.recipe
dmc = DualMarchingCubes(recipe, invalidate_parent=False, input=oc_name, output=surf_name,
                       threshold_density=density, n_points_min=n_points_min, remesh=remesh,
                       repair=repair, cull_inner_surfaces=cull_inner_surfaces, smooth_curvature=smooth_curvature)
recipe.add_modules_and_execute([dmc,])

octree_fp = points_fp.split('.txt')[0] + '_octree.stl'
pipeline.dataSources[dmc.output].to_stl(octree_fp)

In [None]:
"""
At this point, we will start from the original surface and iterate using
different parameters.
"""
# Shrinkwrap renderer
from ch_shrinkwrap import _membrane_mesh as membrane_mesh

max_iters = np.hstack([5, np.logspace(1,2,10).astype(int)])  # integer
step_size = np.logspace(0,2,10)                              # float
search_k = np.hstack([2,np.arange(0,105,5)[1:]])[:5]         # integer
remesh_every = np.array([5,10,20,50,100])                    # integer

# max_iters = np.array([10])
# remesh_every = np.array([5])
# step_size = np.array([12.9])
# search_k = np.array([20])

failed_count = 0
for it in max_iters:
    for lam in step_size:
        for k in search_k:
            for re in remesh_every:
                # Copy the mesh over
                mesh = membrane_mesh.MembraneMesh(mesh=pipeline.dataSources[dmc.output])

                # set params
                mesh.max_iter = it
                mesh.step_size = lam
                mesh.search_k = k
                mesh.remesh_frequency = re

                start = time.time()
                try:
                    mesh.shrink_wrap(points, s, method='ictm')
                except:
                    failed_count += 1
                    pass
                stop = time.time()
                duration = stop-start

                wrap_fp = points_fp.split('.txt')[0] + "_".join(f"_iters{it}_remesh{re}_lambda{lam:.1f}_searchk{k}_noise{no:.1f}_ntriangles{mesh.faces.shape[0]}_duration{duration:.1f}_ictm".split('.')) + ".stl"
                mesh.to_stl(wrap_fp)
print(f'# failed: {failed_count}')

In [None]:
"""
Now calculate SPR.
"""

def screened_poisson(points_fp, rowstoskip=1, strformat='X Y Z Reflectance', separator='SPACE',
                     colorformat='[0-255]', onerror='skip', k=10, smoothiter=0,
                     flipflag=False, viewpos=[0,0,0], depth=8, fulldepth=5,
                     cgdepth=0, scale=1.1, samplespernode=1.5, pointweight=4, 
                     iters=8, confidence=False, preclean=False):
    """
    Run screened poisson reconstruction on a set of points, using meshlab.
    
    For more information on these parameters, see meshlab.
    
    Parameters
    ----------
    points_fp : str
        Path to text file containing a point cloud represented as a set of XYZ coordinates
    see meshlab
    
    Returns
    -------
    str
        File path to STL of reconstruction
        
    """
    import pymeshlab as ml

    ms = ml.MeshSet()  # create a mesh
    ms.load_new_mesh(file_name=points_fp,
                     rowtoskip=rowstoskip,
                     strformat=strformat,
                     separator=separator,
                     rgbmode=colorformat,
                     onerror=onerror)  # load points
    start = time.time()
    # compute normals
    ms.compute_normals_for_point_sets(k=k,  # number of neighbors
                                      smoothiter=smoothiter,
                                      flipflag=flipflag,
                                      viewpos=viewpos)
    # run SPR
    ms.surface_reconstruction_screened_poisson(visiblelayer=False,
                                               depth=depth,
                                               fulldepth=fulldepth,
                                               cgdepth=cgdepth,
                                               scale=scale,
                                               samplespernode=samplespernode,
                                               pointweight=pointweight,
                                               iters=iters,
                                               confidence=confidence,
                                               preclean=preclean)
    stop = time.time()
    duration = stop-start
    # save surface
    surface_fp = points_fp.split('.txt')[0] + "_".join(f"_searchk{k}_depth{depth}_scale{scale:.1f}_samplespernode{samplespernode:.1f}_pointweight{pointweight:.1f}_iters{iters}_noise{no:.1f}_duration{duration:.1f}_spr".split('.')) + ".stl"
    ms.save_current_mesh(file_name=surface_fp, unify_vertices=True)
    
    return surface_fp

max_iters = np.arange(7,10)            # integer
depth = 2**np.arange(1,4)              # integer
search_k = np.arange(5,30,5)           # integer
scale=np.linspace(0,1.2,5)             # float
samplespernode=np.linspace(13,19,5)    # float
pointweight=np.linspace(0,4,5)         # float

# max_iters = np.array([8],dtype=int)
# depth = np.array([8],dtype=int)
# search_k = np.array([10])
# scale=np.array([1.1])
# samplespernode=np.array([1.5])
# pointweight=np.array([4])

for it in max_iters:
    for k in search_k:
        for d in depth:
            for sc in scale:
                for spn in samplespernode:
                    for wt in pointweight:
                        spr_fp = screened_poisson(points_fp, k=k, depth=d, fulldepth=fulldepth, 
                                                  scale=sc, samplespernode=spn, pointweight=wt, iters=it)

In [None]:
"""
Now load and analyze the meshes.
"""

import re
import glob
from ch_shrinkwrap import _membrane_mesh as membrane_mesh

In [None]:
test_points = twj.points(p=0.0001, psf_width=psf_width, mean_photon_count=mean_photon_count, resample=True, noise=None)
print(len(test_points))

In [None]:
from PYME.experimental.isosurface import distance_to_mesh

In [None]:
ictm_error, ictm_its, ictm_remesh, ictm_lamb, ictm_pd = [], [], [], [], []
ictm_searchk, ictm_noise, ictm_ntris, ictm_runtime = [], [], [], []

for fn in glob.glob(save_fp+"\\*ictm.stl"):
    mesh = membrane_mesh.MembraneMesh.from_stl(fn)
    ictm_its.append(int(re.search("(?<=iters)\d+",fn).group(0)))
    ictm_remesh.append(int(re.search("(?<=remesh)\d+",fn).group(0)))
    ictm_lamb.append(float(re.search("(?<=lambda)\d+\_\d+",fn).group(0).replace('_','.')))
    ictm_searchk.append(int(re.search("(?<=searchk)\d+",fn).group(0)))
    ictm_noise.append(float(re.search("(?<=noise)\d+\_\d+",fn).group(0).replace('_','.')))
    ictm_ntris.append(int(re.search("(?<=ntriangles)\d+",fn).group(0)))
    ictm_runtime.append(float(re.search("(?<=duration)\d+\_\d+",fn).group(0).replace('_','.')))
    
    # Calculate error per face
    vecs = mesh._vertices[mesh.faces]['position']
    ictm_error.append(twj.sdf(vecs.mean(1).T))
    ictm_pd.append(distance_to_mesh(test_points, mesh))
    
ictm_its, ictm_remesh = np.array(ictm_its), np.array(ictm_remesh)
ictm_lamb, ictm_searchk = np.array(ictm_lamb), np.array(ictm_searchk)
ictm_noise, ictm_ntris = np.array(ictm_noise), np.array(ictm_ntris)
ictm_runtime = np.array(ictm_runtime)

In [None]:
spr_error, spr_searchk, spr_depth, spr_scale, spr_pd = [], [], [], [], []
spr_spn, spr_pointweight, spr_its, spr_noise = [], [], [], []
spr_runtime = []

fail_count = 0
for fn in glob.glob(save_fp+"\\*spr.stl"):
    try:
        mesh = membrane_mesh.MembraneMesh.from_stl(fn)
        spr_searchk.append(int(re.search("(?<=searchk)\d+",fn).group(0)))
        spr_depth.append(int(re.search("(?<=depth)\d+",fn).group(0)))
        spr_scale.append(float(re.search("(?<=scale)\d+\_\d+",fn).group(0).replace('_','.')))
        spr_spn.append(float(re.search("(?<=samplespernode)\d+\_\d+",fn).group(0).replace('_','.')))
        spr_pointweight.append(int(re.search("(?<=pointweight)\d+_\d+",fn).group(0).replace('_','.')))
        spr_its.append(int(re.search("(?<=iters)\d+",fn).group(0)))
        spr_noise.append(float(re.search("(?<=noise)\d+\_\d+",fn).group(0).replace('_','.')))
        spr_runtime.append(float(re.search("(?<=duration)\d+\_\d+",fn).group(0).replace('_','.')))

        # Calculate error per face
        vecs = mesh._vertices[mesh.faces]['position']
        spr_error.append(twj.sdf(vecs.mean(1).T))
        spr_pd.append(distance_to_mesh(test_points, mesh))
    except:
        fail_count +=1
        
print(f'# failed: {fail_count}')
    
spr_searchk = np.array(spr_searchk)
spr_depth, spr_scale = np.array(spr_depth), np.array(spr_scale)
spr_spn, spr_pointweight = np.array(spr_spn), np.array(spr_pointweight)
spr_its, spr_noise = np.array(spr_its), np.array(spr_noise)
spr_runtime = np.array(spr_runtime)

In [None]:
ictm_mses = []
ictm_abs_err = []
ictm_pd_err =[]
spr_mses = []
spr_abs_err = []
spr_pd_err = []
for i in range(len(ictm_error)):
    ictm_mses.append(np.sum(ictm_error[i]**2)/len(ictm_error[i]))
    ictm_abs_err.append(np.sum(ictm_error[i]))
    ictm_pd_err.append(np.sum(np.abs(ictm_pd[i])))
for i in range(len(spr_error)):
    spr_mses.append(np.sum(spr_error[i]**2)/len(spr_error[i]))
    spr_abs_err.append(np.sum(spr_error[i]))
    spr_pd_err.append(np.sum(np.abs(spr_pd[i])))

In [None]:
scatter(ictm_pd_error)

In [None]:
# idx = np.arange(len(ictm_mses))[~np.isnan(ictm_mses)][np.argmin(np.array(ictm_mses)[~np.isnan(ictm_mses)])]
# srt = np.argsort(np.abs(ictm_abs_err))
# idx = srt[np.argmax(ictm_ntris[srt]>1000)]
# idx = np.argsort(ictm_mses)[114]
srt = np.argsort(ictm_mses)
cutoff = np.median(np.array(ictm_pd_err)[~np.isnan(ictm_pd_err)])
idx = srt[np.argmax(np.array(ictm_pd_err)[srt]<cutoff)]
print("Best ICTM mesh...")
ictm_fn = points_fp.split('.txt')[0] + "_".join(f"_iters{ictm_its[idx]}_remesh{ictm_remesh[idx]}_lambda{ictm_lamb[idx]:.1f}_searchk{ictm_searchk[idx]}_noise{ictm_noise[idx]:.1f}_ntriangles{ictm_ntris[idx]}_duration{ictm_runtime[idx]:.1f}_ictm".split('.')) + ".stl"
print(ictm_fn)
print(f"MSE : {ictm_mses[idx]} ",
      f"iters: {ictm_its[idx]}  ", 
      f"remesh frequency: {ictm_remesh[idx]} ",
      f"lambda: {ictm_lamb[idx]}  ", 
      f"search k: {ictm_searchk[idx]}  ", 
      f"noise : {ictm_noise[idx]} ",
      f"# triangles: {ictm_ntris[idx]}", 
      f"duration: {ictm_runtime[idx]}")

# idx = np.argmin(spr_mses)
# idx = np.arange(len(spr_mses))[~np.isnan(spr_mses)][np.argmin(np.array(spr_mses)[~np.isnan(spr_mses)])]
# idx = np.argmin(np.abs(spr_abs_err))
srt = np.argsort(spr_mses)
cutoff = np.median(np.array(spr_pd_err)[~np.isnan(spr_pd_err)])
idx = srt[np.argmax(np.array(spr_pd_err)[srt]<cutoff)]
print("Best SPR mesh...")
spr_fn = points_fp.split('.txt')[0] + "_".join(f"_searchk{spr_searchk[idx]}_depth{spr_depth[idx]}_scale{spr_scale[idx]:.1f}_samplespernode{spr_spn[idx]:.1f}_pointweight{spr_pointweight[idx]:.1f}_iters{spr_its[idx]}_noise{spr_noise[idx]:.1f}_duration{spr_runtime[idx]:.1f}_spr".split('.')) + ".stl"
print(spr_fn)
print(f"MSE : {spr_mses[idx]} ",
      f"search k: {spr_searchk[idx]}  ", 
      f"depth: {spr_depth[idx]}  ", 
      f"scale: {spr_scale[idx]}  ",
      f"samples per node: {spr_spn[idx]}  ",
      f"pointweight: {spr_pointweight[idx]}  ",      
      f"iters: {spr_its[idx]}  ",  
      f"noise : {spr_noise[idx]} ",
      f"duration: {spr_runtime[idx]}")

In [None]:
import matplotlib.pyplot as plt

def make_scatterplots(x, y, label_x = None, label_y = None, title=None):
    assert(len(x) == len(y))
    rows, cols = (len(x)+1)//2, 2
    fig, axs = plt.subplots(rows, cols)
    for i in range(len(x)):
        axs[i//2][i%2].scatter(x[i], y[i], s=0.1)
        if label_x:
            axs[i//2][i%2].set_xlabel(label_x[i])
        if label_y:
            axs[i//2][i%2].set_ylabel(label_y[i])
    if title:
        fig.suptitle(title, fontsize=16)
        
ictm_x = [ictm_its, ictm_remesh, ictm_lamb, ictm_searchk]
ictm_y = [ictm_mses]*len(ictm_x)
ictm_label_x = ['iterations', 'remesh frequency', 'lambda', 'search k']
ictm_label_y = ['MSE']*len(ictm_x)
make_scatterplots(ictm_x, ictm_y, ictm_label_x, ictm_label_y, 'ICTM')
        
spr_x = [spr_searchk, spr_depth, spr_scale, spr_spn, spr_pointweight, spr_its]
spr_y = [spr_mses]*len(spr_x)
spr_label_x = ['search k', 'depth', 'scale', 'samples per node', 'point weight', 'iterations']
spr_label_y = ['MSE']*len(spr_x)
make_scatterplots(spr_x, spr_y, spr_label_x, spr_label_y, 'SPR')

In [None]:
import itertools
def make_comparative_plot(x, y, list_fixed=[], vals_fixed=None):
    """
    Compare x and y holding values in list_fixed fixed. vals_fixed can
    be used to make a subrange of list_fixed.
    """
    assert(len(x) == len(y))
    for it in list_fixed:
        assert(len(x) == len(it))
    unique_fixed = [np.unique(it) for it in list_fixed]
    if vals_fixed:
        vals_fixed = [it if len(jt) == 0 else jt for it, jt in zip(unique_fixed,vals_fixed)]
        unique_fixed = [list(set(it).intersection(jt)) for it, jt in zip(unique_fixed,vals_fixed)]
    opts = itertools.product(*unique_fixed)
    for opt in opts:
        idxs = np.ones(len(x),dtype=bool)
        i = 0
        for val in opt:
            idxs &= (list_fixed[i] == val)
            i += 1
        xi = x[idxs]
        srt = np.argsort(xi)
        plt.plot(xi[srt], y[idxs][srt])

plt.figure()
make_comparative_plot(ictm_its, np.array(ictm_mses), [ictm_lamb, ictm_remesh, ictm_searchk], [[], [50], [10]])
plt.legend([f'lambda={lam}' for lam in np.unique(ictm_lamb)])
plt.xlabel('# iterations')
plt.ylabel('MSE')

In [None]:
plt.figure()
make_comparative_plot(spr_pointweight, np.array(spr_mses), [spr_depth, spr_searchk, spr_its, spr_scale, spr_spn], [[4], [], [8], [0.5], [16]])

In [None]:
np.median(np.array(ictm_pd_err)[~np.isnan(ictm_pd_err)])