# markdown


In [None]:
import os
import vg
import time
import joblib
import napari
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy import spatial
from skimage import filters
from skimage import morphology
import matplotlib.pyplot as plt
from scipy import ndimage as ndi
from tifffile import imread, imwrite
from skimage.morphology import label
from skimage.feature import peak_local_max
from skimage.segmentation import watershed
from skimage.filters import threshold_otsu
from skimage.segmentation import clear_border
from skimage.morphology import binary_erosion
from skimage.measure import regionprops, regionprops_table
from sklearn.ensemble import RandomForestClassifier


def wipe_layers(viewer_name):
    '''
    Delete all layers in the viewer objected
    '''
    layers = viewer_name.layers
    while len(layers) > 0:
        layers.remove(layers[0])

def remove_large_objects(labels_array: np.ndarray, max_size: int) -> np.ndarray:
    ''' 
    Remove all objects in a mask above a specific threshold
    '''
    out = np.copy(labels_array)
    component_sizes = np.bincount(labels_array.ravel()) # count the number of pixels in different labels
    too_big = component_sizes > max_size
    too_big_mask = too_big[labels_array]
    out[too_big_mask] = 0
    return out

def return_points(labels_array: np.ndarray, label_ID: int) -> np.ndarray:
    '''
    Return the points in a mask that belong to a specific label
    ---
    Parameters:
    labels_array: np.ndarray an ndArray of labels
    label_ID: int the label ID of the label whos points you want to calculate
    ---
    Returns:
    points: np.ndarray an ndArray of shape (n,3) where n is the number of points in the label
    and dim1 is the x,y,z coordinates of the points
    '''
    points = np.column_stack(np.where(labels_array == label_ID))
    return points

def find_label_density(label_points: np.ndarray) -> float:
    '''
    Calculate the bounding box for a point cloud and return the density of points in the bounding box
    ---
    Parameters:
    label_points: np.ndarray the array point coordinates for a given label
    ---
    Returns:
    np.nan if the label is 0, or if the label has no length
    density (float) the number of points in the label divided by the volume of the bounding box
    '''

    x = label_points.T[0]
    y = label_points.T[1]
    z = label_points.T[2]
    num_points = len(x)
    x_min = np.min(x)
    x_max = np.max(x)
    y_min = np.min(y)
    y_max = np.max(y)
    z_min = np.min(z)
    z_max = np.max(z)
    # add 1 to prevent division by 0
    x_range = (x_max - x_min) + 1
    y_range = (y_max - y_min) + 1
    z_range = (z_max - z_min) + 1
    vol = x_range * y_range * z_range
    density = num_points / vol
    return density

def print_label_props(source: np.ndarray, label_num: int) -> None:
    '''
    Print the properties of a label in a mask
    ---
    Parameters:
    source: np.ndarray the mask containing the label
    label_num: int the label number of the label you want to print the properties of
    ---
    Returns:
    None
    '''
    label_points = return_points(source, label_num)
    density = find_label_density(label_points)
    size = label_points.shape[0]
    print(f'Label {label_num} has:')
    print(f'{size:,} points.')
    print(f'density of {round(density,4):,}')

def get_cube(source: np.ndarray, label_num: int) -> np.ndarray:
    '''
    Return a cube of the label in a mask
    ---
    Parameters:
    source: np.ndarray the mask containing the label
    label_num: int the label number of the label you want isolate
    ---
    Returns:
    cube: np.ndarray the cube of the label
    '''
    label_points = return_points(source, label_num)
    x = label_points.T[0]
    y = label_points.T[1]
    z = label_points.T[2]
    x_min = np.min(x) - 1
    x_max = np.max(x) + 2
    y_min = np.min(y) - 1
    y_max = np.max(y) + 2
    z_min = np.min(z) - 1
    z_max = np.max(z) + 2
    #cube = source[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1]
    return x_min, x_max, y_min, y_max, z_min, z_max

def apply_cube(source: np.ndarray, cube: tuple) -> np.ndarray:
    '''
    Crop an ndArray with a cube
    ---
    Parameters:
    source: np.ndarray the array to crop
    cube: tuple containing the x_min, x_max, y_min, y_max, z_min, z_max
    ---
    Returns:
    out: np.ndarray array with the cube applied
    '''
    x_min, x_max, y_min, y_max, z_min, z_max = cube
    out = source[x_min:x_max, y_min:y_max, z_min:z_max]
    return out

