In [None]:
from Dissects.io import (load_NDskl, 
                         load_image,
                         load_skeleton,
                         save_skeleton,
                         save_fits
                        )
from Dissects.image import (z_project,
                            thinning,
                            dilation)
from Dissects.geometry import Skeleton
from Dissects.segmentation.seg_2D import (segmentation, 
                                          junction_around_cell,
                                         vertices,
                                         junctions,
                                         generate_mesh)
from Dissects.segmentation.seg_3D_apical import (flatten_tissu,
                                                 binary_flatten_tissu,
                                                 generate_mesh_3D)

from Dissects.analysis.analysis import (general_analysis,
                              cellstats
                              )
from skimage import morphology

In [None]:
import os
from skimage import io
import numpy as np
import pandas as pd
import copy
import sys
import matplotlib.pyplot as plt
sys.setrecursionlimit(10000)

import plotly.express as px

from scipy.ndimage.morphology import (binary_fill_holes,
                                      binary_dilation,
                                     )
from scipy import ndimage
# %matplotlib notebook
%load_ext snakeviz

In [None]:
directory='/media/admin-suz/Datas/testCellSeg/Test1/'

cp, fil, point, specs = load_NDskl(os.path.join(directory,'im_bin4_3d.fits_c1e+03.up.NDskl.a.NDskl'))
# #create skeleton object
skel = Skeleton(cp, fil, point, specs)

In [None]:
#Clean skeleton
skel.remove_lonely_cp()
skel.remove_free_filament()

In [None]:
skel.specs

In [None]:
fig, ax = plt.subplots()

ax.scatter(skel.critical_point.x, skel.critical_point.y, c=skel.critical_point.z, cmap='Blues', alpha=0.7)

data_crit_2 = skel.critical_point[skel.critical_point.nfil==3]
ax.scatter(x=data_crit_2['x'],
            y=data_crit_2['y'],
           color= 'red')

data_crit_2 = skel.critical_point[skel.critical_point.nfil>3]
ax.scatter(x=data_crit_2['x'],
            y=data_crit_2['y'],
           color= 'black')

ax.axis('equal')

# ax.set_xlim(300, 400)
# ax.set_ylim(50, 100)
fig.set_size_inches(10, 10)


In [None]:
# get original image

img0 = io.imread(os.path.join(directory, "C1-20171214_sqh-GFP_ap-alpha-cat-RFP_WP2h-001-dorsal_AiSc_green_bin4.tif"))
img_output = img0.copy()
img_output = np.where(img_output>0, 0, img_output) 
w = 0

for _, p in skel.point.iterrows(): 
    if w != 0:
        for z_ in range(int(p.z)-w,int(p.z)+w):
            for y_ in range(int(p.y)-w,int(p.y)+w):
                for x_ in range(int(p.x)-w,int(p.x)+w):
                    try:
                        img_output[z_][y_][x_]=1
                    except:
                        pass
    else : 
        x_=int(p.x)
        y_=int(p.y)
        z_=int(p.z)
        img_output[z_][y_][x_]=1
img_binary_3d = img_output.copy()

# Need a skeletonization to avoid vertex detection problem
img_binary_3d = morphology.skeletonize(img_binary_3d)
#io.imsave('binary_images.tif', img_output)

# 3D vertex detection

In [None]:
# from Dissects.segmentation.seg_3D_apical import find_vertex

In [None]:
import scipy as sci
def find_vertex(skeleton_mask, free_edges=False):
    """
    free_edges : if True, find vertex extremity
    warning :  make sure to have a skeletonize the output of disperse
    """

    # Need to be improve
    kernel = np.array(pd.read_csv('../Dissects/segmentation/3d_pattern.csv', header=None))
    kernel = kernel.reshape((int(kernel.shape[0]/9), 3, 3, 3))
    
    output_image = np.zeros(skeleton_mask.shape)

    for i in np.arange(len(kernel)):
        out = sci.ndimage.binary_hit_or_miss(skeleton_mask, kernel[i] )
        output_image = output_image + out

    if free_edges==True:
        kernel = kernels_extremity()
        for i in np.arange(len(kernel)):
            out = sci.ndimage.binary_hit_or_miss(skeleton_mask, kernel[i] )
            output_image = output_image + out

    return output_image


