In [1]:
from cloudvolume import CloudVolume
from caveclient import CAVEclient
import navis
from navis import TreeNeuron
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D 
import time
from tqdm import tqdm
import json
import uuid  
from datetime import datetime
from scipy.spatial import ConvexHull
from sklearn.cluster import DBSCAN
from scipy.spatial import distance, distance_matrix, cKDTree
from scipy.spatial.distance import cdist
from scipy.sparse.csgraph import minimum_spanning_tree
import networkx as nx
import networkx as nx
import pcg_skel
import plotly.express as px
import plotly.graph_objects as go
import cloudvolume as cv
import gc

navis.patch_cloudvolume()

vol = cv.CloudVolume("precomputed://gs://zheng_mouse_hippocampus_production/v2/seg_m195", use_https=True, progress=False) # new from Will
#seg = cv.CloudVolume("graphene://https://minnie.microns-daf.com/segmentation/table/zheng_ca3", use_https=True) # old

client = CAVEclient('zheng_ca3')
auth = client.auth

# CloudVolume and Cave setup
#sv = CloudVolume('graphene://https://minnie.microns-daf.com/segmentation/table/zheng_ca3', use_https=True, lru_bytes=int(1e8))

INFO  : cloud-volume successfully patched! (navis)


In [2]:
client = CAVEclient('zheng_ca3')

In [41]:
def scale_neuron(skel_scaled, scale_x=18, scale_y=18, scale_z=45):
    # Ensure we're handling a NeuronList properly
    if isinstance(skel_scaled, navis.NeuronList):
        for neuron in skel_scaled:
            if hasattr(neuron, 'nodes') and neuron.nodes is not None:
                neuron.nodes = neuron.nodes.copy(deep=True)
                neuron.nodes.loc[:, 'x'] /= scale_x
                neuron.nodes.loc[:, 'y'] /= scale_y
                neuron.nodes.loc[:, 'z'] /= scale_z
    else:
        if skel_scaled:
            # If it's a single neuron, scale directly
            skel_scaled.nodes = skel_scaled.nodes.copy(deep=True)
            skel_scaled.nodes.loc[:, 'x'] /= scale_x
            skel_scaled.nodes.loc[:, 'y'] /= scale_y
            skel_scaled.nodes.loc[:, 'z'] /= scale_z
        else:
            return skel_scaled

    return skel_scaled


def find_soma_position(skel, single_segid, df_cells, current_date):
    """
    Find the soma position by detecting the densest cluster of nodes in the skeleton.
    Returns False if there are not enough samples to form a valid cluster.
    """

    # Filter matching row
    matching_row = df_cells[df_cells[f"updated_segids_{current_date}"] == single_segid]

    # Scale X, Y by 18 and Z by 45
    if not matching_row.empty:
        # Convert to numeric values AFTER filtering
        scaled_xyz = matching_row[['x', 'y', 'z']].astype(float).copy()  

        # Perform numeric scaling
        scaled_xyz['x'] *= 18
        scaled_xyz['y'] *= 18
        scaled_xyz['z'] *= 45

        # Convert to list
        soma_center = scaled_xyz[['x', 'y', 'z']].values.tolist()[0]  
        
    else:
        matching_row = df_cells[df_cells[f"updated_segids_{current_date}"] == fallback_segid]
        
        # Scale X, Y by 18 and Z by 45
        if not matching_row.empty:
            # Convert to numeric values AFTER filtering
            scaled_xyz = matching_row[['x', 'y', 'z']].astype(float).copy()  

            # Perform numeric scaling
            scaled_xyz['x'] *= 18
            scaled_xyz['y'] *= 18
            scaled_xyz['z'] *= 45

            # Convert to list
            soma_center = scaled_xyz[['x', 'y', 'z']].values.tolist()[0]  
        else:
            print(f"No match found for updated_segids_{current_date}: {single_segid}")
            soma_center = 0

    return soma_center


def find_and_remove_soma(skel, single_segid, fallback_segid, df_cells, current_date, pyr_cell=False, apical=False, tilt_angle=15, plot_soma=False, plot_branches=False, plot_plane=False):
            
    # Compute soma center 
    soma_center = find_soma_position(skel, single_segid, df_cells, current_date)
    
    og_skel_length = skel.cable_length
    
    if plot_soma or plot_plane:
        # Create 3D scatter plot of all nodes before filtering
        node_positions = skel.nodes[["x", "y", "z"]]
        df_nodes = pd.DataFrame(node_positions, columns=["x", "y", "z"])
        fig = px.scatter_3d(df_nodes, x="x", y="y", z="z", 
                            title="3D Visualization of Skeleton Nodes Before Filtering", opacity=0.1)
        fig.update_traces(marker=dict(size=1))
        fig.add_scatter3d(x=[soma_center[0]], y=[soma_center[1]], z=[soma_center[2]],
                         mode='markers', marker=dict(size=8, color='red', opacity=1), name='Soma Center')
        del fig

    # Compute distances of all nodes from the soma center
    distances_ = np.linalg.norm(skel.nodes[["x", "y", "z"]].values - soma_center, axis=1)

    # Remove nodes within a 15,000 radius sphere
    nodes_to_remove = skel.nodes["node_id"][distances_ <= 15000].tolist()

    # Remove the nodes within the sphere from the skeleton
    skel = navis.remove_nodes(skel, which=nodes_to_remove, inplace=False)  

    if pyr_cell and skel is not None:
        filtered_neuron_list = []

        def compute_tilted_z(x, y, soma_x, soma_y, soma_z, angle_x, angle_y):
            """ Compute the z-position of the tilted plane at a given (x, y) coordinate """
            angle_x_rad = np.radians(angle_x)
            angle_y_rad = np.radians(angle_y)
            dz_x = (x - soma_x) * np.tan(angle_x_rad)  # Change in z due to tilt in X
            dz_y = (y - soma_y) * np.tan(angle_y_rad)  # Change in z due to tilt in Y
            return soma_z + dz_x + dz_y  # Combine both tilts

        # Determine how much the neuron is "arched"
        # We are not using the arch because all of the neurons face the same way. 
        arch_x_start, arch_x_end = 35000*18, 57000*18  # The start and end X positions of the arch
        arch_y_threshold = 1*18  # The Y position where the arch peaks

        # Default tilt settings (mostly X-tilt on left and right)
        tilt_x = 70 if soma_center[0] < 45000*18 else -70
        tilt_y = 0  # No Y tilt initially

        if arch_x_start < soma_center[0] < arch_x_end and soma_center[1] > arch_y_threshold:
            # Compute weight: How close is this neuron to the arch peak?
            distance_to_top = np.sqrt((soma_center[0] - 46000*18) ** 2 + (soma_center[1] - 46000*18) ** 2)
            max_distance = np.sqrt((arch_x_end - 46000*18) ** 2 + (60000*18 - 46000*18) ** 2)  # Max arch distance
            weight = distance_to_top / max_distance  # Normalize between 0 (top) and 1 (sides)

            # Interpolate tilt based on arch location
            tilt_x = 70 * (1 - weight) + tilt_x * weight  # More tilt along X near the top
            tilt_y = 70 * (1 - weight)  # More tilt along Y closer to arch

        if isinstance(skel, navis.NeuronList):
            for skel_ in skel:
                nodes = skel_.nodes.copy()
                tilted_z_values = nodes.apply(
                    lambda row: compute_tilted_z(
                        row["x"], row["y"], soma_center[0], soma_center[1], soma_center[2], tilt_x, tilt_y
                    ), axis=1
                )
                if apical:
                    to_remove = nodes["z"] > tilted_z_values
                else:
                    to_remove = nodes["z"] < tilted_z_values
                if any(to_remove):
                    skel_ = navis.remove_nodes(skel_, which=nodes.loc[to_remove, "node_id"].tolist(), inplace=False)
                if skel_ and len(skel_.nodes) > 0:
                    filtered_neuron_list.append(skel_)
        else:
            skel_ = skel
            nodes = skel_.nodes.copy()
            tilted_z_values = nodes.apply(
                lambda row: compute_tilted_z(
                    row["x"], row["y"], soma_center[0], soma_center[1], soma_center[2], tilt_x, tilt_y
                ), axis=1
            )
            if apical:
                to_remove = nodes["z"] > tilted_z_values
            else:
                to_remove = nodes["z"] < tilted_z_values
            if any(to_remove):
                skel_ = navis.remove_nodes(skel_, which=nodes.loc[to_remove, "node_id"].tolist(), inplace=False)
            if skel_ and len(skel_.nodes) > 0:
                filtered_neuron_list.append(skel_)

        skel = navis.NeuronList(filtered_neuron_list) if filtered_neuron_list else None
        
        if not skel:
            print(f'Failed to find branches beyond Soma: {single_segid}')
            return skel, None

    if plot_plane and skel:
        x_vals = np.linspace(soma_center[0] - 30000, soma_center[0] + 30000, num=20)
        y_vals = np.linspace(soma_center[1] - 30000, soma_center[1] + 30000, num=20)
        x_grid, y_grid = np.meshgrid(x_vals, y_vals)

        # Compute Z-values with the plane centered at soma_center
        z_grid = np.array([
            compute_tilted_z(x, y, soma_center[0], soma_center[1], soma_center[2], tilt_x, tilt_y) 
            for x, y in zip(np.ravel(x_grid), np.ravel(y_grid))
        ]).reshape(x_grid.shape)

        # Shift the plane so that it is centered at the soma (not above everything)
        z_grid -= np.mean(z_grid) - soma_center[2]

        # Plot the updated plane
        fig.add_traces([go.Surface(x=x_grid, y=y_grid, z=z_grid, colorscale='blues', opacity=0.4, name='Tilted Plane')])
        fig.show()
        del fig
        
    return skel, soma_center