def get_long_axis(cubed_label: np.ndarray, line_length = 75):
    '''
    Get the longest axis of an cubed_label
    ---
    Parameters:
    cubed_label: np.ndarray the cubed_label to get the longest axis of
    ---
    Returns:
    linepts: np.ndarray the points of the longest axis
    '''
    if cubed_label.dtype == 'bool':
        coords = np.column_stack(np.where(cubed_label == True))
    else:
        label_identify = [i for i in np.unique(cubed_label) if i != 0][0]
        coords = np.column_stack(np.where(cubed_label == label_identify))
    if coords.shape[0] > 1000:
        sampling_interval = coords.shape[0] // 1000
    else:
        sampling_interval = 1    

    np.random.shuffle(coords)
    subsampled = coords[::sampling_interval]
    datamean = subsampled.mean(axis=0)
    uu, dd, vv = np.linalg.svd(subsampled - datamean)
    linepts = vv[0] * np.mgrid[-line_length:line_length:2j][:, np.newaxis]
    linepts += datamean
    return vv[0], linepts

def view_saved_files(file_path: str) -> None:
    ''' 
    Fxn for visualizing saved output files.
    '''
    dedicated_file_viewer = napari.Viewer()
    contents = [c for c in os.listdir(file_path) if not c.startswith('.')]
    for content in contents:
        if content.endswith('.tif'):
            if 'tub' in content or 'PI' in content:
                dedicated_file_viewer.add_image(imread(os.path.join(file_path, content)), name=content.split('.')[0], blending='additive', visible=False)
            else:
                dedicated_file_viewer.add_labels(imread(os.path.join(file_path, content)), name=content.split('.')[0], blending='additive')
        elif content.endswith('.txt'):
            nums = np.loadtxt(os.path.join(file_path, content))
            if nums.ndim == 1:
                dedicated_file_viewer.add_points(nums, name=content.split('.')[0], face_color='white', blending='additive')
            elif nums.ndim == 2:
                dedicated_file_viewer.add_shapes(nums, shape_type='line', name=content.split('.')[0], edge_color='white', blending='additive')
        else:
            print(f'file "{content}" not imported to viewer')


#### Define the napari Viewer 

In [None]:
%gui qt 
viewer = napari.Viewer()

In [None]:
analysis_dir = '/Volumes/bigData/wholeMount_volDist/220712_Fix_Emb_Flvw_Chn1GAP_PI_aTub647_Processed/N2V_Denoised/0_Analysis_01'
subdirs = [d for d in os.listdir(analysis_dir) if os.path.isdir(os.path.join(analysis_dir, d))]

# load the spindle classifier
classifier_path = os.path.join(os.getcwd(), 'spindle_classifier/spindle_classifier.joblib')
classifier = joblib.load(classifier_path)

# create save directory
data_save_dir = os.path.join(analysis_dir, '0_data_cubes')
if not os.path.exists(data_save_dir):
    os.mkdir(data_save_dir)