def clean_vertex(vertex_image):
    s = sci.ndimage.generate_binary_structure(3,3)
    labeled_array, num_features = sci.ndimage.label(vertex_image, structure=s)
    unique_, count_ = np.unique(labeled_array, return_counts=True)

    vertex = {}

    index=0
    for u, c, in zip(unique_, count_):
        if c==1:
            vertex[index]=np.array(np.where(labeled_array==u)).flatten()
        else:
            vertex[index]=np.mean(np.array(np.where(labeled_array==u)), axis=1, dtype='int')

        index+=1

    vert_df = pd.DataFrame.from_dict(vertex, orient='index', columns=list('zyx'))

    # remove first line which is the background
    vert_df = vert_df.loc[1:]
    
    return vert_df

In [None]:
# %%snakeviz
import time
start = time.time()
output_vertex = find_vertex(img_binary_3d)
end = time.time()
print(end-start)

In [None]:
vert_df = clean_vertex(output_vertex)

In [None]:
import plotly.graph_objects as go

z0, y0,x0 = np.where(img_binary_3d==1)
fond = go.Scatter3d(x=x0,
                      y=y0,
                      z=z0,
                      mode='markers',
                     marker=dict(
                        size=2,
                        color='black',
                        opacity=0.1
                     )
    )


vertex_isolated = go.Scatter3d(x=vert_df.x,
                      y=vert_df.y,
                      z=vert_df.z,
                      mode='markers',
                     marker=dict(
                        size=2,
                        color='red', 
                        opacity=1
                     )
    )



fig = go.Figure(data=[fond, vertex_isolated], )
fig.update_layout(title='Filament', 
                  autosize=False,
                  width=1000,
                  height=1000,
                  margin=dict(l=65, r=50, b=65, t=90),
                 )

fig['layout'].update(scene=dict(aspectmode="data"))
fig.show()

In [None]:
# remove vertex +3x3x3 from initial image 
img_binary_3d_without_vertex = img_binary_3d.copy()
for i, p in vert_df.iterrows():
    for z_ in range(int(p.z)-1,int(p.z)+2):
        for y_ in range(int(p.y)-1,int(p.y)+2):
            for x_ in range(int(p.x)-1,int(p.x)+2):
                try:
                    img_binary_3d_without_vertex[z_][y_][x_]=0
                except:
                    pass
s = sci.ndimage.generate_binary_structure(3,3)
labeled_array, num_features = sci.ndimage.label(img_binary_3d_without_vertex, structure=s)

In [None]:
import plotly.graph_objects as go

z0, y0, x0 = np.where(labeled_array==5)

from matplotlib.colors import ListedColormap
rand = (np.random.rand(500,3)*255).astype('int')
rand[0] = 0
cmap_rand = ListedColormap(rand)
z0, y0,x0 = np.where(img_binary_3d==1)
fond = go.Scatter3d(x=x0,
                      y=y0,
                      z=z0,
                      mode='markers',
                     marker=dict(
                        size=2,
                        color='black',
                        opacity=0.1
                     )
    )

z0, y0, x0 = np.where(labeled_array>=1)
edge = go.Scatter3d(x=x0,
                      y=y0,
                      z=z0,
                      mode='markers',
                     marker=dict(
                        size=2,
                        color=rand[labeled_array[z0,y0,x0]],
                        opacity=1
                     )
    )

fig = go.Figure(data=[edge], )
fig.update_layout(title='Filament', 
                  autosize=False,
                  width=1000,
                  height=1000,
                  margin=dict(l=65, r=50, b=65, t=90),
                 )
fig['layout'].update(scene=dict(aspectmode="data"))
fig.show()