def get_distal_nodes(skel, start_node_id):
    """
    Traverse the neuron graph from a starting node and return all downstream nodes.
    """
    from collections import deque

    # Build child map: parent_id -> list of child_ids
    child_map = {}
    for idx, row in skel.nodes.iterrows():
        pid = row['parent_id']
        if pid not in child_map:
            child_map[pid] = []
        child_map[pid].append(row['node_id'])

    # BFS from start_node_id to get all downstream nodes
    distal_nodes = set()
    queue = deque([start_node_id])

    while queue:
        current = queue.popleft()
        distal_nodes.add(current)
        children = child_map.get(current, [])
        queue.extend(children)

    return list(distal_nodes)


def remove_proximal_branches_basal(skel, soma_center, segid, skel_og, apical, distance_threshold=40000):
    """
    Removes all nodes within a specified XY-plane cylindrical distance from the soma_center.
    Keeps only the distal part of the neuron beyond the distance_threshold.

    Parameters:
        skel (navis.TreeNeuron): Full neuron skeleton.
        soma_center (array-like): [x, y, z] coordinates of soma center.
        segid (int/str): Segment ID for labeling and debug output.
        skel_og (navis.TreeNeuron): Original skeleton for reference plotting.
        distance_threshold (float): XY-plane radial distance to exclude nodes within this radius.

    Returns:
        navis.TreeNeuron or None: Skeleton with only distal nodes beyond radius, or None if nothing remains.
    """
    import numpy as np
    import navis
    from scipy.spatial.distance import cdist

    # Convert soma center to array
    soma_center = np.asarray(soma_center).reshape(1, 3)

    # Get all node coordinates
    all_coords = skel.nodes[['x', 'y', 'z']].values
    all_node_ids = skel.nodes['node_id'].values

    # Retry with decreasing distance thresholds until successful or min reached
    
    if not apical:
        min_threshold = 30000
        step = 5000
        current_threshold = distance_threshold

        while current_threshold >= min_threshold:
            # Calculate XY-plane distances only
            dists_xy = np.linalg.norm(all_coords[:, :2] - soma_center[:, :2], axis=1)

            # Keep only nodes outside the cylindrical radius
            keep_ids = all_node_ids[dists_xy > current_threshold]

            if len(keep_ids) > 0:
                break  # Success
            else:
                print(f"[{segid}] No nodes beyond {current_threshold} XY-distance from soma — lowering threshold.")
                current_threshold -= step
    else:
        min_threshold = 90000
        
        dists_xy = np.linalg.norm(all_coords[:, :2] - soma_center[:, :2], axis=1)

        # Keep only nodes outside the cylindrical radius
        keep_ids = all_node_ids[dists_xy > min_threshold]

    if len(keep_ids) == 0:
        print(f"[{segid}] No nodes found even at minimum threshold ({min_threshold}) — skipping.")
        return None

    # Prune the skeleton
    skel_pruned = navis.subset_neuron(skel, subset=keep_ids, inplace=False)

    if skel_pruned.n_nodes == 0:
        print(f"[{segid}] Pruned skeleton is empty — skipping.")
        return None

    # Plot result
    plot_skeleton_3d(skel_pruned, segid, skel_og, apical)

    return skel_pruned


def plot_skeleton_3d(skel, filename_prefix, skel_og, apical):
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    import numpy as np

    def safe_node_id(val):
        if isinstance(val, np.ndarray):
            return int(val.flatten()[0])  # get scalar from array
        return int(val)

    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')

    nodes = skel.nodes.set_index('node_id')
    nodes_og = skel_og.nodes.set_index('node_id')

    # Plot nodes of skel
    ax.plot(nodes['x'], nodes['y'], nodes['z'], '.', color='orange', alpha=0.5, markersize=2)

    # Plot original skeleton (skel_og) edges in blue
    for edge in skel_og.edges:
        try:
            node1, node2 = safe_node_id(edge[0]), safe_node_id(edge[1])
        except Exception as e:
            print(f"Failed to parse original edge: {edge} -> {e}")
            continue

        if node1 in nodes_og.index and node2 in nodes_og.index:
            n1 = nodes_og.loc[node1, ['x', 'y', 'z']].values
            n2 = nodes_og.loc[node2, ['x', 'y', 'z']].values
            ax.plot([n1[0], n2[0]], [n1[1], n2[1]], [n1[2], n2[2]],
                    color='blue', linewidth=0.5, alpha=0.5)

    # Plot main skeleton (skel) edges in orange
    for edge in skel.edges:
        try:
            node1, node2 = safe_node_id(edge[0]), safe_node_id(edge[1])
        except Exception as e:
            print(f"Failed to parse main edge: {edge} -> {e}")
            continue

        if node1 in nodes.index and node2 in nodes.index:
            n1 = nodes.loc[node1, ['x', 'y', 'z']].values
            n2 = nodes.loc[node2, ['x', 'y', 'z']].values
            ax.plot([n1[0], n2[0]], [n1[1], n2[1]], [n1[2], n2[2]],
                    color='orange', linewidth=2)

    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.view_init(elev=20, azim=135)
    plt.title(f"{filename_prefix}")
    plt.tight_layout()
    if apical:
        plt.savefig(f"{filename_prefix}_skeleton_img_apical.png", dpi=300)
    else:
        plt.savefig(f"{filename_prefix}_skeleton_img_basal.png", dpi=300)
    plt.close()
    del fig