for subdir in tqdm(subdirs):

    # make a save directory for this embryo
    emb_save_dir = os.path.join(data_save_dir, subdir)
    if not os.path.exists(emb_save_dir):
        os.makedirs(emb_save_dir)

    # define file paths and load data
    emb_type, emb_num = subdir.split('_')
    segmentations_name = f'220712_Fix_Emb_Flvw_Chn1GAP_PI_aTub647_{emb_type}_{emb_num}-Z01_PI_16bit_scaleZ_sbdl2_16bit_seg.npy'
    dog_tub_name = f'220712_Fix_Emb_Flvw_Chn1GAP_PI_aTub647_{emb_type}_{emb_num}-Z01_Tub_16bit_scaleZ_sbdl2_16bit.tif'
    raw_tub_name = f'220712_Fix_Emb_Flvw_Chn1GAP_PI_aTub647_{emb_type}_{emb_num}-Z01_Tub_16bit_scaleZ.tif'
    pi_name = f'220712_Fix_Emb_Flvw_Chn1GAP_PI_aTub647_{emb_type}_{emb_num}-Z01_PI_16bit_scaleZ_sbdl2_16bit.tif'

    data_load_dir = os.path.join(analysis_dir, subdir)
    masks = np.load(os.path.join(data_load_dir, segmentations_name), allow_pickle=True).item()['masks']
    tub = imread(os.path.join(data_load_dir, dog_tub_name))
    tub_raw = imread(os.path.join(data_load_dir, raw_tub_name))
    pi = imread(os.path.join(data_load_dir, pi_name))

    # coarsly filter the masks of poor segmentations
    minimum_size = 250000
    maximum_size = 1000000
    minimum_density = 0.21

    filtered_masks = clear_border(masks)
    filtered_masks = morphology.remove_small_objects(filtered_masks, min_size=minimum_size, connectivity=1)
    filtered_masks = remove_large_objects(filtered_masks, max_size=maximum_size)
    remaining_labels = [label for label in np.unique(filtered_masks) if label != 0]
    print('Calculating point clouds...')
    label_pcs = [return_points(masks, label_ID) for label_ID in tqdm(remaining_labels)]
    densities = [find_label_density(pc) for pc in label_pcs]

    for ind, id in enumerate(remaining_labels):
        if id in remaining_labels and densities[ind] < minimum_density:
            filtered_masks[filtered_masks == id] = 0

    # Establish the remaining labels for the embryo and loop through each cell
    final_labels = [label for label in np.unique(filtered_masks) if label != 0]
    for curr_mask_id in final_labels:
        wipe_layers(viewer)

        # isolate the current mask as bolean array
        curr_mask = filtered_masks == curr_mask_id

        # get the coordinate of the bounding cube for the current mask ID. Apply it to the labels and images
        cube_dims = get_cube(filtered_masks, curr_mask_id)
        cubed_label = apply_cube(curr_mask, cube_dims)
        cubed_tub = apply_cube(tub, cube_dims)
        cubed_tub_raw = apply_cube(tub_raw, cube_dims)
        cubed_PI = apply_cube(pi, cube_dims)

        # get the mask coordinates, centroid, and long axis
        mask_coords = np.column_stack(np.where(cubed_label == True))
        cell_centroid = mask_coords.mean(axis=0)
        cell_long_vect, cell_long_line = get_long_axis(cubed_label)

        # erode the mask to eliminate some cortical signal
        eroded_mask = binary_erosion(cubed_label, footprint=np.ones((3, 3, 3)))
        for i in range(10):
            eroded_mask = binary_erosion(eroded_mask)

        # get the tubulin signal from the remaining region and define an Otsu threshold
        remaining_tub = np.zeros(shape=cubed_label.shape)
        remaining_tub[eroded_mask] = cubed_tub[eroded_mask]
        remaining_vals = cubed_tub[eroded_mask].ravel()
        thresh_val = threshold_otsu(remaining_vals)
        thresh_mask = label(remaining_tub > thresh_val)

        # get number of remaining labels
        num_tub_labels_b4_filter = [label for label in np.unique(thresh_mask) if label != 0]

        if len(num_tub_labels_b4_filter) == 0:
            # no filtering required if we already have no labels
            print(f'no spindles detected in label {curr_mask_id}')
            continue

        # filter and labels smaller or larger than the mimum and maximum expected label sizes
        min_thrsh_size = 500
        max_thrsh_size = 5000       
        if len(num_tub_labels_b4_filter) > 1:
            thresh_mask = morphology.remove_small_objects(thresh_mask, min_size=100, connectivity=1)
            thresh_mask = remove_large_objects(thresh_mask, max_size=max_thrsh_size)

        # get the number of labels after filtering
        remaining_labels = [label for label in np.unique(thresh_mask) if label != 0]

        if len(remaining_labels) == 0:
            print(f'no spindles detected in mask {curr_mask_id}')
            continue

        if len(remaining_labels) >= 1:

            # calculate properties of remaining labels
            props = regionprops_table(thresh_mask, properties=('area',
                                                                'axis_major_length',
                                                                'axis_minor_length',
                                                                'label'))
            props_df = pd.DataFrame(props)

            # for each remaining label, calculate density and distance to cell centroid
            for label_num in remaining_labels:
                label_coords = np.column_stack(np.where(thresh_mask == label_num))
                label_centroid = label_coords.mean(axis=0)
                dist = spatial.distance.euclidean(cell_centroid, label_centroid)
                props_df.loc[props_df['label'] == label_num, 'dist_to_cell'] = dist

                label_density = find_label_density(label_coords)
                props_df.loc[props_df['label'] == label_num, 'density'] = label_density

            # convert to array and ask classifier to classify
            for label_num in remaining_labels:
                stats = props_df.loc[props_df['label'] == label_num]
                stats = stats.drop(columns=['label'])
                vals = stats.values

                # remove the label if the classifier isn't >90% confident.
                spindle_prediction = classifier.predict(vals)
                print(f'spindle prediction for label {label_num} in cell {curr_mask_id} of embryo {subdir}: {spindle_prediction}')
                if not spindle_prediction == 1:
                    thresh_mask[thresh_mask == label_num] = 0
                
        # get a final label count
        final_labels = [l for l in np.unique(thresh_mask) if l != 0]

        if len(final_labels) == 0 or len(final_labels) > 1:
            print(f'unsatisfactory spindle segmentation in mask {curr_mask_id}')
            continue

        # define mask save directory
        mask_save_dir = os.path.join(emb_save_dir, f'cell_{curr_mask_id}')
        if not os.path.exists(mask_save_dir):
            os.mkdir(mask_save_dir)

        # get the spindle coordinates and centroid
        spindle_label_ID = remaining_labels[0]

        spindle_coords = np.column_stack(np.where(thresh_mask == spindle_label_ID))
        spindle_centroid = spindle_coords.mean(axis=0)
        spindle_long_vect, spindle_long_line = get_long_axis(thresh_mask)

        # get the distance between the centroids, and the angles between the long axes
        dist = spatial.distance.euclidean(cell_centroid, spindle_centroid)
        ang = vg.angle(spindle_long_vect, cell_long_vect)

        # populate the viewer 
        viewer.add_labels(eroded_mask, name='eroded_mask', blending='additive', visible=False)
        viewer.add_labels(cubed_label, name='curr_mask_cube', blending='additive')
        viewer.add_image(cubed_tub, name='curr_tub_cube', blending='additive', visible=False)
        viewer.add_image(cubed_tub_raw, name='curr_tub_raw_cube', blending='additive', visible=False)
        viewer.add_image(cubed_PI, name='curr_PI_cube', blending='additive', visible=False)
        viewer.add_labels(thresh_mask, name='thresh_mask', blending='additive')
        viewer.add_points(cell_centroid, name='spindle centroid', face_color='magenta', blending='additive')
        viewer.add_points(spindle_centroid, name='spindle centroid', face_color='green', blending='additive')
        viewer.add_shapes(cell_long_line, shape_type='line', name='cell long axis', edge_color='red', blending='additive')
        viewer.add_shapes(spindle_long_line, shape_type='line', name='spindle long axis', edge_color='blue', blending='additive')

        images_and_layers = ['curr_mask_cube',
                            'curr_tub_cube',
                            'curr_tub_raw_cube',
                            'curr_PI_cube',
                            'eroded_mask',
                            'thresh_mask']

        # save the tif compatible layers as tifs
        for item in images_and_layers:
            viewer.layers[item].save(os.path.join(mask_save_dir, item + '.tif'))

        # save the arrays as txt files
        np.savetxt(os.path.join(mask_save_dir, 'spindle_centroid.txt'), spindle_centroid)
        np.savetxt(os.path.join(mask_save_dir, 'cell_centroid.txt'), cell_centroid)
        np.savetxt(os.path.join(mask_save_dir, 'spindle_long_axis.txt'), spindle_long_line)
        
    print(f'finished with embryo {subdir}')