In [None]:
def find_edges(img_binary_3d, output_vertex, vert_df):
    # remove vertex +3x3x3 from initial image 
    img_binary_3d_without_vertex = img_binary_3d.copy()
    for i, p in vert_df.iterrows():
        for z_ in range(int(p.z)-1,int(p.z)+2):
            for y_ in range(int(p.y)-1,int(p.y)+2):
                for x_ in range(int(p.x)-1,int(p.x)+2):
                    try:
                        img_binary_3d_without_vertex[z_][y_][x_]=0
                    except:
                        pass
    s = sci.ndimage.generate_binary_structure(3,3)
    labeled_array, num_features = sci.ndimage.label(img_binary_3d_without_vertex, structure=s)
    
    #labeled_array

    binary_edges = np.zeros(labeled_array.shape)
    binary_edges = np.where(labeled_array>0, 1, 0)


    # clean output vertex
    output_vertex = np.zeros(output_vertex.shape)
    for i, val in vert_df.iterrows():
        output_vertex[val.z, val.y, val.x]=1


    # Initiate edge_df dataframe
    edge_df = pd.DataFrame(index = range(1,num_features+1), columns=['srce', 'trgt'], dtype='int')

    for i, val in vert_df.iterrows():
        img_vert = np.zeros(img_output.shape)

        img_vert[val.z, val.y, val.x] = 1

        s = sci.ndimage.generate_binary_structure(3,3)

        img_vert_dilate = binary_dilation(img_vert, structure=s)
        img_corresponding_vertex = img_vert_dilate + binary_edges
        while np.count_nonzero(img_corresponding_vertex==2)<2:
            img_vert_dilate = binary_dilation(img_vert_dilate, structure=s)
            img_corresponding_vertex= img_vert_dilate + binary_edges

        edges = labeled_array[np.where(img_corresponding_vertex==2)]
        for e in np.unique(edges):
            if np.isnan(edge_df.loc[e]['srce']):
                edge_df.loc[e]['srce'] = i
            elif np.isnan(edge_df.loc[e]['trgt']):
                edge_df.loc[e]['trgt'] = i
            else:
                print("problem:", str(i))
                print(edge_df.loc[e])

    tmp=[]
    for e in edge_df.index:
        tmp.append(str(np.where(labeled_array==e)))
    edge_df['points'] = tmp

    edge_df.dropna(axis=0, inplace=True)
    
    # recupère les petites jonctions perdus

    from sklearn.neighbors import KDTree, BallTree
    from io import StringIO

    # Compte le nombre de jonction associé à un vertex
    srce_count = np.unique(edge_df.srce, return_counts=True)
    trgt_count = np.unique(edge_df.trgt, return_counts=True)
    res={}
    for i, v in zip(srce_count[0], srce_count[1]):
        res[i] = res.get(i, 0)+v

    for i, v in zip(trgt_count[0], trgt_count[1]):
        res[i] = res.get(i, 0)+v

    res = pd.DataFrame.from_dict({"idx":res.keys(), "value":res.values()})   
    vert_ = res[res.value<=2]['idx'].to_numpy()
    while len(vert_)>0:
        X = vert_df[['x', 'y', 'z']].values
        tree = BallTree(X, metric='euclidean')
        dist, ind = tree.query(X[int(vert_[0]-1):int(vert_[0])], 2)

        edge_df.loc[edge_df.index.max()+1]={'srce':vert_df.index[ind[0][0]], 'trgt':vert_df.index[ind[0][1]]}
        # Compte le nombre de jonction associé à un vertex
        srce_count = np.unique(edge_df.srce, return_counts=True)
        trgt_count = np.unique(edge_df.trgt, return_counts=True)
        res={}
        for i, v in zip(srce_count[0], srce_count[1]):
            res[i] = res.get(i, 0)+v

        for i, v in zip(trgt_count[0], trgt_count[1]):
            res[i] = res.get(i, 0)+v

        res = pd.DataFrame.from_dict({"idx":res.keys(), "value":res.values()})   
        vert_ = res[res.value<=2]['idx'].to_numpy()



    #remove doublon
    edge_df['min']=np.min(edge_df[['srce', 'trgt']], axis=1)
    edge_df['max']=np.max(edge_df[['srce', 'trgt']], axis=1)

    edge_df['srce'] = edge_df['min']
    edge_df['trgt'] = edge_df['max']
    edge_df.drop(['min','max'], axis=1, inplace=True)

    edge_df.drop_duplicates(inplace=True)
       
    return edge_df