def group_and_filter_nodes(df_leafs, distance_threshold=1000):
    # Create a copy to avoid modifying the original dataframe
    df_leafs = df_leafs.copy()
    
    # Convert coordinates to NumPy array for efficiency
    coords = df_leafs[['x', 'y', 'z']].values
    node_ids = df_leafs['node_id'].values
    
    # Create a KDTree for fast nearest-neighbor search
    tree = cKDTree(coords)
    
    visited = set()  # Track visited nodes
    keep_nodes = set()  # Nodes to keep
    remove_nodes = set()  # Nodes to remove

    for i, node_id in enumerate(node_ids):
        if node_id in visited:
            continue  # Skip already processed nodes

        # Find all nodes within the distance threshold
        neighbors = tree.query_ball_point(coords[i], distance_threshold)
        
        # Extract their node IDs
        group_nodes = [node_ids[j] for j in neighbors]
        
        # Mark all nodes as visited
        visited.update(group_nodes)
        
        # Keep only the last node in the group
        last_node = group_nodes[-1]  # The last node in the group
        keep_nodes.add(last_node)

        # Remove all other nodes in the group (excluding the last one)
        remove_nodes.update(group_nodes[:-1])

    # Filter df_leafs to retain only the nodes in keep_nodes
    df_filtered = df_leafs[df_leafs['node_id'].isin(keep_nodes)].reset_index(drop=True)

    return df_filtered, list(remove_nodes)  # Return both the filtered dataframe and removed node IDs


def remove_close_leaf_nodes(skel, pruned_skeleton, min_threshold=350, max_threshold=750):
    """
    Remove nodes from skel.nodes that are:
    - Within a dynamically changing distance threshold of pruned skeleton nodes or edges.
    - The threshold decreases from `max_threshold` at the start of the skeleton to `min_threshold` at the end.
    """

    # Extract XYZ coordinates from the pruned skeleton
    pruned_coords = pruned_skeleton.nodes[['x', 'y', 'z']].values
    pruned_node_ids = pruned_skeleton.nodes['node_id'].values  # IDs of pruned nodes

    # Extract XYZ coordinates of skel.nodes
    leaf_coords = skel.nodes[['x', 'y', 'z']].values
    leaf_node_ids = skel.nodes['node_id'].values  # IDs of skel nodes

    if leaf_coords.shape[0] == 0:
        print("Error: skel.nodes is empty! No nodes to process.")
        return None

    # Build KDTree for fast nearest-neighbor searches
    tree_nodes = cKDTree(pruned_coords)

    # Ensure pruned_skeleton.edge_coords is a NumPy array
    pruned_edges = np.squeeze(np.array(pruned_skeleton.edge_coords))

    if pruned_edges.shape[0] == 0:
        print("Error: pruned_edges is empty!")
        return None

    # Check if pruned_edges has the correct shape (n, 2, 3)
    if pruned_edges.ndim == 3 and pruned_edges.shape[1] == 2 and pruned_edges.shape[2] == 3:
        # Compute midpoints if the shape is correct (n, 2, 3)
        edge_midpoints = np.mean(pruned_edges, axis=1)
    else:
        print("Warning: pruned_edges has an unexpected shape. Skipping edge proximity check.")
        edge_midpoints = np.array([])  # Set an empty array so that edge checks are skipped.

    # Create a KDTree for the edge midpoints, only if there are valid midpoints
    if edge_midpoints.shape[0] > 0 and edge_midpoints.ndim == 2 and edge_midpoints.shape[1] == 3:
        tree_edges = cKDTree(edge_midpoints)
    else:
        tree_edges = None
        print("Warning: Invalid edge midpoints shape. Skipping edge proximity check.")

    # Compute cumulative distance along the skeleton to determine how far each node is from the start
    distances_along_skeleton = np.cumsum(np.linalg.norm(np.diff(pruned_coords, axis=0), axis=1))
    distances_along_skeleton = np.insert(distances_along_skeleton, 0, 0)  # Start from zero

    # Normalize distances between 0 and 1
    normalized_distances = distances_along_skeleton / distances_along_skeleton[-1]

    # Compute dynamic thresholds based on normalized position along the skeleton
    distance_thresholds = max_threshold - (max_threshold - min_threshold) * normalized_distances

    # List to store nodes that should be removed
    nodes_to_remove = set()

    # Step 1: Check direct proximity to pruned nodes
    distances, _ = tree_nodes.query(leaf_coords)

    # Assign each leaf node a dynamic threshold based on nearest pruned node distance
    closest_indices = tree_nodes.query(leaf_coords)[1]  # Get closest pruned node index
    dynamic_thresholds = distance_thresholds[closest_indices]

    # Remove nodes that are within their assigned threshold
    nodes_to_remove.update(leaf_node_ids[distances <= dynamic_thresholds])

    # Step 2: Check proximity to edges using `pruned_skeleton.edge_coords`
    if tree_edges is not None:
        distances, _ = tree_edges.query(leaf_coords)

        # Identify nodes that are too close to edges
        edge_mask = distances <= dynamic_thresholds
        nodes_to_remove.update(leaf_node_ids[edge_mask])

    # Return list of removed nodes
    return list(nodes_to_remove)


def plot_threshold_changes(distances, thresholds, step=25000):
    """
    Plot the changing threshold along the skeleton every `step` distance.
    """
    sampled_indices = np.arange(0, len(distances), step=np.searchsorted(distances, step))
    
    sampled_distances = distances[sampled_indices]
    sampled_thresholds = thresholds[sampled_indices]

    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')
    
    ax.scatter(sampled_distances, sampled_thresholds, zs=0, zdir='z', label='Threshold', c='r')
    ax.plot(sampled_distances, sampled_thresholds, zs=0, zdir='z', linestyle='dashed')

    ax.set_xlabel("Distance Along Skeleton")
    ax.set_ylabel("Threshold (um)")
    ax.set_zlabel("Z=0 Plane")
    ax.set_title("Dynamic Threshold Along Skeleton")

    plt.show()
    del fig

    
