In [4]:
#!/usr/bin/env python
# coding: utf-8

import os
import json
from pathlib import Path
import numpy as np
from skimage import filters, color
from skimage.morphology import remove_small_objects, skeletonize
from tifffile import imread, imwrite
from skan import Skeleton, summarize
import pandas as pd
import imageio
from concurrent.futures import ProcessPoolExecutor, as_completed

# Utility functions
def read_config(config_path):
    """Read the configuration file and return parameters."""
    with open(config_path, 'r') as file:
        config = json.load(file)
    return config

def save_config_to_export_folder(config, export_folder):
    """Save the configuration file to the export folder for record keeping."""
    export_path = Path(export_folder) / "config_record.json"
    with open(export_path, 'w') as file:
        json.dump(config, file, indent=4)
    print(f"Configuration file saved to '{export_path}'.")


def read_all_images_in_folder(folder_path):
    """Read all images in the specified folder using imread."""
    folder = Path(folder_path)
    images = [(image_file.name, imread(image_file)) 
              for image_file in folder.glob('*') 
              if image_file.suffix in ['.tif', '.tiff', '.png', '.jpg', '.jpeg']]
    return images

def filter_and_skeletonize(image, min_size=200, sigma=1):
    """Filter and skeletonize the image."""
    filt_image = filters.gaussian(image, sigma=sigma)
    otsu_image = filt_image > filters.threshold_otsu(filt_image)
    large_obj_image = remove_small_objects(otsu_image, min_size=min_size)
    skel_image = skeletonize(large_obj_image, method='lee')
    return skel_image, large_obj_image

def process_branch_data(skel_image, branch_table_path):
    """Summarize branch data and export to CSV."""
    branch_table = summarize_branch_data(skel_image)
    branch_table.to_csv(branch_table_path, index=False)
    print(f"Table exported as '{branch_table_path.name}'.")
    return branch_table
    
def classify_nerve_endings(branch_table,min_branch_length):
    """Classify nerve endings based on the branch table."""
    filtered_table = branch_table[
        (branch_table["branch-distance"] > min_branch_length) & branch_table["branch-type"].isin([1])
    ]
    end_node_coordinates = [(int(row["end-node-1"]), int(row["end-node-0"])) for _, row in filtered_table.iterrows()]
    return end_node_coordinates

def create_terminal_only_image(original_image, end_node_coordinates, square_size=60):
    """Create a new image where only the pixels within 100x100 squares around the end nodes are kept."""
    new_image = np.zeros_like(original_image, dtype=np.uint8)  # Create an empty image of the same size
    
    half_square = square_size // 2
    
    for coord in end_node_coordinates:
        x, y = coord
        x_min, x_max = max(0, x - half_square), min(original_image.shape[1], x + half_square)
        y_min, y_max = max(0, y - half_square), min(original_image.shape[0], y + half_square)
        
        # Keep the pixel values from the original image within the square
        new_image[y_min:y_max, x_min:x_max] = original_image[y_min:y_max, x_min:x_max]
    
    return new_image

def create_terminal_red_image(original_image, end_node_coordinates, square_size):
    """Create an image where the terminal areas are highlighted in red, overlaid on the original image."""
    
        # Initialize a black background image (same size as original image with 3 channels for RGB)
    height, width = original_image.shape
    output_image = np.zeros((height, width, 3), dtype=np.uint8)
    
    # Convert the original grayscale image to white objects on black background
    output_image[original_image > 0] = [255, 255, 255]  # White objects

    half_square = square_size // 2
    
    # Create a red overlay for terminal areas
    height, width = original_image.shape
    red_overlay = np.zeros((height, width, 3), dtype=np.uint8)
    red_overlay[original_image > 0] = [255, 0, 0]  
   
    # Apply the red overlay for each end node coordinate
    for coord in end_node_coordinates:
        x, y = coord
        x_min = max(0, x - half_square)
        x_max = min(output_image.shape[1], x + half_square)
        y_min = max(0, y - half_square)
        y_max = min(output_image.shape[0], y + half_square)
        
        # Ensure overlay is within the bounds of the image
        if x_min < x_max and y_min < y_max:
            output_image[y_min:y_max, x_min:x_max] = red_overlay[y_min:y_max, x_min:x_max]
    
    return output_image


def summarize_branch_data(skel_image):
    """Summarize branch data from skeletonized image."""
    branch_data = summarize(Skeleton(skel_image))
    branch_table = pd.DataFrame(branch_data)
    total_node_ids = branch_table["node-id-src"].tolist() + branch_table["node-id-dst"].tolist()

    branch_table[["end-type", "end-node-0", "end-node-1"]] = branch_table.apply(
        lambda row: determine_values(row, total_node_ids), axis=1, result_type="expand"
    )
    return branch_table