In [None]:
edge_df = find_edges(img_binary_3d, output_vertex, vert_df)

In [None]:
import plotly.graph_objects as go

z0, y0,x0 = np.where(img_binary_3d==1)
fond = go.Scatter3d(x=x0,
                      y=y0,
                      z=z0,
                      mode='markers',
                     marker=dict(
                        size=2,
                        color='black',
                        opacity=0.1
                     )
    )


vertex_isolated = go.Scatter3d(x=vert_df.x,
                      y=vert_df.y,
                      z=vert_df.z,
                      mode='markers',
                     marker=dict(
                        size=2,
                        color='red', 
                        opacity=1
                     )
    )



fig = go.Figure(data=[fond, vertex_isolated], )
fig.update_layout(title='Filament', 
                  autosize=False,
                  width=1000,
                  height=1000,
                  margin=dict(l=65, r=50, b=65, t=90),
                 )


from matplotlib.colors import ListedColormap
rand = (np.random.rand(500,3)*255).astype('int')
rand[0] = 0
cmap_rand = ListedColormap(rand)

zs,ys,xs = vert_df.loc[edge_df['srce']].values.flatten(order='F').reshape(3,edge_df.shape[0])
zt,yt,xt = vert_df.loc[edge_df['trgt']].values.flatten(order='F').reshape(3,edge_df.shape[0])
for i in range(len(zs)):
    
    fig.add_trace(
            go.Scatter3d(
                x=[xs[i], xt[i]],
                y=[ys[i], yt[i]],
                z=[zs[i], zt[i]],
                mode='lines',
                line={"color":rand[[i,i,i]],
                      "width":10,
                     },                
            )
        )


fig['layout'].update(scene=dict(aspectmode="data"), showlegend=False)
fig.show()

In [None]:
zeds

In [None]:
import plotly.graph_objects as go

z0, y0,x0 = np.where(img_binary_3d==1)
fond = go.Scatter3d(x=x0,
                      y=y0,
                      z=z0,
                      mode='markers',
                     marker=dict(
                        size=2,
                        color='black',
                        opacity=0.1
                     )
    )


z0, y0,x0 = np.where(labeled_array!=0)
c = labeled_array[z0, y0,x0]
vertex = go.Scatter3d(x=x0,
                      y=y0,
                      z=z0,
                      mode='markers',
                     marker=dict(
                        size=2,
#                         color=c, 
                         color='blue',
                        opacity=0.5
                     )
    )


vertex_isolated = go.Scatter3d(x=vert_df.x,
                      y=vert_df.y,
                      z=vert_df.z,
                      mode='markers',
                     marker=dict(
                        size=2,
                        color='red', 
                        opacity=1
                     )
    )

fig = go.Figure(data=[fond, vertex_isolated], )
fig.update_layout(title='Filament', 
                  autosize=False,
                  width=1000,
                  height=1000,
                  margin=dict(l=65, r=50, b=65, t=90),
                 )
fig['layout'].update(scene=dict(aspectmode="data"))
fig.show()

In [None]:
from sklearn.neighbors import KDTree
from io import StringIO

In [None]:
tree = BallTree(np.deg2rad(vert_df[['x', 'y', 'z']].values), metric='euclidean')


In [None]:
other_data = """NAME x y z
a 253 155 9
b 500 55 82"""

df_other = pd.read_csv(StringIO(other_data), sep = ' ')