def find_spines(df, skel, apical, threshold=1000, plot_spines=False):
    """
    Identify spines (leaf nodes) near postsynaptic positions and remove non-spine nodes.

    Parameters:
    - df: DataFrame with postsynaptic positions in 'post_pt_position'
    - skel_list: List of navis skeletons
    - threshold: Distance threshold to consider a node as a spine

    Returns:
    - skel_spines_only: NeuronList with only spine nodes kept
    - spines_num: Dictionary with neuron names and the number of spines found
    """

    # Convert postsynapse_positions list into a NumPy array for efficient calculations
    postsynapse_positions = np.array([np.array(pos) for pos in df['post_pt_position']])  # Convert list of lists into (N,3) array
    postsynapse_positions[:, :2] *= 18  # Scale X and Y by 18
    postsynapse_positions[:, 2] *= 45   # Scale Z by 45

    df_leafs = skel.leafs
    df_nodes = skel.nodes

    pruned_neuron = skel.prune_twigs(size=10000, inplace=False)

    # Filter out rows where 'type' is 'end'
    non_end_nodes = df_nodes[df_nodes['type'] != 'end']

    # Extract the node_id values from the non-end nodes
    non_end_node_ids = non_end_nodes['node_id'].tolist()
    skel = navis.remove_nodes(skel, which=non_end_node_ids, inplace=False)

    # Remove the nodes too close to the backbone of the skeleton 
    removed_leaf_nodes = remove_close_leaf_nodes(skel, pruned_neuron)  
    if removed_leaf_nodes is None:
        return None, None, None, None
    
    skel = navis.remove_nodes(skel, which=removed_leaf_nodes, inplace=False)

    df_nodes = skel.nodes

    # Extract x, y, z coordinates as a NumPy array
    remaining_node_coords = df_nodes[['x', 'y', 'z']].values  # These are the remaining nodes after filtering
    #print(f"Node Coords: {remaining_node_coords}")

    # Build KDTree from remaining nodes instead of postsynaptic positions
    tree = cKDTree(remaining_node_coords)

    # Query all postsynaptic positions against the tree of remaining nodes
    distances, indices = tree.query(postsynapse_positions, distance_upper_bound=threshold)
    
    # Step 1: Apply distance filter to keep only valid postsynaptic positions
    valid_mask = distances <= threshold
    matching_postsyn_positions = postsynapse_positions[valid_mask]  # These are the kept postsynaptic positions
    matching_postsyn_positions[:, :2] /= 18  # Reverse scaling for X and Y
    matching_postsyn_positions[:, 2] /= 45   # Reverse scaling for Z
    
    # Convert matching positions into a set for fast lookups
    matching_positions_set = set(tuple(pos) for pos in matching_postsyn_positions)
    
    # Step 2: Filter original df based on matching postsynaptic positions
    filtered_df = df[df['post_pt_position'].apply(lambda pos: tuple(pos) in matching_positions_set)]

    # Step 3: Extract 'pre_pt_root_id' values from the filtered dataset
    pre_partner_ids = filtered_df['pre_pt_root_id'].tolist()
    #print(f"Matching Presyn IDs: {pre_partner_ids}")

    # Step 4: Count the number of partners
    pre_partner_num = len(pre_partner_ids)

    #print(len(skel.nodes))
    
    df_nodes_2 = skel.nodes
    
    df_nodes_filtered, removed_node_ids = group_and_filter_nodes(df_nodes_2, distance_threshold=250)
    
    skel = navis.remove_nodes(skel, which=removed_node_ids, inplace=False)
    
    spines_num = len(skel.nodes)
    
    def plot_skeleton_with_nodes(skel, pruned_skeleton, df_synapses, apical):
        import navis
        import plotly.graph_objects as go

        # Plot the pruned skeleton using navis (Plotly backend)
        fig = navis.plot3d(pruned_skeleton, color='black', lw=1, backend='plotly', inline=False)
        if apical:
            filename = f"{df_synapses['post_pt_root_id'].iloc[0]}_spine_cable_apical.html"
        else:
            filename = f"{df_synapses['post_pt_root_id'].iloc[0]}_spine_cable_basal.html"
            
        # Add remaining nodes from skel.nodes as scatter points
        fig.add_trace(go.Scatter3d(
            x=skel.nodes['x'], 
            y=skel.nodes['y'], 
            z=skel.nodes['z'],
            mode='markers',
            marker=dict(size=3, color='blue'),
            name='Remaining Nodes'
        ))

        # Save the figure as HTML
        fig.write_html(filename)
        del fig

        # Show the combined plot
        #fig.show()
  
    plot_skeleton_with_nodes(skel, pruned_neuron, df, apical)
    
    return skel, spines_num, pre_partner_ids, pre_partner_num
    
    
def convert_to_JSON(skel_spines, segid, apical, id_name, print_output=False):
    """Save spine coordinates in JSON format for Neuroglancer."""

    json_path = "annotation_base_file.json"
    with open(json_path, "r") as f:
        neuroglancer_state = json.load(f)

    #df = skel_spines
    df = pd.DataFrame(skel_spines.nodes, columns=['x', 'y', 'z'])

    xyz_data_scaled = [[x / 18, y / 18, z / 45] for x, y, z in df[['x', 'y', 'z']].values]

    annotation_layer = next((layer for layer in neuroglancer_state["layers"] if layer["type"] == "annotation"), None)
    if annotation_layer:
        for xyz in xyz_data_scaled:
            annotation_layer["annotations"].append({"point": xyz, "type": "point", "id": str(uuid.uuid4())})

    seg_layer = next((layer for layer in neuroglancer_state["layers"] if layer["type"] == "segmentation" and "segments" in layer and layer.get("name") == "zheng_ca3"), None)
    if seg_layer:
        seg_layer["segments"].append(str(segid))
        
    if apical:
        output_path = f"{segid}_updated_neuroglancer_apical_spines.json"
        with open(output_path, "w") as f:
            json.dump(neuroglancer_state, f, indent=4)

    else:
        output_path = f"{segid}_updated_neuroglancer_basal_spines.json"
        with open(output_path, "w") as f:
            json.dump(neuroglancer_state, f, indent=4)

    if print_output:
        print(f"Updated JSON saved to: {output_path}")

        
def find_cable_lengths(skel, segid, apical, plot_bare_skeleton=False):
    """Find and return cable lengths of neuron segments."""
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    import navis

    pruned_neurons = []
    cable_lengths = {}

    # Ensure `skel` is a NeuronList
    if isinstance(skel, navis.NeuronList):
        for skel_ in skel:
            pruned_neuron = skel_.prune_twigs(size=10000, inplace=False)
            cable_length_n = skel_.cable_length
            cable_lengths[segid] = cable_length_n
            pruned_neurons.append(pruned_neuron)
    else:
        if not skel:
            return skel, 0
        pruned_neuron = skel.prune_twigs(size=10000, inplace=False)
        cable_length_n = skel.cable_length
        cable_lengths[segid] = cable_length_n
        pruned_neurons.append(pruned_neuron)

    pruned_neuron_list = navis.NeuronList(pruned_neurons)

    if plot_bare_skeleton:
        print('outdated')
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')

    for neuron in pruned_neuron_list:
        nodes = neuron.nodes.set_index('node_id')
        for edge in neuron.edges:
            n1 = nodes.loc[edge[0], ['x', 'y', 'z']].values
            n2 = nodes.loc[edge[1], ['x', 'y', 'z']].values
            ax.plot([n1[0], n2[0]], [n1[1], n2[1]], [n1[2], n2[2]], color='black', linewidth=1)

    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.view_init(elev=20, azim=135)
    plt.title(f'Pruned skeleton for {segid}')
    plt.tight_layout()
    if apical:
        plt.savefig(f"{segid}_bare_skeleton_cable_length_apical.png", dpi=300)
    else:
        plt.savefig(f"{segid}_bare_skeleton_cable_length_basal.png", dpi=300)
    plt.close()
    del fig

    return pruned_neuron_list, cable_lengths


def compute_spines_per_micron(num_spines, cable_lengths, segid):
    """Compute spines per micron for each neuron segment."""
    
    spines_per_micron = {}
    
    cable_lengths = cable_lengths[segid]
    nm_length = cable_lengths
    microns = nm_length / 1000  # Convert to microns
    spines_per_micron[segid] = num_spines / microns if microns > 0 else 0

    return spines_per_micron


