#### This code allows designated skeletons in specified bbox/stls to be retrived with pymaid and navis. Nodes in ROI are first extracted and re-skeletonise. Node coords are saved in .swc file, along with a sanity check with 3D visualistion.  

In [1]:
import pymaid

url = "https://neurophyla.mrc-lmb.cam.ac.uk/catmaid/fibsem/#"
token = "x"
name = "x"
password = "x"
project_id = x
rm = pymaid.CatmaidInstance(url, token, name, password, project_id)

INFO  : Global CATMAID instance set. Caching is ON. (pymaid)


In [3]:
n = pymaid.get_neurons_in_volume('f1 compartment (Right)', remote_instance=None)

INFO  : Retrieving neurons in volume f1 compartment (Right) (pymaid)
INFO  : Done. 345 unique neurons found in volume(s) f1 compartment (Right) (pymaid)


In [32]:
import os
import numpy as np
import pandas as pd
import navis
import pymaid
import trimesh


def crop_tree_neuron_by_mesh_contains(
    neuron_full: navis.TreeNeuron,
    roi_mesh: trimesh.Trimesh,
    prevent_fragments: bool = True,
) -> navis.TreeNeuron:
    pts = neuron_full.nodes[['x', 'y', 'z']].to_numpy()
    inside = roi_mesh.contains(pts)  # requires rtree installed
    node_ids_inside = neuron_full.nodes.loc[inside, 'node_id'].tolist()

    n_part = navis.subset_neuron(
        neuron_full,
        subset=node_ids_inside,
        inplace=False,
        prevent_fragments=prevent_fragments
    )
    return n_part


def process_skeleton_ids_crop_and_save(
    skeleton_ids,
    roi_mesh_or_path,
    out_dir="cropped_neurons_out",
    threshold=10,
    prevent_fragments=True,
):
    """
    For each skeleton_id:
      1) pymaid.get_neuron(skid)
      2) crop by ROI mesh using contains()->subset_neuron
      3) classify by cropped node count:
           - >= threshold => "long" (saved to capped_{threshold}nodes_long_skels)
           - <  threshold => "small" (saved to capped_{threshold}nodes_small_skels)
      4) save BOTH groups as SWC
      5) save a summary CSV
      6) return (long_list, small_list, summary_df, roi_mesh)
    """
    # ---- Load ROI mesh ----
    if isinstance(roi_mesh_or_path, str):
        roi = trimesh.load(roi_mesh_or_path, force="mesh")
        if isinstance(roi, trimesh.Scene):
            roi = trimesh.util.concatenate(tuple(roi.dump()))
    else:
        roi = roi_mesh_or_path

    # ---- Output dirs (always save both) ----
    long_dir = os.path.join(out_dir, f"capped_{threshold}nodes_long_skels")
    small_dir = os.path.join(out_dir, f"capped_{threshold}nodes_small_skels")
    os.makedirs(long_dir, exist_ok=True)
    os.makedirs(small_dir, exist_ok=True)

    long_parts = []
    small_parts = []
    rows = []

    for skid in skeleton_ids:
        skid_int = int(skid)

        try:
            # 1) Load full neuron
            neuron_full = pymaid.get_neuron(skid_int)

            # Safety: ensure it has nodes
            if not hasattr(neuron_full, "nodes") or neuron_full.nodes is None or len(neuron_full.nodes) == 0:
                rows.append({
                    "skeleton_id": skid_int,
                    "full_nodes": 0,
                    "cropped_nodes": 0,
                    "status": "error_no_nodes"
                })
                continue

            full_nodes = int(len(neuron_full.nodes))

            # 2) Crop by ROI
            n_part = crop_tree_neuron_by_mesh_contains(
                neuron_full,
                roi,
                prevent_fragments=prevent_fragments
            )

            cropped_nodes = int(len(n_part.nodes)) if hasattr(n_part, "nodes") and n_part.nodes is not None else 0

            # 3) Classify + 4) Save BOTH groups
            if cropped_nodes >= threshold:
                status = "kept"
                long_parts.append(n_part)
                out_path = os.path.join(long_dir, f"{skid_int}_cropped.swc")
                navis.write_swc(n_part, out_path)
            else:
                status = "remove"
                small_parts.append(n_part)
                out_path = os.path.join(small_dir, f"{skid_int}_cropped.swc")
                navis.write_swc(n_part, out_path)

            # 5) Record
            rows.append({
                "skeleton_id": skid_int,
                "full_nodes": full_nodes,
                "cropped_nodes": cropped_nodes,
                "status": status
            })

        except Exception as e:
            print(f"⚠️ Error processing skeleton_id={skid_int}: {e}")
            rows.append({
                "skeleton_id": skid_int,
                "full_nodes": None,
                "cropped_nodes": None,
                "status": "error"
            })

    long_list = navis.NeuronList(long_parts)
    small_list = navis.NeuronList(small_parts)
    df = pd.DataFrame(rows)

    # Save summary CSV (include threshold in filename)
    os.makedirs(out_dir, exist_ok=True)
    df.to_csv(os.path.join(out_dir, f"cropping_summary_threshold_{threshold}.csv"), index=False)

    return long_list, small_list, df, roi


In [None]:
long_skels, small_skels, summary_df, roi_mesh = process_skeleton_ids_crop_and_save(
    skeleton_ids=n,            # your list of skids
    roi_mesh_or_path="roi.stl", # insert roi
    out_dir="roi_cropped_results",
    threshold=20,
    prevent_fragments=True
)

print("Kept:", len(long_skels))
print("Fragmented:", len(small_skels))


In [37]:
#sanity check 
skid = 12366036

In [38]:
import navis
import numpy as np

# --- Full neuron ---
neuron_full = pymaid.get_neuron(skid)

# --- Node coordinates ---
pts = neuron_full.nodes[['x','y','z']].to_numpy()

# --- Inside / Outside masks ---
inside_mask = roi_mesh.contains(pts)
outside_mask = ~inside_mask

# --- Subset neurons ---
n_inside = navis.subset_neuron(
    neuron_full,
    subset=neuron_full.nodes.loc[inside_mask, "node_id"].tolist(),
    inplace=False,
    prevent_fragments=True
)

n_outside = navis.subset_neuron(
    neuron_full,
    subset=neuron_full.nodes.loc[outside_mask, "node_id"].tolist(),
    inplace=False,
    prevent_fragments=True
)

print("Inside nodes:", len(n_inside.nodes))
print("Outside nodes:", len(n_outside.nodes))

# --- Plot together ---
fig = navis.plot3d(
    [n_inside, n_outside],
    backend="plotly",
    inline=False,
    color=["red", "green"],
    linewidth=3,
    alpha=[1, 0.6]
)

# Overlay ROI mesh
navis.plot3d(
    roi_mesh,
    fig=fig,
    color="blue",
    alpha=0.15
)



INFO  : Cached data used. Use `pymaid.clear_cache()` to clear. (pymaid)
INFO  : Cached data used. Use `pymaid.clear_cache()` to clear. (pymaid)
INFO  : Use the `.show()` method to plot the figure. (navis)


Inside nodes: 26
Outside nodes: 26