In [None]:
query_x = df_other['x']
query_y = df_other['y']
query_z = df_other['z']

distances, indices = tree.query(np.deg2rad(np.c_[query_x, query_y, query_z]), k = 3)


In [None]:
for name, d, ind in zip(df_other['NAME'], distances, indices):
    print(f"NAME {name} closest matches:")
    for i, index in enumerate(ind):
        print(f"\t{index+1} with distance {d[i]:.4f}")


In [None]:
distances

In [None]:
import plotly.graph_objects as go

z0, y0,x0 = np.where(img_binary_3d==1)
fond = go.Scatter3d(x=x0,
                      y=y0,
                      z=z0,
                      mode='markers',
                     marker=dict(
                        size=2,
                        color='black',
                        opacity=0.1
                     )
    )


vertex_isolated = go.Scatter3d(x=vert_df.x,
                      y=vert_df.y,
                      z=vert_df.z,
                      mode='markers',
                     marker=dict(
                        size=2,
                        color='red', 
                        opacity=1
                     )
    )


vertex_origin = go.Scatter3d(x=vert_df.loc[[1]]['x'],
                      y=vert_df.loc[[1]]['y'],
                      z=vert_df.loc[[1]]['z'],
                      mode='markers',
                     marker=dict(
                        size=2,
                        color='green', 
                        opacity=1
                     )
    )

l = [19,34]
vertex_near = go.Scatter3d(x=vert_df.loc[l]['x'],
                      y=vert_df.loc[l]['y'],
                      z=vert_df.loc[l]['z'],
                      mode='markers',
                     marker=dict(
                        size=2,
                        color='blue', 
                        opacity=1
                     )
    )

fig = go.Figure(data=[fond, vertex_isolated,vertex_origin, vertex_near], )
fig.update_layout(title='Filament', 
                  autosize=False,
                  width=1000,
                  height=1000,
                  margin=dict(l=65, r=50, b=65, t=90),
                 )
fig['layout'].update(scene=dict(aspectmode="data"))
fig.show()

In [None]:
from scipy.spatial import ConvexHull

In [None]:
pts = np.array((vert_df.x, vert_df.y, vert_df.z)).flatten(order='F').reshape((vert_df.shape[0],3))

In [None]:
hull = ConvexHull(pts)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")

# Plot defining corner points
ax.plot(pts.T[0], pts.T[1], pts.T[2], "ko")

# 12 = 2 * 6 faces are the simplices (2 simplices per square face)
for s in hull.simplices:
    s = np.append(s, s[0])  # Here we cycle back to the first coordinate
    ax.plot(pts[s, 0], pts[s, 1], pts[s, 2], "r-")

# Make axis label
for i in ["x", "y", "z"]:
    eval("ax.set_{:s}label('{:s}')".format(i, i))

plt.show()

In [None]:
import plotly.figure_factory as ff
fig = ff.create_trisurf(x=pts[:,0],
                        y=pts[:,1],
                        z=pts[:,2],
                         simplices=hull.simplices[:,:3],
                        aspectratio=dict(x=1, y=1, z=0.3),
                       )
fig['data'][0].update(opacity=0.75)
fig.show()

In [None]:
import plotly.graph_objects as go

z0, y0,x0 = np.where(img_binary_3d==1)
fond = go.Scatter3d(x=x0,
                      y=y0,
                      z=z0,
                      mode='markers',
                     marker=dict(
                        size=2,
                        color='black',
                        opacity=0.1
                     )
    )


from plotly.tools import FigureFactory as FF
trisurf = FF.create_trisurf(x=pts[:,0],
                        y=pts[:,1],
                        z=pts[:,2],
                         simplices=hull.simplices[:,:3],)


fig = go.Figure(data=[fond, fond, trisurf], )
fig.update_layout(title='Filament', 
                  autosize=False,
                  width=1000,
                  height=1000,
                  margin=dict(l=65, r=50, b=65, t=90),
                 )
fig['layout'].update(scene=dict(aspectmode="data"))
fig.show()