def compute_rolling_spines(num_spines, cable_lengths, spine_positions_dict, segid, window_size=10, step_size=1, threshold=10):
    """Compute rolling average of spines per micron along the neuron skeleton."""

    rolling_spines_dict = {}
    all_rolling_values = []

    nm_length = cable_lengths[segid]
    microns = nm_length / 1000  # Convert to microns

    spine_positions = np.array(spine_positions_dict[segid]) / 1000  # Convert nm to microns
    spine_positions = spine_positions[(spine_positions >= 0) & (spine_positions < microns)]

    spine_distribution = np.zeros(int(np.ceil(microns)))  

    for pos in spine_positions.astype(int):
        spine_distribution[pos] += 1  

    rolling_avg = [
        np.mean(spine_distribution[i:i+window_size])
        for i in range(0, len(spine_distribution) - window_size + 1, step_size)
    ]

    rolling_spines_dict[segid] = rolling_avg
    all_rolling_values.extend(rolling_avg)

    overall_avg_rolling_spines = np.mean(all_rolling_values) if all_rolling_values else 0
    
    return rolling_spines_dict, overall_avg_rolling_spines


def find_spines_per_micron(df_synapses, mesh_list, segid, fallback_segid, df_cells, current_date, pyr_cell=False, apical=True, distal=False, plot_soma=False, plot_branches=False, plot_bare_skeleton=False, plot_spines=False):
    """Process a single skeleton from a mesh to compute spine density per micron."""

    skel = navis.skeletonize(mesh_list)
    skel_og = skel[0]
    
    if distal:
        skel_nosoma, soma_center = find_and_remove_soma(skel[0], segid, fallback_segid, df_cells, current_date, pyr_cell=pyr_cell, apical=apical, tilt_angle=70, plot_soma=plot_soma, plot_branches=plot_branches, plot_plane=False)
    else:
        skel_nosoma, soma_center = find_and_remove_soma(skel[0], segid, fallback_segid, df_cells, current_date, pyr_cell=pyr_cell, apical=apical, tilt_angle=70, plot_soma=plot_soma, plot_branches=plot_branches, plot_plane=False)   
    
    if soma_center == 0 or skel_nosoma is None:
        print(f"Failed to find soma center for: {segid}")
        spines_per_micron = 0 
        skel_spines = 0 
        cable_lengths = 0 
        spines_num = 0 
        soma_center = soma_center
        pre_partners_ids = []  
        pre_partner_num = 0
        return spines_per_micron, skel_spines, cable_lengths, spines_num, soma_center, pre_partners_ids, pre_partner_num
    
    skel_noprox = remove_proximal_branches_basal(skel_nosoma, soma_center, segid, skel_og, apical)

    if skel_noprox is None:        
        print(f"Failed to find proximal branches for: {segid}")
        spines_per_micron = 0 
        skel_spines = 0 
        cable_lengths = 0 
        spines_num = 0 
        soma_center = soma_center
        pre_partners_ids = []  
        pre_partner_num = 0
        return spines_per_micron, skel_spines, cable_lengths, spines_num, soma_center, pre_partners_ids, pre_partner_num
         
    if isinstance(skel_noprox, TreeNeuron):
        skel_spines, spines_num, pre_partners_ids, pre_partner_num = find_spines(df_synapses, skel_noprox, apical, threshold=2000, plot_spines=plot_spines)
    else:
        skel_spines, spines_num, pre_partners_ids, pre_partner_num = find_spines(df_synapses, skel_noprox[0], apical, threshold=2000, plot_spines=plot_spines)
    
    if skel_spines is None:
        print(f"Failed to find spines for: {segid}")
        spines_per_micron = 0 
        skel_spines = 0 
        cable_lengths = 0 
        spines_num = 0 
        soma_center = soma_center
        pre_partners_ids = []  
        pre_partner_num = 0
        return spines_per_micron, skel_spines, cable_lengths, spines_num, soma_center, pre_partners_ids, pre_partner_num
        
    id_name = segid
    
    convert_to_JSON(skel_spines, segid, apical, id_name=id_name, print_output=False)

    skel_backbones, cable_lengths = find_cable_lengths(skel_noprox, segid, apical, plot_bare_skeleton=plot_bare_skeleton)
    
    spines_per_micron = compute_spines_per_micron(spines_num, cable_lengths, segid=segid)

    return spines_per_micron, skel_spines, cable_lengths, spines_num, soma_center, pre_partners_ids, pre_partner_num


def build_skeleton_from_pcg(SID, client):
    pcg_skeleton = pcg_skel.pcg_skeleton(root_id=SID, client=client, root_point_resolution=True)

    if pcg_skeleton.edges.shape[0] < 1 or pcg_skeleton.vertices.shape[0] < 2:
        raise ValueError(f"Malformed PCG skeleton for SID {SID} (too few edges or vertices)")
    
    G = nx.Graph()
    G.add_edges_from(pcg_skeleton.edges)

    # Convert root to int to avoid 'unhashable type' error
    bfs_edges = list(nx.bfs_edges(G, source=int(pcg_skeleton.root)))

    parent_map = {child: parent for parent, child in bfs_edges}
    parent_ids = [parent_map.get(i, -1) for i in range(len(pcg_skeleton.vertices))]

    df_skeleton = pd.DataFrame({
        'node_id': range(len(pcg_skeleton.vertices)),
        'x': pcg_skeleton.vertices[:, 0],
        'y': pcg_skeleton.vertices[:, 1],
        'z': pcg_skeleton.vertices[:, 2],
        'parent_id': parent_ids
    })

    navis_skel = navis.TreeNeuron(x=df_skeleton)
    healed_skel = navis.heal_skeleton(navis_skel)

    return healed_skel


def update_segids_df(df, super_voxel_col):
    # Get the current date
    current_date = datetime.now().strftime('%Y%m%d')  # Format: YYYYMMDD
    
    # Update segids
    updated_segid_list = client.chunkedgraph.get_roots(df[super_voxel_col])
    
    # Add the updated segids to the DataFrame with the date in the column name
    updated_col_name = f"updated_segids_{current_date}"
    df[updated_col_name] = updated_segid_list
    
    print(f"Number of updated segids: {len(df)}")

    return df, updated_segid_list, current_date


def save_data_(segid_results, current_date, apical=True):
    today_date = current_date
    data_list = []

    for segid, values in segid_results.items():
        data_list.append({
            "segid": segid,
            "spines_per_micron": values.get("spines_per_micron"),
            "cable_length": values.get("cable_length"),
            "spines_num": values.get("spines_num"),
            "soma_center_x": values.get("soma_center")[0] if values.get("soma_center") else None,
            "soma_center_y": values.get("soma_center")[1] if values.get("soma_center") else None,
            "soma_center_z": values.get("soma_center")[2] if values.get("soma_center") else None,
            "curve": values.get("curve"),
            "depth": values.get("depth"),
            "pre_partners_ids": values.get("pre_partners_ids"),
            "pre_partner_num": values.get("pre_partner_num"),
            "date": values.get("date"),
            "segid_mesh": values.get("segid_mesh"),
            "error_type": values.get("error_type")
        })

    df = pd.DataFrame(data_list)

    if apical:
        df.to_csv(f"Pyr_spines_results_{today_date}_apical.csv", index=False)
    else:
        df.to_csv(f"Pyr_spines_results_{today_date}_basal.csv", index=False)

    return df


