In [6]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly
import plotly.subplots as sp
import plotly.io as pio
from cloudvolume import CloudVolume as cv
from caveclient import CAVEclient
import meshparty
from meshparty import trimesh_vtk, trimesh_io 
from collections import Counter
from sklearn.cluster import KMeans
import time
import copy
from scipy.spatial import cKDTree, distance, ConvexHull, Delaunay, distance_matrix
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from scipy.ndimage import label
from itertools import combinations

vol =cv('gs://zheng_mouse_hippocampus_production/v2/seg_m195',parallel=True, progress=False, use_https=True)
client = CAVEclient('zheng_ca3')
t195 = client.materialize.get_timestamp(195)
tstamp = t195


In [15]:


def cluster_center(points):
    center = points.mean(axis=0)
    d = np.linalg.norm(points - center, axis=1)
    return [center]    

def get_mesh_in_bbox(vertices, faces, bbox):
    x_min, x_max, y_min, y_max, z_min, z_max = bbox
    mask = ((vertices[:, 0] >= x_min) & (vertices[:, 0] <= x_max) &
            (vertices[:, 1] >= y_min) & (vertices[:, 1] <= y_max) &
            (vertices[:, 2] >= z_min) & (vertices[:, 2] <= z_max))
    vertices_in_bbox = vertices[mask]
    vertex_indices = np.where(mask)[0] 
    
    faces_in_bbox = []
    for face in faces:
        if all(v in vertex_indices for v in face):
            faces_in_bbox.append(face)
    faces_in_bbox = np.array(faces_in_bbox)
    index_mapping = {global_idx: local_idx for local_idx, global_idx in enumerate(vertex_indices)}
    new_faces = np.array([[index_mapping[v] for v in face] for face in faces_in_bbox])
    
    return vertices_in_bbox, new_faces

def cluster_vertices_kmeans(vertices, n_clusters):
    kmeans = KMeans(n_clusters=n_clusters, random_state=0)
    labels = kmeans.fit_predict(vertices)
    return labels, kmeans.cluster_centers_, kmeans.inertia_

    
def get_cluster_scores(points, n_cluster_max):
    kmeans_inertia = np.zeros((n_cluster_max,1))
    for i in range(0,n_cluster_max):
        _, _, kmeans_inertia[i] = cluster_vertices_kmeans(points, i+1)
    return kmeans_inertia
        

def decide_number_of_clusters(kmeans_inertia):    
    inertia_drop_rate = np.divide(kmeans_inertia[:-1], kmeans_inertia[1:])
    n_best_cluster = np.argmax(inertia_drop_rate) + 2
    if np.all(inertia_drop_rate < 1.8):
        n_best_cluster = 1
    return n_best_cluster, inertia_drop_rate


def get_foreground_within_convhull(vertices, bw, bbx):
    foreground_coords = np.array(np.nonzero(bw)).T  + [bbx.minpt.x, bbx.minpt.y, bbx.minpt.z]
    hull = ConvexHull(vertices)
    delaunay = Delaunay(vertices[hull.vertices])
    inside_hull = delaunay.find_simplex(foreground_coords) >= 0
    foreground_inside = np.zeros_like(bw, dtype=np.uint8)
    for voxel, is_inside in zip(foreground_coords, inside_hull):
        if is_inside:
            foreground_inside[tuple(voxel - [bbx.minpt.x, bbx.minpt.y, bbx.minpt.z])] = 1
    return foreground_inside


def get_bbox(bw):
    coords = np.argwhere(bw)
    xmin, ymin, zmin = coords.min(axis=0)
    xmax, ymax, zmax = coords.max(axis=0)
    return [xmin, xmax, ymin, ymax, zmin, zmax]