In [None]:
from skimage.morphology import label


def calculate_geometries(curr_mask_id):

    # start with a fresh slate
    wipe_layers(viewer)
    results[curr_mask_id] = []


    # get the coordinate of the bounding cube for the current mask ID. Apply it to the labels and images
    cube_dims = get_cube(masks, curr_mask_id)
    cubed_label = apply_cube(curr_mask, cube_dims)
    cubed_tub = apply_cube(tub, cube_dims)
    cubed_tub_raw = apply_cube(tub_raw, cube_dims)
    cubed_PI = apply_cube(pi, cube_dims)

    # erode the mask to eliminate some cortical signal
    eroded_mask = binary_erosion(cubed_label, footprint=np.ones((3, 3, 3)))
    for i in range(10):
        eroded_mask = binary_erosion(eroded_mask)

    # get the tubulin signal from the remaining region and define an Otsu threshold
    remaining_tub = np.zeros(shape=cubed_label.shape)
    remaining_tub[eroded_mask] = cubed_tub[eroded_mask]
    remaining_vals = cubed_tub[eroded_mask].ravel()
    thresh_val = threshold_otsu(remaining_vals)
    thresh_mask = label(remaining_tub > thresh_val)

    # filter and labels smaller or larger than the mimum and maximum expected label sizes
    min_thrsh_size = 500
    max_thrsh_size = 5000
    num_tub_labels_b4_filter = [label for label in np.unique(thresh_mask) if label != 0]

    if len(num_tub_labels_b4_filter) == 0:
        # no filtering required if we already have no labels
        print(f'no spindles detected in label {curr_mask_id}')
        return

    if len(num_tub_labels_b4_filter) == 1:
        # remove_small_objects will complain about this, so let's just pass it by remove_large_objects
        thresh_mask = remove_large_objects(thresh_mask, max_size=max_thrsh_size)

    if len(num_tub_labels_b4_filter) > 1:
        # here it's likely that we have both small and large objects coontaminating
        thresh_mask = morphology.remove_small_objects(thresh_mask, min_size=min_thrsh_size, connectivity=1)
        thresh_mask = remove_large_objects(thresh_mask, max_size=max_thrsh_size)

    # get the number of labels after filtering
    remaining_labels = [label for label in np.unique(thresh_mask) if label != 0]
    if len(remaining_labels) > 1:
        print('more than one region remaining!')
        return

    if len(remaining_labels) == 0:
        print(f'no spindles detected in mask {curr_mask_id}')
        return

    # get the spindle coordinates and centroid
    spindle_label_ID = remaining_labels[0]
    '''
    record the spindle properties so we can make sure we're not analyzing any crazy shapes
    '''
    spindle_coords = np.column_stack(np.where(thresh_mask == spindle_label_ID))
    spindle_centroid = spindle_coords.mean(axis=0)
    spindle_long_vect, spindle_long_line = get_long_axis(thresh_mask)

    # get the mask coordinates, centroid, and long axis
    mask_coords = np.column_stack(np.where(cubed_label == True))
    cell_centroid = mask_coords.mean(axis=0)
    cell_long_vect, cell_long_line = get_long_axis(cubed_label)

    # get the distance between the centroids, and the angles between the long axes
    dist = spatial.distance.euclidean(cell_centroid, spindle_centroid)
    ang = vg.angle(spindle_long_vect, cell_long_vect)
    results[curr_mask_id].append(dist)
    results[curr_mask_id].append(ang)

    # make a save directory for this mask
    if cntrl:
        emb_save_dir = os.path.join(data_save_dir, f'Cntrl_E{emb_num}')
    else:
        emb_save_dir = os.path.join(data_save_dir, f'Exp_E{emb_num}')
    if not os.path.exists(emb_save_dir):
        os.makedirs(emb_save_dir)
    
    save_dir = os.path.join(emb_save_dir, f'{curr_mask_id}')
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    # populate the viewer 
    viewer.add_labels(eroded_mask, name='eroded_mask', blending='additive', visible=False)
    viewer.add_labels(cubed_label, name='curr_mask_cube', blending='additive')
    viewer.add_image(cubed_tub, name='curr_tub_cube', blending='additive', visible=False)
    viewer.add_image(cubed_tub_raw, name='curr_tub_raw_cube', blending='additive', visible=False)
    viewer.add_image(cubed_PI, name='curr_PI_cube', blending='additive', visible=False)
    viewer.add_labels(thresh_mask, name='thresh_mask', blending='additive')
    viewer.add_points(cell_centroid, name='spindle centroid', face_color='magenta', blending='additive')
    viewer.add_points(spindle_centroid, name='spindle centroid', face_color='green', blending='additive')
    viewer.add_shapes(cell_long_line, shape_type='line', name='cell long axis', edge_color='red', blending='additive')
    viewer.add_shapes(spindle_long_line, shape_type='line', name='spindle long axis', edge_color='blue', blending='additive')

    images_and_layers = ['curr_mask_cube',
                        'curr_tub_cube',
                        'curr_tub_raw_cube',
                        'curr_PI_cube',
                        'eroded_mask',
                        'thresh_mask']

    # save the tif compatible layers as tifs
    for item in images_and_layers:
        viewer.layers[item].save(os.path.join(save_dir, item + '.tif'))

    # save the arrays as txt files
    np.savetxt(os.path.join(save_dir, 'spindle_centroid.txt'), spindle_centroid)
    np.savetxt(os.path.join(save_dir, 'cell_centroid.txt'), cell_centroid)
    np.savetxt(os.path.join(save_dir, 'spindle_long_axis.txt'), spindle_long_line)