def run_for_all_segids(segids, df, current_date, pyr_cell=False, apical=False,
                       distal=False, plot_soma=False, plot_branches=False,
                       plot_bare_skeleton=False, plot_spines=False):
    import os
    from datetime import datetime, timedelta

    today_file = f"Pyr_spines_results_{current_date}_{'apical' if apical else 'basal'}.csv"
    parsed_date = datetime.strptime(current_date, "%Y%m%d")
    yesterday = (parsed_date - timedelta(days=1)).strftime("%m%d")
    yesterday_file = f"Pyr_spines_results_{yesterday}_{'apical' if apical else 'basal'}.csv"

    df_existing = pd.DataFrame()
    processed_segids = set()

    if os.path.exists(today_file):
        df_existing = pd.read_csv(today_file)
        processed_segids = set(df_existing["segid"].astype(str))
        print(f"[INFO] Resuming from today's file: {today_file}")
    elif os.path.exists(yesterday_file):
        df_existing = pd.read_csv(yesterday_file)
        processed_segids = set(df_existing["segid"].astype(str))
        print(f"[INFO] Resuming from yesterday's file: {yesterday_file}")
    else:
        print("[INFO] No existing result file found. Starting fresh.")

    segid_results = {row["segid"]: row for _, row in df_existing.iterrows()}

    for segid in tqdm(segids, desc="Processing SegIDs"):
        if str(segid) in processed_segids:
            continue

        error_type = None

        match = df.loc[df[f'updated_segids_{current_date}'] == segid, 'segid_0114']
        fallback_segid = match.values[0] if not match.empty else None

        df_synapses = client.materialize.synapse_query(
            pre_ids=None,
            post_ids=segid,
            synapse_table="synapses_ca3_v1",
            desired_resolution=[18, 18, 45]
        )

        segid_mesh = fallback_segid

        if df_synapses.empty and fallback_segid is not None:
            print(f"[INFO] No synapses for {segid}. Retrying with fallback {fallback_segid}...")
            df_synapses = client.materialize.synapse_query(
                pre_ids=None,
                post_ids=fallback_segid,
                synapse_table="synapses_ca3_v1",
                desired_resolution=[18, 18, 45]
            )

        if df_synapses.empty:
            error_type = "no_synapses"

        try:
            mesh_list = vol.mesh.get_navis(segid_mesh)
        except Exception as e:
            print(f"[ERROR] Could not get mesh for {segid_mesh}: {e}")
            mesh_list = None
            if error_type is None:
                error_type = "no_mesh"

        if not df_synapses.empty and mesh_list:
            try:
                spines_per_micron, skel_spine_pos, cable_lengths, spines_num, soma_center, pre_partners_ids, pre_partner_num = find_spines_per_micron(
                    df_synapses, mesh_list, segid, fallback_segid, df, current_date,
                    pyr_cell=pyr_cell, apical=apical, distal=distal, plot_soma=plot_soma,
                    plot_branches=plot_branches, plot_bare_skeleton=plot_bare_skeleton, plot_spines=plot_spines
                )

                df_current = df[df[f"updated_segids_{current_date}"] == segid]
                current_curve = df_current['curve_distance'].iloc[0] if not df_current.empty else None
                current_depth = df_current['depth_microns'].iloc[0] if not df_current.empty else None

                row_data = {
                    "spines_per_micron": spines_per_micron.get(segid) if spines_per_micron else 0,
                    "skel_spine_pos": skel_spine_pos,
                    "cable_length": cable_lengths.get(segid) if cable_lengths else 0,
                    "spines_num": spines_num,
                    "soma_center": soma_center,
                    "curve": current_curve,
                    "depth": current_depth,
                    "pre_partners_ids": pre_partners_ids,
                    "pre_partner_num": pre_partner_num,
                    "date": current_date,
                    "segid_mesh": segid_mesh,
                    "error_type": None
                }

            except Exception as e:
                print(f"[ERROR] Spine analysis failed for {segid}: {e}")
                error_type = "processing_error"
                row_data = None
        else:
            row_data = None

        # Handle fallback row
        if row_data is None:
            row_data = {
                "spines_per_micron": None,
                "skel_spine_pos": None,
                "cable_length": None,
                "spines_num": None,
                "soma_center": None,
                "curve": None,
                "depth": None,
                "pre_partners_ids": None,
                "pre_partner_num": None,
                "date": current_date,
                "segid_mesh": segid_mesh if 'segid_mesh' in locals() else None,
                "error_type": error_type
            }

        segid_results[segid] = row_data

        # Save after each iteration
        save_data_(segid_results, current_date, apical=apical)
        processed_segids.add(str(segid))

    print(f"[INFO] Finished. Final data saved to {today_file}")
    return save_data_(segid_results, current_date, apical=apical)


In [42]:
# Load Nuclei DF data and remove NA supervoxels 
current_date = datetime.now().strftime('%Y%m%d')  # Format: YYYYMMDD

df_apical = pd.read_csv('Pyr_MF_DF - MF-pyr.csv')
df_basal = df_apical
df_apical = df_apical[df_apical['apical branch'] == 'Yes']

# Updating segids for nuclei 
df_apical_updated, updated_segid_list_apical, current_date = update_segids_df(df_apical, super_voxel_col='supervoxel')
df_basal_updated, updated_segid_list_basal, current_date = update_segids_df(df_basal, super_voxel_col='supervoxel')
single_df, updated_segid_list_single, current_date = update_segids_df(single_df, super_voxel_col='supervoxel')


# Running spine finder (single segid and full)
#df_results_single = run_for_all_segids([648518346434011987], df_apical_updated, current_date, pyr_cell=True, apical=True, distal=False, plot_soma=False, plot_branches=False, plot_bare_skeleton=False, plot_spines=False)
df_results = run_for_all_segids(updated_segid_list_apical, df_apical_updated, current_date, pyr_cell=True, apical=True, distal=False, plot_soma=False, plot_branches=False, plot_bare_skeleton=False, plot_spines=False)
#df_results = run_for_all_segids(updated_segid_list_basal, df_basal_updated, current_date, pyr_cell=True, apical=False, distal=False, plot_soma=False, plot_branches=False, plot_bare_skeleton=False, plot_spines=False)





A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



Number of updated segids: 266
Number of updated segids: 636
Number of updated segids: 1
[INFO] No existing result file found. Starting fresh.