def determine_values(row, total_node_ids):
    """Determine the value for end node columns."""
    src_node_num = total_node_ids.count(row["node-id-src"])
    dst_node_num = total_node_ids.count(row["node-id-dst"])
    
    if src_node_num == 1 and dst_node_num == 1:
        return -2, None, None
    elif src_node_num == 1:
        return 0, row["coord-src-0"], row["coord-src-1"]
    elif dst_node_num == 1:
        return 1, row["coord-dst-0"], row["coord-dst-1"]
    else:
        return -1, None, None
        
def process_image_skeleton(image, skeleton_path,min_size):
    """Filter, skeletonize, and export skeletonized image."""
    skel_image, large_obj_image = filter_and_skeletonize(image,min_size)
    imwrite(skeleton_path, skel_image)
    print(f"Skeleton exported as '{skeleton_path.name}'.")
    return skel_image, large_obj_image
    
def prepare_export_paths(export_folder, image_prefix):
    """Prepare export paths for different outputs."""
    base_path = Path(export_folder)
    paths = {
        'skeleton': base_path / f'skeletons/{image_prefix}_skeleton.tif',
        'branch_table': base_path / f'branch_tables/{image_prefix}_branch_table.csv',
        'terminal_only_image': base_path / f'terminal_only_images/{image_prefix}_terminal_only.tif',
        'terminal_red_image': base_path / f'terminal_red_images/{image_prefix}_terminal_red.tif',
        'coordinates': base_path / f'end_node_coordinates/{image_prefix}_coordinates.csv'  # New folder for coordinates
    }
    for path in paths.values():
        path.parent.mkdir(parents=True, exist_ok=True)
    return paths
    
def export_coordinates(end_node_coordinates, export_path):
    """Export end node coordinates to a CSV file."""
    coordinates_df = pd.DataFrame(end_node_coordinates, columns=['x', 'y'])
    coordinates_df.to_csv(export_path, index=False)
    print(f"Coordinates exported as '{export_path.name}'.")

def process_image(image_name, image, export_folder, min_size=200, square_size=100, min_branch_length=15):
    """Process an image, create a terminal only image, and export it."""
    image_prefix = Path(image_name).stem
    export_paths = prepare_export_paths(export_folder, image_prefix)

    # Filter and skeletonize
    
    skel_image, large_obj_image = process_image_skeleton(image, export_paths['skeleton'],min_size)
    
    # Analyze skeleton and store as a table
    branch_table = process_branch_data(skel_image, export_paths['branch_table'])
    
    #
    end_node_coordinates = classify_nerve_endings(branch_table,min_branch_length)
    export_coordinates(end_node_coordinates, export_paths['coordinates'])

    # Create terminal-only image
    
    terminal_image = create_terminal_only_image(large_obj_image, end_node_coordinates,square_size)
    terminal_red = create_terminal_red_image(large_obj_image, end_node_coordinates,square_size)
    # Export the image
    imwrite(export_paths['terminal_only_image'], terminal_image)
    print(f"Terminal-only image exported as '{export_paths['terminal_only_image'].name}'.")
   
    imageio.imwrite(export_paths['terminal_red_image'], terminal_red)
    print(f"Terminal-red image exported as '{export_paths['terminal_red_image'].name}'.")
# Main function

def process_image_parallel(args):
    """Wrapper function to process a single image with multiple arguments."""
    image_name, image, export_folder, min_size, square_size, min_branch_length = args
    process_image(image_name, image, export_folder, min_size, square_size, min_branch_length)


def main(config_path):
    config = read_config(config_path)
    import_folder = config["import_folder"]
    export_folder = config["export_folder"]
    min_size = config.get("min_size", 200)  # Default value is 200 pixels
    square_size = config.get("square_size", 100)  # Default value is 100 pixels
    min_branch_length = config.get("min_branch_length", 15)  # Default value is 15 pixels

    # Prepare export folder and save configuration
    Path(export_folder).mkdir(parents=True, exist_ok=True)
    save_config_to_export_folder(config, export_folder)
    
    # Read all images
    images = read_all_images_in_folder(import_folder)
    print(f"Found {len(images)} images to process.")

    # Prepare arguments for parallel processing
    task_args = [
        (image_name, image, export_folder, min_size, square_size, min_branch_length)
        for image_name, image in images
    ]

    # Parallelize the image processing using ProcessPoolExecutor
    with ProcessPoolExecutor() as executor:
        futures = [executor.submit(process_image_parallel, args) for args in task_args]
        for future in as_completed(futures):
            try:
                future.result()  # Wait for the task to complete and handle exceptions
            except Exception as e:
                print(f"An error occurred: {e}")

if __name__ == "__main__":
    config_path = "./example.json"
    main(config_path)


Configuration file saved to 'export_folder/config_record.json'.
Found 1 images to process.
Skeleton exported as 'example_image_skeleton.tif'.
Table exported as 'example_image_branch_table.csv'.
Coordinates exported as 'example_image_coordinates.csv'.
Terminal-only image exported as 'example_image_terminal_only.tif'.
Terminal-red image exported as 'example_image_terminal_red.tif'.