def get_bbox_for_syn_cluster(syn_cluster, vol, ws):
    bbx = vol.bounds
    min_corner = np.min(syn_cluster, axis=0) / np.array([18,18,45]) *1000 
    max_corner = np.max(syn_cluster, axis=0) / np.array([18,18,45]) *1000
    bbx.minpt.x = max(bbx.minpt.x, min_corner[0] - ws)
    bbx.maxpt.x = min(bbx.maxpt.x, max_corner[0] + ws)
    bbx.minpt.y = max(bbx.minpt.y, min_corner[1] - ws)
    bbx.maxpt.y = min(bbx.maxpt.y, max_corner[1] + ws)
    bbx.minpt.z = max(bbx.minpt.z, min_corner[2] - (ws*2/5))
    bbx.maxpt.z = min(bbx.maxpt.z, max_corner[2] + (ws*2/5))    
    #bbox_nm = np.array([bbx.minpt.x, bbx.maxpt.x, bbx.minpt.y, bbx.maxpt.y, bbx.minpt.z, bbx.maxpt.z]) * np.array([18,18,18,18,45,45])
    #syn_center_bbox_vx = this_syn_center-[bbx.minpt.x, bbx.minpt.y, bbx.minpt.z]
    return bbx


def divide_vertices_by_count_neighboring_voxels(vertices_bbox_microns, radius):
    tree = cKDTree(vertices_bbox_microns)
    counts = tree.query_ball_point(vertices_bbox_microns, radius, return_length=True)
    counts_reshaped = counts.reshape(-1,1)
    kmeans = KMeans(n_clusters=2, random_state=42).fit(counts_reshaped)
    threshold_kmeans = np.mean(kmeans.cluster_centers_)
    counts_thresholded = (counts >= threshold_kmeans).astype(int)
    bouton_vertices = vertices_bbox_microns[counts >= threshold_kmeans]
    nonbouton_vertices = vertices_bbox_microns[counts < threshold_kmeans]
    
    return bouton_vertices, nonbouton_vertices, counts, counts_thresholded


def save_bouton_split_result_3D_mesh(vertices_bbox_microns, color_weights, latest_mf, latest_root_PC):

    fig = go.Figure()
    fig.add_trace(go.Scatter3d(
        x=vertices_bbox_microns[:, 0], y=vertices_bbox_microns[:, 1], z=vertices_bbox_microns[:, 2],
        mode='markers',
        marker=dict(size=1, color=color_weights, opacity=0.8, colorbar=dict(title="Neighbor Count")),
        name='vertices'
    ))        
    fig.update_layout(
        title="3D Scatter Plots",
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        )
    )
    plotly.offline.plot(fig, filename='./'+str(latest_mf)+'_' + str(latest_root_PC)+'.html')


def save_bouton_split_result_3D_seg(binary_mask, latest_mf, latest_root_PC):
    xc, yc, zc = np.where(binary_mask == 1)
    fig = go.Figure()
    fig.add_trace(go.Scatter3d(
        x=xc, y=yc, z=zc*5/2,
        mode='markers',
        marker=dict(size=3, color='red', opacity=0.8, colorbar=dict(title="Neighbor Count")),
        name='vertices'
    ))        
    fig.update_layout(
        title="3D Scatter Plots",
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        )
    )
    plotly.offline.plot(fig, filename='./seg'+str(latest_mf)+'_' + str(latest_root_PC)+'.html')
    