base_dir = '/Volumes/bigData/wholeMount_volDist/220712_Fix_Emb_Flvw_Chn1GAP_PI_aTub647_Processed/N2V_Denoised/16bit_scaleZ_sbdl2' 

cntrl = False
emb_nums = ['06']
for emb_num in emb_nums:

    if cntrl == True:
        emb_type = 'Cntrl'
        segmentations_name = f'220712_Fix_Emb_Flvw_Chn1GAP_PI_aTub647_{emb_type}_E{emb_num}-Z01_PI_16bit_scaleZ_sbdl2_16bit_seg.npy'
        dog_tub_name = f'220712_Fix_Emb_Flvw_Chn1GAP_PI_aTub647_{emb_type}_E{emb_num}-Z01_Tub_16bit_scaleZ_sbdl2_16bit.tif'
        raw_tub_name = f'220712_Fix_Emb_Flvw_Chn1GAP_PI_aTub647_{emb_type}_E{emb_num}-Z01_Tub_16bit_scaleZ.tif'
        pi_name = f'220712_Fix_Emb_Flvw_Chn1GAP_PI_aTub647_{emb_type}_E{emb_num}-Z01_PI_16bit_scaleZ_sbdl2_16bit.tif'

        masks = np.load(os.path.join(base_dir, segmentations_name), allow_pickle=True).item()['masks']
        tub = imread(os.path.join(base_dir, dog_tub_name))
        tub_raw = imread(os.path.join(base_dir, raw_tub_name))
        pi = imread(os.path.join(base_dir, pi_name))

    if cntrl == False:
        emb_type = 'Exp'
        segmentations_name = f'220712_Fix_Emb_Flvw_Chn1GAP_PI_aTub647_{emb_type}_E{emb_num}-Z01_PI_16bit_scaleZ_sbdl2_16bit_seg.npy'
        dog_tub_name = f'220712_Fix_Emb_Flvw_Chn1GAP_PI_aTub647_{emb_type}_E{emb_num}-Z01_Tub_16bit_scaleZ_sbdl2_16bit.tif'
        raw_tub_name = f'220712_Fix_Emb_Flvw_Chn1GAP_PI_aTub647_{emb_type}_E{emb_num}-Z01_Tub_16bit_scaleZ.tif'
        pi_name = f'220712_Fix_Emb_Flvw_Chn1GAP_PI_aTub647_{emb_type}_E{emb_num}-Z01_PI_16bit_scaleZ_sbdl2_16bit.tif'

        masks = np.load(os.path.join(base_dir, segmentations_name), allow_pickle=True).item()['masks']
        tub = imread(os.path.join(base_dir, dog_tub_name))
        tub_raw = imread(os.path.join(base_dir, raw_tub_name))
        pi = imread(os.path.join(base_dir, pi_name))

    ########################################################################################################################

    minimum_size = 250000
    maximum_size = 1000000
    minimum_density = 0.21

    filtered_masks = clear_border(masks)
    filtered_masks = morphology.remove_small_objects(filtered_masks, min_size=minimum_size, connectivity=1)
    filtered_masks = remove_large_objects(filtered_masks, max_size=maximum_size)
    remaining_labels = [label for label in np.unique(filtered_masks) if label != 0]
    print('Calculating point clouds...')
    label_pcs = [return_points(masks, label_ID) for label_ID in tqdm(remaining_labels)]
    densities = [find_label_density(pc) for pc in label_pcs]

    for ind, id in enumerate(remaining_labels):
        if id in remaining_labels and densities[ind] < minimum_density:
            filtered_masks[filtered_masks == id] = 0

    ########################################################################################################################


    final_labels = [label for label in np.unique(filtered_masks) if label != 0]
    main_save_dir = '/Users/bementmbp/Desktop/autosave2' 
    results = {}

    def calculate_geometries(curr_mask_id):

        # start with a fresh slate
        wipe_layers()
        results[curr_mask_id] = []
        curr_mask = masks == curr_mask_id

        # get the coordinate of the bounding cube for the current mask ID. Apply it to the labels and images
        cube_dims = get_cube(masks, curr_mask_id)
        cubed_label = apply_cube(curr_mask, cube_dims)
        cubed_tub = apply_cube(tub, cube_dims)
        cubed_tub_raw = apply_cube(tub_raw, cube_dims)
        cubed_PI = apply_cube(pi, cube_dims)

        # erode the mask to eliminate some cortical signal
        eroded_mask = binary_erosion(cubed_label, footprint=np.ones((3, 3, 3)))
        for i in range(10):
            eroded_mask = binary_erosion(eroded_mask)

        # get the tubulin signal from the remaining region and define an Otsu threshold
        remaining_tub = np.zeros(shape=cubed_label.shape)
        remaining_tub[eroded_mask] = cubed_tub[eroded_mask]
        remaining_vals = cubed_tub[eroded_mask].ravel()
        thresh_val = threshold_otsu(remaining_vals)
        thresh_mask = label(remaining_tub > thresh_val)

        # filter and labels smaller or larger than the mimum and maximum expected label sizes
        min_thrsh_size = 500
        max_thrsh_size = 5000
        num_tub_labels_b4_filter = [label for label in np.unique(thresh_mask) if label != 0]

        if len(num_tub_labels_b4_filter) == 0:
            # no filtering required if we already have no labels
            print(f'no spindles detected in label {curr_mask_id}')
            return

        if len(num_tub_labels_b4_filter) == 1:
            # remove_small_objects will complain about this, so let's just pass it by remove_large_objects
            thresh_mask = remove_large_objects(thresh_mask, max_size=max_thrsh_size)

        if len(num_tub_labels_b4_filter) > 1:
            # here it's likely that we have both small and large objects coontaminating
            thresh_mask = morphology.remove_small_objects(thresh_mask, min_size=min_thrsh_size, connectivity=1)
            thresh_mask = remove_large_objects(thresh_mask, max_size=max_thrsh_size)

        # get the number of labels after filtering
        remaining_labels = [label for label in np.unique(thresh_mask) if label != 0]
        if len(remaining_labels) > 1:
            print('more than one region remaining!')
            return

        if len(remaining_labels) == 0:
            print(f'no spindles detected in mask {curr_mask_id}')
            return

        # get the spindle coordinates and centroid
        spindle_label_ID = remaining_labels[0]
        '''
        record the spindle properties so we can make sure we're not analyzing any crazy shapes
        '''
        spindle_coords = np.column_stack(np.where(thresh_mask == spindle_label_ID))
        spindle_centroid = spindle_coords.mean(axis=0)
        spindle_long_vect, spindle_long_line = get_long_axis(thresh_mask)

        # get the mask coordinates, centroid, and long axis
        mask_coords = np.column_stack(np.where(cubed_label == True))
        cell_centroid = mask_coords.mean(axis=0)
        cell_long_vect, cell_long_line = get_long_axis(cubed_label)

        # get the distance between the centroids, and the angles between the long axes
        dist = spatial.distance.euclidean(cell_centroid, spindle_centroid)
        ang = vg.angle(spindle_long_vect, cell_long_vect)
        results[curr_mask_id].append(dist)
        results[curr_mask_id].append(ang)

        # make a save directory for this mask
        if cntrl:
            emb_save_dir = os.path.join(main_save_dir, f'Cntrl_E{emb_num}')
        else:
            emb_save_dir = os.path.join(main_save_dir, f'Exp_E{emb_num}')
        if not os.path.exists(emb_save_dir):
            os.makedirs(emb_save_dir)
        
        save_dir = os.path.join(emb_save_dir, f'{curr_mask_id}')
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        # populate the viewer 
        viewer.add_labels(eroded_mask, name='eroded_mask', blending='additive', visible=False)
        viewer.add_labels(cubed_label, name='curr_mask_cube', blending='additive')
        viewer.add_image(cubed_tub, name='curr_tub_cube', blending='additive', visible=False)
        viewer.add_image(cubed_tub_raw, name='curr_tub_raw_cube', blending='additive', visible=False)
        viewer.add_image(cubed_PI, name='curr_PI_cube', blending='additive', visible=False)
        viewer.add_labels(thresh_mask, name='thresh_mask', blending='additive')
        viewer.add_points(cell_centroid, name='spindle centroid', face_color='magenta', blending='additive')
        viewer.add_points(spindle_centroid, name='spindle centroid', face_color='green', blending='additive')
        viewer.add_shapes(cell_long_line, shape_type='line', name='cell long axis', edge_color='red', blending='additive')
        viewer.add_shapes(spindle_long_line, shape_type='line', name='spindle long axis', edge_color='blue', blending='additive')

        images_and_layers = ['curr_mask_cube',
                            'curr_tub_cube',
                            'curr_tub_raw_cube',
                            'curr_PI_cube',
                            'eroded_mask',
                            'thresh_mask']

        # save the tif compatible layers as tifs
        for item in images_and_layers:
            viewer.layers[item].save(os.path.join(save_dir, item + '.tif'))

        # save the arrays as txt files
        np.savetxt(os.path.join(save_dir, 'spindle_centroid.txt'), spindle_centroid)
        np.savetxt(os.path.join(save_dir, 'cell_centroid.txt'), cell_centroid)
        np.savetxt(os.path.join(save_dir, 'spindle_long_axis.txt'), spindle_long_line)

    for curr_mask_id in final_labels:
        calculate_geometries(curr_mask_id)
    
    print(f'finished with {emb_num}')