Processing SegIDs:   0%|                                | 0/266 [00:00<?, ?it/s]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:   0%|                      | 1/266 [00:47<3:30:27, 47.65s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:   1%|▏                     | 2/266 [01:25<3:03:22, 41.68s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:   1%|▏                     | 3/266 [02:18<3:25:26, 46.87s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:   2%|▎                     | 4/266 [02:31<2:27:33, 33.79s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:   2%|▍                     | 5/266 [02:41<1:49:17, 25.12s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:   2%|▍                     | 6/266 [03:02<1:41:46, 23.49s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing Se

[648518346445521015] No nodes found even at minimum threshold (90000) — skipping.
Failed to find proximal branches for: 648518346445521015


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:   6%|█▎                   | 16/266 [10:05<2:17:33, 33.01s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:   6%|█▎                   | 17/266 [10:39<2:17:38, 33.17s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:   7%|█▍                   | 18/266 [11:11<2:16:03, 32.92s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:   8%|█▌                   | 20/266 [11:54<1:46:25, 25.96s/it]

[648518346457367570] No nodes found even at minimum threshold (90000) — skipping.
Failed to find proximal branches for: 648518346457367570


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:   8%|█▋                   | 21/266 [13:24<3:04:03, 45.08s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:   9%|█▊                   | 23/266 [14:12<2:13:30, 32.96s/it]

[648518346443472004] No nodes found even at minimum threshold (90000) — skipping.
Failed to find proximal branches for: 648518346443472004


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:   9%|█▉                   | 24/266 [14:47<2:16:10, 33.76s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:   9%|█▉                   | 25/266 [16:25<3:32:31, 52.91s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  10%|██▏                  | 27/266 [17:07<2:20:37, 35.30s/it]

[648518346469873197] No nodes found even at minimum threshold (90000) — skipping.
Failed to find proximal branches for: 648518346469873197


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  11%|██▏                  | 28/266 [18:14<2:57:31, 44.75s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  11%|██▎                  | 29/266 [18:57<2:55:14, 44.37s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  11%|██▎                  | 30/266 [19:57<3:12:40, 48.98s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  12%|██▍                  | 31/266 [20:41<3:06:00, 47.49s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  12%|██▌                  | 32/266 [21:21<2:55:55, 45.11s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  12%|██▌                  | 33/266 [21:53<2:40:18, 41.28s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  13%|██▋                  | 34/266 [22:09<2:09:47, 33.57s/it]INFO  : Use t

[648518346441397979] No nodes found even at minimum threshold (90000) — skipping.
Failed to find proximal branches for: 648518346441397979


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  15%|███▏                 | 41/266 [27:29<3:19:59, 53.33s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  16%|███▎                 | 42/266 [28:35<3:33:25, 57.17s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  16%|███▍                 | 43/266 [29:25<3:24:01, 54.89s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  17%|███▍                 | 44/266 [29:53<2:52:57, 46.75s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  17%|███▌                 | 45/266 [30:25<2:36:35, 42.51s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  17%|███▋                 | 46/266 [31:07<2:35:17, 42.35s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  18%|███▋                 | 47/266 [31:20<2:01:56, 33.41s/it]INFO  : Use t

[648518346450806708] No nodes found even at minimum threshold (90000) — skipping.
Failed to find proximal branches for: 648518346450806708


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  19%|████                 | 51/266 [34:32<2:44:24, 45.88s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  20%|████                 | 52/266 [35:13<2:37:48, 44.24s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  20%|████▏                | 53/266 [35:48<2:27:38, 41.59s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  20%|████▎                | 54/266 [37:08<3:07:50, 53.16s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  21%|████▎                | 55/266 [37:56<3:00:50, 51.42s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  21%|████▍                | 56/266 [38:21<2:32:38, 43.61s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  22%|████▌                | 58/266 [39:04<1:47:07, 30.90s/it]

[648518346442423797] No nodes found even at minimum threshold (90000) — skipping.
Failed to find proximal branches for: 648518346442423797


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  22%|████▋                | 59/266 [39:32<1:43:37, 30.04s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  23%|████▋                | 60/266 [41:21<3:05:01, 53.89s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  23%|████▉                | 62/266 [42:44<2:33:37, 45.18s/it]

[648518346449601907] No nodes found even at minimum threshold (90000) — skipping.
Failed to find proximal branches for: 648518346449601907


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  24%|████▉                | 63/266 [43:24<2:27:08, 43.49s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  24%|█████                | 64/266 [43:45<2:04:29, 36.98s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  25%|█████▏               | 66/266 [44:18<1:24:54, 25.47s/it]

[ERROR] Could not get mesh for 648518346473791448: [1;91mManifest not found for segment 648518346473791448.[m


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  25%|█████▎               | 67/266 [44:46<1:26:45, 26.16s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  26%|█████▎               | 68/266 [45:24<1:37:42, 29.61s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  26%|█████▍               | 69/266 [46:18<2:01:47, 37.09s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  26%|█████▌               | 70/266 [47:03<2:08:25, 39.31s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  27%|█████▌               | 71/266 [47:54<2:20:03, 43.09s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  27%|█████▊               | 73/266 [49:38<2:20:55, 43.81s/it]

[648518346442061919] No nodes found even at minimum threshold (90000) — skipping.
Failed to find proximal branches for: 648518346442061919


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  28%|█████▊               | 74/266 [49:57<1:56:39, 36.46s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  28%|█████▉               | 75/266 [51:49<3:08:09, 59.11s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  29%|██████               | 76/266 [52:38<2:57:10, 55.95s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  29%|██████               | 77/266 [53:25<2:48:22, 53.45s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  29%|██████▏              | 78/266 [54:09<2:38:40, 50.64s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  30%|██████▏              | 79/266 [55:21<2:57:00, 56.79s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  30%|██████▎              | 80/266 [56:12<2:51:21, 55.28s/it]INFO  : Use t

[648518346443209520] No nodes found even at minimum threshold (90000) — skipping.
Failed to find proximal branches for: 648518346443209520


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  31%|██████▌              | 83/266 [57:29<1:49:12, 35.81s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  32%|██████▋              | 84/266 [58:00<1:44:49, 34.56s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  32%|██████▋              | 85/266 [58:17<1:28:24, 29.31s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  32%|██████▊              | 86/266 [58:46<1:27:32, 29.18s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  33%|██████▊              | 87/266 [59:09<1:20:58, 27.14s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  33%|██████▎            | 88/266 [1:01:57<3:26:11, 69.50s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  33%|██████▎            | 89/266 [1:02:39<3:00:58, 61.35s/it]INFO  : Use t

[648518346442205986] No nodes found even at minimum threshold (90000) — skipping.
Failed to find proximal branches for: 648518346442205986


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  38%|██████▊           | 100/266 [1:10:28<2:07:36, 46.13s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  38%|██████▊           | 101/266 [1:12:41<3:18:31, 72.19s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  38%|██████▉           | 102/266 [1:13:09<2:41:10, 58.97s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  39%|██████▉           | 103/266 [1:13:28<2:07:42, 47.01s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  39%|███████           | 104/266 [1:13:53<1:48:40, 40.25s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  39%|███████           | 105/266 [1:14:38<1:52:37, 41.97s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  40%|███████▏          | 106/266 [1:15:28<1:58:14, 44.34s/it]INFO  : Use t

[648518346438058169] No nodes found even at minimum threshold (90000) — skipping.
Failed to find proximal branches for: 648518346438058169


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  47%|████████▍         | 124/266 [1:28:26<1:27:22, 36.92s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  47%|████████▍         | 125/266 [1:29:36<1:49:54, 46.77s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  47%|████████▌         | 126/266 [1:32:07<3:02:18, 78.13s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  48%|████████▌         | 127/266 [1:32:34<2:25:25, 62.78s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  48%|████████▋         | 128/266 [1:33:08<2:04:01, 53.92s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  48%|████████▋         | 129/266 [1:34:05<2:05:44, 55.07s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  49%|████████▊         | 130/266 [1:34:48<1:56:22, 51.34s/it]INFO  : Use t

[648518346441538519] No nodes found even at minimum threshold (90000) — skipping.
Failed to find proximal branches for: 648518346441538519


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  52%|█████████▎        | 138/266 [1:45:57<2:47:16, 78.41s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  52%|█████████▍        | 139/266 [1:47:04<2:39:13, 75.23s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  53%|█████████▍        | 140/266 [1:48:13<2:33:45, 73.22s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  53%|█████████▌        | 141/266 [1:48:42<2:05:02, 60.02s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  53%|█████████▌        | 142/266 [1:49:16<1:47:33, 52.04s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  54%|█████████▋        | 143/266 [1:49:56<1:39:23, 48.49s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  54%|█████████▋        | 144/266 [1:50:32<1:31:17, 44.89s/it]INFO  : Use t

[648518346451966076] No nodes found even at minimum threshold (90000) — skipping.
Failed to find proximal branches for: 648518346451966076


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  59%|██████████▌       | 156/266 [2:07:12<2:45:06, 90.06s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  59%|██████████▌       | 157/266 [2:08:19<2:31:09, 83.21s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  59%|██████████▋       | 158/266 [2:10:31<2:56:21, 97.98s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  60%|██████████▊       | 159/266 [2:11:09<2:22:24, 79.85s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  60%|██████████▊       | 160/266 [2:11:43<1:56:54, 66.18s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  61%|██████████▉       | 161/266 [2:12:29<1:45:10, 60.10s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  61%|██████████▉       | 162/266 [2:13:00<1:28:47, 51.23s/it]INFO  : Use t

Error: pruned_edges is empty!
Failed to find spines for: 648518346443212041


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  68%|████████████▎     | 182/266 [2:44:31<2:02:58, 87.84s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  69%|████████████▍     | 183/266 [2:46:15<2:08:18, 92.75s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  69%|████████████▍     | 184/266 [2:46:56<1:45:25, 77.15s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  70%|████████████▌     | 185/266 [2:48:55<2:01:24, 89.94s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  70%|████████████▌     | 186/266 [2:50:05<1:51:52, 83.90s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  71%|████████████▋     | 188/266 [2:51:29<1:17:18, 59.47s/it]

[648518346441355558] No nodes found even at minimum threshold (90000) — skipping.
Failed to find proximal branches for: 648518346441355558


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  71%|████████████▊     | 189/266 [2:52:31<1:17:14, 60.18s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  71%|████████████▊     | 190/266 [2:53:13<1:09:23, 54.78s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  72%|████████████▉     | 191/266 [2:54:25<1:14:54, 59.93s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  72%|████████████▉     | 192/266 [2:55:00<1:04:30, 52.30s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  73%|████████████▍    | 194/266 [3:00:30<2:26:43, 122.27s/it]

Error: pruned_edges is empty!
Failed to find spines for: 648518346437033808


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  73%|█████████████▏    | 195/266 [3:01:09<1:55:09, 97.32s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  74%|█████████████▎    | 196/266 [3:02:23<1:45:29, 90.43s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  74%|█████████████▎    | 197/266 [3:04:02<1:46:50, 92.91s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  74%|█████████████▍    | 198/266 [3:05:47<1:49:14, 96.39s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  75%|█████████████▍    | 199/266 [3:07:00<1:40:02, 89.58s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  75%|█████████████▌    | 200/266 [3:07:28<1:18:11, 71.08s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  76%|█████████████▌    | 201/266 [3:08:33<1:14:59, 69.23s/it]INFO  : Use t

[ERROR] Could not get mesh for 648518346433868883: [1;91mManifest not found for segment 648518346433868883.[m


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  87%|█████████████████▎  | 231/266 [3:51:59<52:51, 90.60s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  87%|█████████████████▍  | 232/266 [3:53:23<50:12, 88.60s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  88%|██████████████▉  | 233/266 [3:56:18<1:02:53, 114.35s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  88%|████████████████▋  | 234/266 [3:58:04<59:41, 111.92s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  88%|████████████████▊  | 235/266 [3:59:34<54:28, 105.45s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  89%|█████████████████▋  | 236/266 [4:00:35<46:02, 92.09s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  89%|█████████████████▊  | 237/266 [4:01:10<36:09, 74.83s/it]INFO  : Use t

[ERROR] Could not get mesh for 648518346442917710: [1;91mManifest not found for segment 648518346442917710.[m


INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  98%|███████████████████▌| 260/266 [4:38:31<07:03, 70.51s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  98%|███████████████████▌| 261/266 [4:39:02<04:53, 58.71s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  98%|███████████████████▋| 262/266 [4:39:44<03:34, 53.59s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  99%|██████████████████▊| 263/266 [4:48:02<09:20, 186.79s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs:  99%|██████████████████▊| 264/266 [4:49:13<05:04, 152.07s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs: 100%|██████████████████▉| 265/266 [4:50:10<02:03, 123.74s/it]INFO  : Use the `.show()` method to plot the figure. (navis)
Processing SegIDs: 100%|████████████████████| 266/266 [4:51:13<00:00, 65.69s/it]


[INFO] Finished. Final data saved to Pyr_spines_results_20250414_apical.csv


In [43]:
df_results.to_csv('Apical_Dendrite_Spines_Per_Micron_250414.csv')

In [59]:
df_results = pd.read_csv('Apical_Dendrite_Spines_Per_Micron_250414.csv')

In [45]:
def filter_bad_spine_cells(df_results, low=True, num_=10):
    df_results = df_results[df_results['spines_per_micron'].notna()]
    df_results = df_results[df_results['spines_per_micron'] != 0]
    
    if low:
        lowest_ = df_results.nsmallest(num_, 'spines_per_micron')
        print(len(df_results))
        #df_results.head(10)
        print(lowest_)
        
        return lowest_
    
    else:
        highest_ = df_results.nlargest(num_, 'spines_per_micron')
        print(len(df_results))
        #df_results.head(10)
        print(highest_)
        
        return highest_
    
def remove_bad_ids(df_results, low_high):
    print("Before removal:", len(df_results))
    
    # Handle if low_high is a DataFrame
    if isinstance(low_high, pd.DataFrame):
        ids_to_remove = low_high['SegID'].tolist()
    else:
        # Assume it's already a list-like of IDs
        ids_to_remove = low_high

    # Remove rows where SegID is in the list
    df_results = df_results[~df_results['segid'].isin(ids_to_remove)]
    
    print("After removal:", len(df_results))
    return df_results


In [60]:
print(len(df_results))
filtered_results_apical = df_results[df_results['spines_num'] >= 50]
print(len(filtered_results_apical))

266
216


In [61]:
low_ = filter_bad_spine_cells(filtered_results_apical, low=False, num_=10)

216
     Unnamed: 0               segid  spines_per_micron  cable_length  \
73           73  648518346443775693           0.464057  1.249847e+05   
205         205  648518346437543686           0.460742  1.692920e+05   
226         226  648518346442778357           0.448374  2.252584e+05   
140         140  648518346467131377           0.446799  2.327670e+05   
173         173  648518346448511198           0.445816  1.623987e+06   
143         143  648518346439431876           0.441934  8.213905e+05   
168         168  648518346442530823           0.439073  6.172097e+05   
141         141  648518346448344747           0.437081  4.095352e+05   
189         189  648518346448504058           0.436057  1.719960e+05   
202         202  648518346445202966           0.435934  7.088234e+05   

     spines_num  soma_center_x  soma_center_y  soma_center_z        curve  \
73         58.0       632736.0      1112832.0        92790.0   264.231009   
205        78.0       586944.0      1202688.0    

In [57]:
filtered_results_apical.to_csv('Apical_Dendrite_Spines_Per_Micron_Filtered_250414.csv')