def extract_bouton_by_voxel_density(latest_mf, latest_root_PC, vol, syn_coords, radius,mm):
    mf_pyr_syn_thresh = 0
    dsyn_outlier_thresh = 5 #[microns]
    n_cluster_syn = 1
    result = []
    mesh_dir = '../mesh_data/'

    if syn_coords.shape[0] > mf_pyr_syn_thresh:
        syn_coords_microns = syn_coords * np.array([18,18,45]) / 1000
        syn_cluster_center_microns = cluster_center(syn_coords_microns)[0]  
 
        # count number of syn too far away from the center of all syn 
        # if there are such syn, then this neuron makes more than one bouton on target cell
        syn_cluster_center_vx = syn_cluster_center_microns / np.array([18,18,45]) * 1000
        n_outlier = 0
        for k in range(0, len(syn_coords_microns)):
            dist = distance.euclidean(syn_coords_microns[k], syn_cluster_center_microns)
            if dist > dsyn_outlier_thresh:
                n_outlier += 1
                
        if n_outlier > 2:  # at least two boutons
            #return ['double_bouton_candidate',latest_mf, latest_root_PC, this_syn_center ]           
            score = get_cluster_scores(syn_coords_microns, 3)
            n_cluster_syn, gain = decide_number_of_clusters(score)             
            syn_cluster_labels, syn_cluster_centers, _ = cluster_vertices_kmeans(syn_coords_microns, n_cluster_syn)
            unique_syn_labels, counts = np.unique(syn_cluster_labels, return_counts=True)
            valid_syn_clusters = unique_syn_labels[counts>mf_pyr_syn_thresh]
            print('at least two boutons')
            for i in range(0, n_cluster_syn):
                if i in valid_syn_clusters:
                    syn_cluster_points = syn_coords_microns[syn_cluster_labels == i]
                    this_syn_center_vx = syn_cluster_centers[i] / np.array([18,18,45]) * 1000
                    print('bouton', i)
                    if (this_syn_center_vx[2] > 200) and (this_syn_center_vx[2] < 2040):
                        if len(syn_cluster_points) < 24:
                            bbx=get_bbox_for_syn_cluster(syn_cluster_points, vol, 400)
                        elif len(syn_cluster_points) < 36:
                            bbx=get_bbox_for_syn_cluster(syn_cluster_points, vol, 500)
                        else:
                            bbx=get_bbox_for_syn_cluster(syn_cluster_points, vol, 600)
                        bbox_nm = np.array([bbx.minpt.x, bbx.maxpt.x, bbx.minpt.y, bbx.maxpt.y, bbx.minpt.z, bbx.maxpt.z]) * np.array([18,18,18,18,45,45])

                        try:
                            mesh = mm.mesh(seg_id = latest_mf, remove_duplicate_vertices=True)
                        except:
                            print('no mesh')
                            return ['no mesh available', latest_mf, latest_root_PC, this_syn_center_vx]

                        vertices_bbox, faces_bbox = get_mesh_in_bbox(mesh.vertices, mesh.faces, bbox_nm)
                        vertices_bbox_microns = vertices_bbox/1000
                        try:
                            vertices_bouton,vertices_nonbouton,counts,counts_thresholded = divide_vertices_by_count_neighboring_voxels(vertices_bbox_microns, radius)
                            vertices_bouton_vx = np.round(vertices_bouton / np.array([18,18,45]) * 1000, decimals=3)
                            bouton_bbox = copy.deepcopy(bbx)
                            bouton_bbox.minpt.x, bouton_bbox.minpt.y, bouton_bbox.minpt.z = vertices_bouton_vx.min(axis=0)
                            bouton_bbox.maxpt.x, bouton_bbox.maxpt.y, bouton_bbox.maxpt.z = vertices_bouton_vx.max(axis=0)  
                            seg_bouton = vol.download(bbox=bouton_bbox, label=latest_mf) 
                            seg_bouton = np.squeeze(seg_bouton, axis=-1)
                            bouton_vol = np.round(np.sum(seg_bouton) * 0.018 * 0.018 * 0.045, decimals=4)

                            result.append(['double_bouton',latest_mf, latest_root_PC, this_syn_center_vx, len(syn_coords), bouton_vol, bouton_bbox])
                        
                            #save_bouton_split_result_3D_seg(seg_bouton, latest_mf, latest_root_PC)
                            save_bouton_split_result_3D_mesh(vertices_bbox_microns, counts_thresholded, latest_mf, latest_root_PC)
                        except:
                            print('likely to be not MF')
                            result.append(['not MF',latest_mf, latest_root_PC, this_syn_center_vx, len(syn_coords)])

                        
        else: 
            if (syn_cluster_center_vx[2] > 200) and (syn_cluster_center_vx[2] < 2040):
                print('single bouton')
                if len(syn_coords_microns) < 24:
                    bbx=get_bbox_for_syn_cluster(syn_coords_microns, vol, 400)
                elif len(syn_coords_microns) < 36:
                    bbx=get_bbox_for_syn_cluster(syn_coords_microns, vol, 500)
                else:
                    bbx=get_bbox_for_syn_cluster(syn_coords_microns, vol, 600)
                bbox_nm = np.array([bbx.minpt.x, bbx.maxpt.x, bbx.minpt.y, bbx.maxpt.y, bbx.minpt.z, bbx.maxpt.z]) * np.array([18,18,18,18,45,45])
                #syn_center_bbox_vx = this_syn_center-[bbx.minpt.x, bbx.minpt.y, bbx.minpt.z]

                try:
                    mesh = mm.mesh(seg_id = latest_mf, remove_duplicate_vertices=True)
                except:
                    print('no mesh')
                    return ['no mesh available', latest_mf, latest_root_PC, syn_cluster_center_vx]

                vertices_bbox, faces_bbox = get_mesh_in_bbox(mesh.vertices, mesh.faces, bbox_nm)
                vertices_bbox_microns = vertices_bbox/1000
                vertices_bouton,vertices_nonbouton,counts,counts_thresholded = divide_vertices_by_count_neighboring_voxels(vertices_bbox_microns, radius)
                vertices_bouton_vx = np.round(vertices_bouton / np.array([18,18,45]) * 1000, decimals=3)
                bouton_bbox = copy.deepcopy(bbx)
                bouton_bbox.minpt.x, bouton_bbox.minpt.y, bouton_bbox.minpt.z = vertices_bouton_vx.min(axis=0)
                bouton_bbox.maxpt.x, bouton_bbox.maxpt.y, bouton_bbox.maxpt.z = vertices_bouton_vx.max(axis=0)  
                seg_bouton = vol.download(bbox=bouton_bbox, label=latest_mf)  
                seg_bouton = np.squeeze(seg_bouton, axis=-1)
                bouton_vol = np.round(np.sum(seg_bouton) * 0.018 * 0.018 * 0.045, decimals=4)  
                result.append([latest_mf, latest_root_PC, syn_cluster_center_vx, len(syn_coords), bouton_vol, bouton_bbox])
                #save_bouton_split_result_3D_seg(seg_bouton, latest_mf, latest_root_PC)
                save_bouton_split_result_3D_mesh(vertices_bbox_microns, counts_thresholded, latest_mf, latest_root_PC)
            else:
                print('truncated bouton')
        return result