In [None]:
viewer = napari.Viewer()

In [None]:
p = '/Users/bementmbp/Desktop/autosave2/Exp_E02/46/curr_tub_raw_cube.tif' 
m = '/Users/bementmbp/Desktop/Scripts/volume-distribution/spindle_classifier/raw_training_data/95/curr_mask_cube.tif'
raw = imread(p)
cell_mask = imread(m).astype('bool')



thresh = threshold_otsu(raw)
binary = np.zeros(raw.shape)
binary[raw > thresh] = 1
thresh_labels = label(binary)
thresh_labels = morphology.remove_small_objects(thresh_labels, min_size=100, connectivity=1)

viewer.add_image(raw, name='raw', blending='additive')
viewer.add_labels(thresh_labels, name='thresh_labels', blending='additive')
viewer.add_labels(cell_mask, name='cell_mask', blending='additive')

# load the classifier
classifier_path = os.path.join(os.getcwd(), 'spindle_classifier/spindle_classifier.joblib')
classifier = joblib.load(classifier_path)

# get the cell mask centroid
mask_coords = np.column_stack(np.where(cell_mask == True))
cell_centroid = mask_coords.mean(axis=0)

props = regionprops_table(thresh_labels, properties=('area',
                                                    'axis_major_length',
                                                    'axis_minor_length',
                                                    'label'))
props_df = pd.DataFrame(props)

remaining_labels = [l for l in np.unique(thresh_labels) if l != 0]
for label_num in remaining_labels:
    label_coords = np.column_stack(np.where(thresh_labels == label_num))
    label_centroid = label_coords.mean(axis=0)
    dist = spatial.distance.euclidean(cell_centroid, label_centroid)
    props_df.loc[props_df['label'] == label_num, 'dist_to_cell'] = dist

    label_density = find_label_density(label_coords)
    props_df.loc[props_df['label'] == label_num, 'density'] = label_density

for label_num in remaining_labels:
    stats = props_df.loc[props_df['label'] == label_num]
    stats = stats.drop(columns=['label'])
    vals = stats.values
    print(f'prediction for label {label_num}: {classifier.predict_proba(vals)}')

    print(f'spindle prediction: {classifier.predict_proba(vals)[0][1]}')

    if not classifier.predict_proba(vals)[0][1] > 0.9:
        thresh_labels[thresh_labels == label_num] = 0
    
viewer.add_labels(thresh_labels, name='cleaned labels', blending='additive')