In [16]:
pyr_ids = [648518346440178071]
mf_ids = [648518346446845973,648518346466684376]
mesh_dir = '../mesh_data/'  

radius = 3
results = []

stime = time.time()

for i in range(0,len(pyr_ids)):
    syn = client.materialize.synapse_query(pre_ids=mf_ids, post_ids=pyr_ids[i], bounding_box=None, bounding_box_column='post_pt_position', 
                timestamp=tstamp, remove_autapses=True, include_zeros=False, limit=None, offset=None, 
                split_positions=False, desired_resolution=[18,18,45], materialization_version=None, 
                synapse_table='synapses_ca3_v1', datastack_name='zheng_ca3', metadata=True) 
    
    for j in range(0,len(mf_ids)):
        mm = trimesh_io.MeshMeta(cv_path='gs://zheng_mouse_hippocampus_production/v2/seg_m195',disk_cache_path=mesh_dir, cache_size=20)
        syn_center_pos1 = syn.loc[(syn['pre_pt_root_id'] == mf_ids[j]) & (syn['post_pt_root_id'] == pyr_ids[i]), ['pre_pt_position']].values
        syn_coords1 = np.array([arr[0] for arr in syn_center_pos1])            
        result = extract_bouton_by_voxel_density(mf_ids[j], pyr_ids[i], vol, syn_coords1, radius, mm)
        results.append(result)
        print('bouton size = ', result[0][4])


etime = time.time()
print('Time taken (s):', etime-stime)


single bouton
bouton size =  14.7212
single bouton
bouton size =  25.8573
Time taken (s): 5.091066837310791
