In [1]:
import cv2
from pyphenotyper.logger_config import logger
import os
import numpy as np
import networkx as nx
import skan
from skan import Skeleton, summarize
import skimage
import pandas as pd
import math
from typing import List, Tuple, Union, Optional
from concurrent.futures import ThreadPoolExecutor
import warnings
from collections import deque
import matplotlib.pyplot as plt
from rich.progress import track

In [2]:
def load_in_mask(path: str) -> Optional[np.ndarray]:
    """
    This function takes in a path to the mask, loads it in, segments it, removes all the object that are smaller than
    an average area of all segmented instances, transfers the image back to BGR
    :param path: Path to the mask image.
    :type path: str
    :return: Loaded in mask image with root segmentation in BGR format.
    :rtype: np.ndarray
    """
    # Check if the path is a string
    if not isinstance(path, str):
        warnings.warn(
            f"Expected 'path' to be a string, but got {type(path).__name__}.")
        return None
    # Read the mask
    mask = cv2.imread(path, 0)
    if mask is None:
        raise ValueError(f"Could not read the image at {path}")
    mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
    return mask


def load_in_seedlings_template(template_path: str) -> Optional[np.ndarray]:
    """
    This function loads in the template with the seed positions
    :param template_path: Path to the template image with seed positions.
    :type template_path: str
    :return: Loaded template image with seed positions.
    :rtype: np.ndarray
    """
    # Check if the path is a string
    if not isinstance(template_path, str):
        warnings.warn(
            f"Expected 'path' to be a string, but got {type(template_path).__name__}.")
        return None
    template = cv2.imread(template_path, 0)
    if template is None:
        raise ValueError(f"Could not read the image at {template_path}")
    return template


def get_plants_bboxes(path: str, template_path: str) -> List[List[int]]:
    """
    This function takes in the paths to the images, finds and returns the bounding boxes for each seed possible position
    :param path: Path to a mask image.
    :type path: str
    :param template_path: Path to the template image with seed positions.
    :type template_path: str
    :return: Array with the y, y_max, x, x_max coordinates that represent the bounding
             boxes for each seed possible position.
    :rtype: List[List[int]]
    """
    # load in images
    mask = load_in_mask(path)
    template = load_in_seedlings_template(template_path)
    # resize and prepare the template 
    mask_seeding_position = cv2.resize(
        template, (mask.shape[1], mask.shape[0]))
    mask_seeding_position[mask_seeding_position < 200] = 0
    mask_seeding_position[mask_seeding_position != 0] = 255
    # segment the template to get the bounding boxes for each seed position
    mask_seeding_position_label = skimage.measure.label(
        mask_seeding_position).astype('uint8')
    _, template_segmented, stats, _ = cv2.connectedComponentsWithStats(
        mask_seeding_position_label)
    plants = []
    stats = stats[1::]

    # loop through all the stats, get coordinates, transform them into y, y_max, x, x_max and append to an array

    for j in range(len(stats)):
        if j == len(stats[1::]):
            point = stats[j][0] + stats[j][2]
        else:
            point = stats[j][0] + stats[j][2] + int((stats[j + 1][0] - stats[j][0] - stats[j][2]) / 2)
        plants.append([stats[j][1], stats[j][1] + stats[j][3], stats[j][0], point])
    return plants


def determine_starting_nodes(branch: pd.DataFrame) -> List[int]:
    """
    Determines all the starting points in the branch
    :param branch: DataFrame summarizing the skeleton of the branch with columns 'node-id-src' and 'node-id-dst'.
    :type branch: pd.DataFrame
    :return: List of node IDs that are starting points.
    :rtype: List[int]
    """
    # get uniques src and dst nodes
    src_nodes = branch['node-id-src'].unique()
    dst_nodes = branch['node-id-dst'].unique()
    starting_nodes = []
    # loop through all src nodes if it's not in dst nodes -> append to starting nodes array
    for node in src_nodes:
        if node not in dst_nodes:
            starting_nodes.append(node)
    return starting_nodes


def getting_coords_for_starting_nodes(branch: pd.core.frame.DataFrame) -> List[Tuple[int, int]]:
    """
    Finds coordinates for all the starting nodes in the branch
    :param branch: DataFrame with all the branches in it (graph representation).
    :type branch: pd.DataFrame
    :return: List of tuples with starting nodes coordinates (y, x).
    :rtype: List[Tuple[int, int]]
    """
    starting_nodes_list = determine_starting_nodes(branch)
    start_nodes_coordinates = []
    for id in starting_nodes_list:
        x = branch[branch['node-id-src'] == id]['image-coord-src-1'].iloc[0]
        y = branch[branch['node-id-src'] == id]['image-coord-src-0'].iloc[0]
        start_nodes_coordinates.append((y, x))
    return start_nodes_coordinates


def map_coords_to_node_ids(coordinates: List[Tuple[int, int]], branch: pd.DataFrame) -> List[int]:
    """
    Map coordinates to node IDs.
    This function takes a list of coordinates and a DataFrame representing
    a branch, and maps each coordinate to its corresponding node ID.
    :param coordinates: List of coordinates (y, x).
    :type coordinates: List[Tuple[int, int]]
    :param branch: DataFrame with all the branches in it (graph representation).
    :type branch: pd.DataFrame
    :return: List of node IDs corresponding to the provided coordinates.
    :rtype: List[int]
    """
    node_ids = []
    for y, x in coordinates:
        if len(branch[(branch['image-coord-src-1'] == x) & (branch['image-coord-src-0'] == y)]['node-id-src']) != 0:
            node_ids.append(
                branch[(branch['image-coord-src-1'] == x) & (branch['image-coord-src-0'] == y)]['node-id-src'].iloc[0])
    return node_ids


def can_be_skeletonized(image: np.ndarray) -> bool:
    """
    Check if there is at least one connected component of two or more pixels in the given image, to determine if it
    can be skeletonized or not.

    :param image: A numpy array representing the image.
    :type image: np.ndarray
    :return: True if there is a connected component of two or more pixels (can be skeletonized), False otherwise.
    :rtype: bool
    """
    rows, cols = image.shape
    visited = set()
    directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]

    def bfs(start: Tuple[int, int]) -> bool:
        """
        Perform Breadth-First Search (BFS) to check for connected pixels with value 1.

        :param start: The starting pixel coordinates.
        :return: True if a connected component of two or more pixels is found, False otherwise.
        """
        queue = deque([start])
        visited.add(start)
        component_size = 0

        while queue:
            x, y = queue.popleft()
            component_size += 1
            for dx, dy in directions:
                nx, ny = x + dx, y + dy
                if 0 <= nx < rows and 0 <= ny < cols and (nx, ny) not in visited:
                    if image[nx, ny] == 1:
                        visited.add((nx, ny))
                        queue.append((nx, ny))

        # Check if the component size is 2 or more
        return component_size > 1

    for i in range(rows):
        for j in range(cols):
            if image[i, j] == 1 and (i, j) not in visited:
                if bfs((i, j)):
                    return True

    return False


def process_plant(plant: Tuple[int, int, int, int], img: np.ndarray, starts_c: List[Tuple[int, int, int]],
                  n: int) -> None:
    """
    Process a single plant to find the most central starting node.
    :param plant: Bounding box coordinates of the plant (y1, y2, x1, x2).
    :type plant: Tuple[int, int, int, int]
    :param img: Image array.
    :type img: np.ndarray
    :param starts_c: List to store the starting node coordinates and distance for each plant.
    :type starts_c: List[Tuple[int, int, int]]
    :param n: Index of the current plant.
    :type n: int
    """
    subset = img[plant[0]:plant[1], plant[2]:plant[3]]
    subset_elements = img[plant[0]:, plant[2]:plant[3]]
    subset_gray = cv2.cvtColor(subset, cv2.COLOR_BGR2GRAY)
    subset_elements_gray = cv2.cvtColor(subset_elements, cv2.COLOR_BGR2GRAY)
    subset_bin = ((subset_gray > 0) * 1).astype('uint8')
    subset_elements_bin = ((subset_elements_gray > 0) * 1).astype('uint8')
    subset_skeleton = skimage.morphology.skeletonize(subset_bin)
    subset_elements_skeleton = skimage.morphology.skeletonize(
        subset_elements_bin)
    if len(np.unique(subset_skeleton)) > 1 and can_be_skeletonized(subset_skeleton):
        subset_skeleton_ob = Skeleton(subset_skeleton)
        subset_branch = summarize(subset_skeleton_ob)
        subset_skeleton_elem_ob = Skeleton(subset_elements_skeleton)
        subset_elem_branch = summarize(subset_skeleton_elem_ob)
        G = nx.from_pandas_edgelist(subset_elem_branch, source='node-id-src', target='node-id-dst',
                                    edge_attr='branch-distance')
        starting_nodes_coordinates = getting_coords_for_starting_nodes(
            subset_branch)
        node_ids_mapped = map_coords_to_node_ids(
            starting_nodes_coordinates, subset_elem_branch)
        y_start, x_start = 0, 0
        max_number_of_elements = -1
        for k in range(len(node_ids_mapped)):
            current_node = node_ids_mapped[k]
            last_element = sorted(
                nx.node_connected_component(G, current_node))[-1]
            if last_element - current_node > max_number_of_elements:
                y_start, x_start = starting_nodes_coordinates[k]
                max_number_of_elements = last_element - current_node
        starts_c[n] = (x_start, y_start, 0)


def get_most_central_starting_node_for_each_plant(path: str, template_path: str, mask: np.ndarray = None) -> List[
    Tuple[int, int, int]]:
    """
    Get the most central starting node for each plant.
    This function identifies all starting nodes for each plant, determines the one closest to
    the center of the bounding box for each plant, and appends its coordinates and distance to the center.
    :param path: Path to the mask image.
    :type path: str
    :param template_path: Path to the template image with seed positions.
    :type template_path: str
    :param mask: Pre-connected mask, if available. Otherwise, the mask is loaded from the path.
    :type mask: np.ndarray, optional
    :return: List of tuples with the x, y coordinates and distance to the center of the bounding box
             for seed location of each plant.
    :rtype: List[Tuple[int, int, int]]
    """
    img = load_in_mask(path)
    plants = get_plants_bboxes(path, template_path)
    number_of_plants = len(plants)
    starts_c = [(0, 0, 0)] * number_of_plants
    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(process_plant, plant, img, starts_c, n)
                   for n, plant in enumerate(plants)]
        for future in futures:
            future.result()
    return starts_c


def determine_primary_roots_locations(path: str, template_path: str) -> List[List[int]]:
    """
    Based on the location of the bounding boxes in which the seeds are located,
    get a slice of the mask where the primary root is supposed to be
    (x min : x max, y min : bottom of the image) and return the coordinates.
    :param path: Path to the mask image.
    :type path: str
    :param template_path: Path to the template image with seed positions.
    :type template_path: str
    :return: List of coordinates [y1, x1, x_right_bound] for each plant.
    :rtype: List[List[int]]
    """
    primary_roots_locations = []
    # get the bounding boxes where seeds can be located for each plant
    plants = get_plants_bboxes(path, template_path)
    # loop through all the bounding boxes
    for i in range(1, len(plants) + 1):
        if i < len(plants):
            # get the coordinates of the previous plant bounding box
            y1, y2, x1, x2 = plants[i - 1]
            # get the coordinates of the current plant bounding box ( only right bound is needed to find the slice
            # that is needed)
            _, _, x_right_bound, _ = plants[i]
            # append the coordinates to the array (the right bound for each plant is the starting coordinate for the
            # next plant -> width slice = from start point of the n's plant to start of the n + 1 plant )
            primary_roots_locations.append([y1, x1, x_right_bound])
        # if it's the last plant the right bound is the right bound of its bounding box
        elif i == len(plants):
            y1, y2, x1, x2 = plants[i - 1]
            primary_roots_locations.append([y1, x1, x2])
    return primary_roots_locations


def map_root_to(coordinates: Tuple[int, int], plant_num: int, branch: pd.DataFrame,
                offset_y: int, offset_x: int, root_type: str, target: float) -> pd.DataFrame:
    """
    Map root to a specific plant and root type based on coordinates and offset.
    :param coordinates: Coordinates to map (y, x).
    :type coordinates: Tuple[int, int]
    :param plant_num: Plant number to which the root belongs.
    :type plant_num: int
    :param branch: DataFrame containing branch information.
    :type branch: pd.DataFrame
    :param offset_y: Y offset to add to the coordinates.
    :type offset_y: int
    :param offset_x: X offset to add to the coordinates.
    :type offset_x: int
    :param root_type: Type of the root to map.
    :type root_type: str
    :param target: Target branch distance.
    :type target: float
    :return: Updated DataFrame with mapped root information.
    :rtype: pd.DataFrame
    """
    condition_1 = (
            (branch['image-coord-dst-0'] == coordinates[0] + offset_y) &
            (branch['image-coord-dst-1'] == coordinates[1] + offset_x) &
            (branch['branch-distance'] == target)
    )
    branch.loc[condition_1, 'root_type'] = root_type
    branch.loc[condition_1, 'plant'] = plant_num
    return branch


def are_consecutive(lst: List[int], num1: np.int32, num2: np.int32) -> bool:
    """
    Determine if two numbers appear consecutively in the list.
    This function iterates through the list, checking if num1 is immediately followed by num2 at any position.
    If such a pair is found, the function returns True, indicating that num1 and num2 are consecutive in the list.
    If the list does not contain num1 followed directly by num2, the function returns False.
    :param lst: The list of elements to be checked. The elements should be of a type that supports comparison.
    :type lst: List[int]
    :param num1: The first number to check for consecutiveness.
    :type num1: np.int32
    :param num2: The second number to check for consecutiveness.
    :type num2: np.int32
    :return: True if num1 is immediately followed by num2 in the list, otherwise False.
    :rtype: bool
    :example:
    >>> are_consecutive([1, 2, 3, 4, 5], 2, 3)
    True
    >>> are_consecutive([1, 2, 4, 5], 3, 4)
    False
    """
    for i in range(len(lst) - 1):
        if lst[i] == num1 and lst[i + 1] == num2:
            return True
    return False


def prepare_data_for_segmentation(path: str, template_path: str) -> List[np.ndarray]:
    """
    Prepare all the data for the segmentation.
    This function prepares the following data for segmentation:
        - primary_root_locations: array with the coordinates to slice the image and focus on each plant's primary root
        - start_node_coordinates: array with the coordinates of the most central starting nodes for each plant
        - plant_bboxes: bounding boxes representing the possible locations for seeds for each plant
        - connected_mask: mask where all the starting nodes that are close enough to the central node for each plant are connected
    :param path: Path to the mask image.
    :type path: str
    :param template_path: Path to the template image with seed positions.
    :type template_path: str
    :return: List containing primary_root, skeleton, primary_root_locations, start_node_coordinates,
             plant_bboxes, and connected_mask.
    :rtype: List[np.ndarray]
    """
    # Get the coordinates that are used to slice the image and focus on each primary root
    primary_roots_locations = determine_primary_roots_locations(
        path, template_path)
    # Transform mask to binary for a more accurate skeleton
    mask = load_in_mask(path)
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    mask = ((mask > 0) * 1).astype('uint8')
    # Get the starting node that is the part of the largest skeleton in the selected area representing the possible
    # seed location
    start_node_coordinates = get_most_central_starting_node_for_each_plant(
        path, template_path, mask=mask)
    # Get bounding boxes for each plant's possible seed location
    plants_bboxes = get_plants_bboxes(path, template_path)
    return [primary_roots_locations, start_node_coordinates, plants_bboxes, mask]

In [3]:

def determine_ending_nodes(branch: pd.DataFrame) -> List[int]:
    """
    Determines all the starting points in the branch
    :param branch: DataFrame summarizing the skeleton of the branch with columns 'node-id-src' and 'node-id-dst'.
    :type branch: pd.DataFrame
    :return: List of node IDs that are starting points.
    :rtype: List[int]
    """
    # get uniques src and dst nodes
    src_nodes = branch['node-id-src'].unique()
    dst_nodes = branch['node-id-dst'].unique()
    
    starting_nodes = []
    # loop through all src nodes if it's not in dst nodes -> append to starting nodes array
    for node in dst_nodes:
        if node not in src_nodes:
            starting_nodes.append(node)
    return starting_nodes

In [4]:
def find_end_node(graph):
    """
    Identifies and returns the end nodes in a given graph structure. End nodes are those that only appear as destinations, not as sources.

    Args: 
    graph (DataFrame): A graph represented as a DataFrame with source and destination nodes.
    
    Returns: 
    list: A list of end nodes.
    """

    src = list(graph['node-id-src'])
    end_nodes = []
    
    for destination in list(graph['node-id-dst']):
    
        if destination not in src:
            end_nodes.append(destination)
            
    return (end_nodes)

def root_to_graph(image, root_coordinates):
    """
    Converts a root image segment into a graph representation. Extracts a sub-image defined by root coordinates and processes it into a graph structure.

    Args:
    image (ndarray): The complete image of roots.
    root_coordinates (list): Coordinates defining the sub-image to process.

    Returns:
    DataFrame: Graph representation of the root segment.
    """

    y, x, x_max = root_coordinates
    
    plant = image[y::, x:x_max]
    plant_skeleton = skimage.morphology.skeletonize(plant)
  
    plant_branch = summarize(Skeleton(plant_skeleton))

    return plant_branch

In [5]:
def find_intersection(x1_1, y1_1, x1_2, y1_2, x2_1, y2_1, x2_2, y2_2):
    # Calculate the direction vectors of the lines
    a1 = x1_2 - x1_1
    b1 = y1_2 - y1_1
    a2 = x2_2 - x2_1
    b2 = y2_2 - y2_1
    
    # Form the system of linear equations
    A = np.array([[a1, -a2],
                  [b1, -b2]])
    B = np.array([x2_1 - x1_1,
                  y2_1 - y1_1])
    
    # Solve the system of equations
    try:
        t, u = np.linalg.solve(A, B)
    except np.linalg.LinAlgError:
        # The lines are parallel (no intersection or infinite intersections)
        return None
    
    # Calculate the intersection point
    intersection_x = x1_1 + t * a1
    intersection_y = y1_1 + t * b1
    
    return intersection_x, intersection_y

In [6]:
segmented_plants = cv2.imread(
    "/Users/work_uni/Documents/GitHub/AIxPlant_Science/overlapping_roots/final_pipeline_masks/masks/43-18-ROOT1-2023-08-08_pvd_OD001_f6h1_02-Fish Eye Corrected_root_mask.png"
)

image = cv2.imread("/Users/work_uni/Documents/GitHub/AIxPlant_Science/overlapping_roots/final_pipeline_masks/input_images/43-19-ROOT1-2023-08-08_pvd_OD001_f6h1_02-Fish Eye Corrected_original_padded.png")
    
roots_bboxes = determine_primary_roots_locations("/Users/work_uni/Documents/GitHub/AIxPlant_Science/overlapping_roots/final_pipeline_masks/masks/43-19-ROOT1-2023-08-08_pvd_OD001_f6h1_02-Fish Eye Corrected_root_mask.png", "/Users/work_uni/Documents/GitHub/2023-24d-fai2-adsai-group-cv2/pyphenotyper/assets/seeding_template.tif")


In [7]:
def find_slope(line_point1, line_point2):
    x1, y1 = line_point1
    x2, y2 = line_point2
    
    # Calculate the slope of the line
    if x2 == x1:
        m = float('inf')
    else:
        # Calculate the slope of the line
        m = (y2 - y1) / (x2 - x1)
        
    return m
    
def angle_with_line(xi, yi, x1, y1, x2, y2):
    # Calculate the slope of the line segment from (xi, yi) to (x2, y2)
    m = find_slope(
        (x1, y1),
        (x2, y2)
    )
    if xi == x2:
        m1 = float('inf')  # Handle vertical line segment
    else:
        m1 = (yi - y2) / (xi - x2)
    
    if m1 == float('inf'):
        if m == float('inf'):
            return 0  # Both lines are vertical
        else:
            return 90
    elif m == float('inf'):
        theta_radians = math.atan(m1)

        # Convert the angle to degrees
        theta_degrees = math.degrees(theta_radians)
        
        # Calculate the angle with the vertical line
        # angle_with_vertical = 90 - theta_degrees
        return theta_degrees
        
        # Convert the angle from radians to degrees
    tan_theta = abs((m1 - m) / (1 + m1 * m))
    theta = np.arctan(tan_theta)
    angle_in_degrees = np.degrees(theta)
    
    return angle_in_degrees


In [25]:
def calculate_curvature(x, y):
    # Formula comes from here
    # https://en.wikipedia.org/wiki/Curvature

    
    # First derivates
    dx = np.gradient(x)
    dy = np.gradient(y)

    # Second derivates
    ddx = np.gradient(dx)
    ddy = np.gradient(dy)

    # Calculate the curvature at each point
    curvature = np.abs(dx * ddy - dy * ddx) / (dx**2 + dy**2)**(3/2)

    # curvature is an array of values that contains curvature of each point
    # We want the curvature of the branch, so I can return the average
    return np.mean(curvature)

In [26]:
def get_coords(x_start, y_start, branch, mask, points_total, blue_nodes, points = []):
    # points.append((int(x_start), int(y_start)))
    x_dsts = branch[
            (branch['image-coord-src-1'] == x_start) & (branch['image-coord-src-0'] == y_start)
        ]['image-coord-dst-1']
        
    y_dsts = branch[
            (branch['image-coord-src-1'] == x_start) & (branch['image-coord-src-0'] == y_start)
        ]['image-coord-dst-0']
    
    if len(x_dsts) == 0:
        return mask, points, points_total, blue_nodes
    
    smalles_angle = np.inf
    x_smallest = 0
    y_smallest = 0
    accum_angle = 0
    for i in range(len(x_dsts)):
        x = int(x_dsts.iloc[i])
        y = int(y_dsts.iloc[i])
        blue_nodes.append((x, y))
        # points.append((x, y))
        
        # print(x, y)
        curr_angle = angle_with_line(
            x, 2618 - y, points_total[-2][0], 2618 - points_total[-2][1], points_total[-1][0], 2618 - points_total[-1][1]
        )
        # mask = cv2.putText(mask, f'{curvature}',(int(x) - 150,int(y) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1, cv2.LINE_AA)
        # if x > points_total[-1][0]:
        #     curr_angle += accum_angle
        # else:
        #     curr_angle -= accum_angle
        # print(f'curr angle - {curr_angle}')
        
        mask = cv2.circle(mask, (int(x), int(y)), 5, (255,0,0),3)
        # mask = cv2.putText(mask, f'({int(x)}, {int(y)} - {curr_angle})', (int(x) - 10,int(y) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1, cv2.LINE_AA)
        
        
        # curr_angle = calculate_angle_with_horizontal(
        #     (int(x_start), int(y_start)),
        #     (x, y)
        # )
        
        if curr_angle < smalles_angle:
            blue_nodes.remove((x_smallest, y_smallest))
            x_smallest = x
            y_smallest = y
            smalles_angle = curr_angle
        
        # if abs(curr_angle)<=smalles_angle:
        #     if y_smallest < y and curr_angle == smalles_angle:
        #         x_smallest = x
        #         y_smallest = y
        #         smalles_angle = abs(curr_angle)
        #         accum_angle = curr_angle
        #     elif y_smallest > y and curr_angle == smalles_angle:
        #         pass
        #     else:
        #         x_smallest = x
        #         y_smallest = y
        #         smalles_angle = abs(curr_angle)
        #         accum_angle = curr_angle
    
    points_total.append(
        (x_smallest, y_smallest)
    )
    points.append(
        (x_smallest, y_smallest)
    )
            
    mask = cv2.circle(mask, (x_smallest, y_smallest), 5, (0,0,255),3)
    # mask = cv2.putText(mask, f'({x_smallest}, {y_smallest} )', (x_smallest + 10,y_smallest), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
    
    return get_coords(x_smallest, y_smallest, branch, mask, points_total, blue_nodes, points)

In [27]:
def get_coords(x_start, y_start, branch, mask, points_total, blue_nodes, all_red, hierarchy):
    
    x_dsts = branch[
            (branch['image-coord-src-1'] == x_start) & (branch['image-coord-src-0'] == y_start)
        ]['image-coord-dst-1']
        
    y_dsts = branch[
            (branch['image-coord-src-1'] == x_start) & (branch['image-coord-src-0'] == y_start)
        ]['image-coord-dst-0']
    
    if len(x_dsts) == 0:
        return mask, points_total, blue_nodes, all_red, hierarchy
    
    smalles_angle = np.inf
    x_smallest = 0
    y_smallest = 0
    
    hier_r = []
    hier_b = []
    
    red = []
    blue = []
    
    for i in range(len(x_dsts)):
        x = int(x_dsts.iloc[i])
        y = int(y_dsts.iloc[i])
        blue_nodes.append((x, y))
        hier_b.append((x,y))
        curr_angle = angle_with_line(
            x, 2618 - y, points_total[-2][0], 2618 - points_total[-2][1], points_total[-1][0], 2618 - points_total[-1][1]
        )
       
        
        mask = cv2.circle(mask, (int(x), int(y)), 5, (255,0,0),3)
        
        if curr_angle < smalles_angle:
            x_smallest = x
            y_smallest = y
            smalles_angle = curr_angle
            
    blue_nodes.remove((x_smallest, y_smallest))
    hier_b.remove((x_smallest, y_smallest))
    hier_r.append(
        (x_smallest, y_smallest)
    )
    points_total.append(
        (x_smallest, y_smallest)
    )
    all_red.append(
        (x_smallest, y_smallest)
    )
    hierarchy.append([hier_r, hier_b])
    
            
    mask = cv2.circle(mask, (x_smallest, y_smallest), 5, (0,0,255),3)
    # mask = cv2.putText(mask, f'({x_smallest}, {y_smallest} )', (x_smallest + 10,y_smallest), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
    
    return get_coords(x_smallest, y_smallest, branch, mask, points_total, blue_nodes, all_red, hierarchy)

In [28]:
def get_coords(x_start, y_start, branch, mask, points_total, red = None, blue = None):
    
    x_dsts = branch[
            (branch['image-coord-src-1'] == x_start) & (branch['image-coord-src-0'] == y_start)
        ]['image-coord-dst-1']
        
    y_dsts = branch[
            (branch['image-coord-src-1'] == x_start) & (branch['image-coord-src-0'] == y_start)
        ]['image-coord-dst-0']
    
    if len(x_dsts) == 0:
        return mask, red, blue, points_total
    
    smalles_angle = np.inf
    x_smallest = 0
    y_smallest = 0
    
    if red is None:
        red = []
    if blue is None:
        blue = []
    
    current_blue = []
    for i in range(len(x_dsts)):
        x = int(x_dsts.iloc[i])
        y = int(y_dsts.iloc[i])
        
        current_blue.append((x, y))
        
        curr_angle = angle_with_line(
            x, 2618 - y, points_total[-2][0], 2618 - points_total[-2][1], points_total[-1][0], 2618 - points_total[-1][1]
        )
       
        
        mask = cv2.circle(mask, (int(x), int(y)), 5, (255,0,0),3)
        
        if curr_angle < smalles_angle:
            x_smallest = x
            y_smallest = y
            smalles_angle = curr_angle
            
    current_blue.remove((x_smallest, y_smallest))
    
    red.append(
        (x_smallest, y_smallest)
    )
    blue.append(current_blue)
    points_total.append((x_smallest, y_smallest))
    
            
    mask = cv2.circle(mask, (x_smallest, y_smallest), 5, (0,0,255),3)
    
    return get_coords(x_smallest, y_smallest, branch, mask, points_total, red, blue)

[[454, 932, 1489],
 [454, 1489, 2048],
 [454, 2048, 2607],
 [454, 2607, 3164],
 [454, 3164, 3424]]

[(243, 61, 0), (218, 78, 0), (160, 136, 0), (138, 102, 0), (69, 148, 0)]

In [29]:

mask = segmented_plants[454:, 932:1489]
plant_skeleton = skimage.morphology.skeletonize(mask)
plant_branch = summarize(Skeleton(plant_skeleton)) 
counter = 0

all_segments_dir = {}
all_red = []

for i in range(len(plant_branch)):
    
    subset_segment_dir = {}
    blue_nodes = []
    points_total = [
        (243, 61)
    ]
    mask = cv2.imread(
    "/Users/work_uni/Documents/GitHub/AIxPlant_Science/overlapping_roots/final_pipeline_masks/masks/43-18-ROOT1-2023-08-08_pvd_OD001_f6h1_02-Fish Eye Corrected_root_mask.png"
    )[454:, 932:1489]
    row = plant_branch.iloc[i]
    if row['branch-type'] == 2:
        x_start = row['image-coord-dst-1']
        y_start = row['image-coord-dst-0']
       
        if (int(x_start), int(y_start)) in all_red:
            continue
            
        points_total.append(
            (int(x_start), int(y_start))
        )
        # mask = cv2.circle(mask, (int(169), int(148)), 5, (0,255,0),2)
        # 
        # mask = cv2.circle(mask, (int(x_start), int(y_start)), 5, (0,255,0),2)
        # mask = cv2.putText(mask, f'({int(x_start)}, {int(y_start)})', (int(x_start) - 40,int(y_start) - 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1, cv2.LINE_AA)
        
        mask, red, blue, points_total = get_coords(x_start, y_start, plant_branch, mask, points_total, [(int(x_start), int(y_start))], [[]])
        all_segments_dir[str(counter)] = {'red':red,
                                          'blue':blue}
        all_red = all_red + red
        # hierarchy_all.append(hierarchy)
        # cv2.imwrite(f'/Users/work_uni/Documents/GitHub/2023-24d-fai2-adsai-group-cv2/pyphenotyper/pipelines/images_withbpoints/{counter}-mask.png', mask)
        # counter += 1
        

In [30]:
def find_starting_nodes(dict):
    starting_coordinates = []
    for key in dict.keys():
        starting_coordinates.append(
            (dict[key]['red'][0])
        )
    return starting_coordinates

In [31]:
def map_starting_nodes_to_unassigned(starting_nodes, dict):
    
    nodes_mapped = []
    
    for key in dict.keys():
        for node_num, node in enumerate(starting_nodes):
            for unassigned_coordinate in dict[key]['blue']:
                if node in unassigned_coordinate:
                    nodes_mapped.append(
                        (unassigned_coordinate, key, node_num)
                    )
    return nodes_mapped

In [32]:
def get_the_last_id_from_dict(dict):
    ids = [int(elem) for elem in dict.keys()]
    return max(ids)

In [33]:
def match_branches(dict_p, previous_dict = None):
    starting_coordinates = find_starting_nodes(dict_p)
    if len(starting_coordinates) == 0:
        return previous_dict
    
    nodes_mapped = map_starting_nodes_to_unassigned(starting_coordinates, dict_p)
    dict_keys = list(dict_p.keys())
    new_dict = {}
    subset_count = get_the_last_id_from_dict(dict_p)
    counter = 0
    for coordinate, first_part, second_part in nodes_mapped:
        first_part_subset_blue = dict_p[first_part]['blue']
        first_part_subset_red = dict_p[first_part]['red']
        
        
        unassigned_position = 0
        for i in range(len(first_part_subset_blue)):
            if coordinate == first_part_subset_blue[i]:
                unassigned_position = i
                break
        
        first_part_subset_red = first_part_subset_red[0:unassigned_position-1]
        first_part_subset_blue = [[]] * (unassigned_position - 1)
        
        second_part_subset_red = dict_p[dict_keys[second_part]]['red']
        second_part_subset_blue = dict_p[dict_keys[second_part]]['blue']
        
        result_red_subset = first_part_subset_red + second_part_subset_red
        result_blue_subset = first_part_subset_blue + second_part_subset_blue
        # print(result_red_subset)
        new_dict[str(subset_count + counter)] = {
            'red':result_red_subset,
            'blue':result_blue_subset
        }
        counter += 1   
    if previous_dict is None:
        previous_dict = dict(dict_p, **new_dict)
    else:
        previous_dict = dict(previous_dict, **new_dict)
    if len(new_dict.keys()) == 0:
        return previous_dict
    return match_branches(new_dict, previous_dict)

In [34]:
everything_mathced = match_branches(all_segments_dir)

In [35]:
def get_the_lowest_point(dict_matched):
    max_x, max_y = 0, 0
    min_x, min_y = 1e6, 1e6
    for key in dict_matched.keys():
        lowest_x, lowest_y = dict_matched[key]['red'][-1]
        highest_x, highest_y = dict_matched[key]['red'][0]
        if lowest_y >= max_y:
            max_y = lowest_y
            max_x = lowest_x
        
            if highest_y < min_y:
                min_y = highest_y
                min_x = highest_x
    
    max_depth_dict = {}
    
    for key in dict_matched:
        last_x, last_y = dict_matched[key]['red'][-1]
        first_x, first_y = dict_matched[key]['red'][0]
        if last_x == max_x and last_y == max_y and first_x == min_x and first_y == min_y:
            max_depth_dict[key] = dict_matched[key]
    
    return max_depth_dict

In [36]:
max_dp = get_the_lowest_point(everything_mathced)

In [37]:
def determine_the_straightest_path(filtered_dict):
    straightest_path = ''
    min_curv = 1e6
    
    for key in filtered_dict.keys():
        x = [elem[0] for elem in filtered_dict[key]['red']]
        y = [elem[1] for elem in filtered_dict[key]['red']]
        
        curvature = calculate_curvature(x, y)
        if curvature < min_curv:
            straightest_path = key
            min_curv = curvature
    
    return straightest_path

In [38]:
sp = determine_the_straightest_path(max_dp)

In [39]:

mask = cv2.imread(
"/Users/work_uni/Documents/GitHub/AIxPlant_Science/overlapping_roots/final_pipeline_masks/masks/43-18-ROOT1-2023-08-08_pvd_OD001_f6h1_02-Fish Eye Corrected_root_mask.png"
)[454:, 932:1489]
current_subset = everything_mathced[sp]['red']

for coord in current_subset:
    mask = cv2.circle(mask, (coord[0], coord[1]), 5, (0,0,255),2)
    cv2.imwrite(f'/Users/work_uni/Documents/GitHub/2023-24d-fai2-adsai-group-cv2/pyphenotyper/pipelines/images_withbpoints/{sp}-mask.png', mask)

In [None]:
nodes_mapped
# COORD \ BLUE \ GREEN

In [None]:
all_segments_dir

{'0': {'blue': [(108, 312),
   (356, 418),
   (190, 463),
   (357, 646),
   (43, 599),
   (87, 619),
   (118, 707),
   (104, 748),
   (227, 898),
   (155, 837),
   (117, 972),
   (91, 1041),
   (82, 1133)],
  'hierarchy': [[[(152, 249)], []],
   [[(144, 344)], [(108, 312)]],
   [[(134, 372)], [(356, 418)]],
   [[(135, 381)], []],
   [[(121, 430)], [(190, 463)]],
   [[(95, 577)], [(357, 646)]],
   [[(91, 614)], [(43, 599)]],
   [[(98, 670)], [(87, 619)]],
   [[(102, 741)], [(118, 707)]],
   [[(104, 748)], [(104, 748)]],
   [[(104, 751)], []],
   [[(108, 799)], [(227, 898)]],
   [[(106, 832)], [(155, 837)]],
   [[(90, 1040)], [(117, 972)]],
   [[(78, 1133)], [(91, 1041)]],
   [[(4, 1185)], [(82, 1133)]]]},
 '1': {'blue': [(165, 175), (190, 188), (193, 222)],
  'hierarchy': [[[(192, 185)], [(165, 175)]],
   [[(210, 206)], [(190, 188)]],
   [[(262, 269)], [(193, 222)]]]},
 '2': {'blue': [(190, 463),
   (357, 646),
   (43, 599),
   (87, 619),
   (118, 707),
   (104, 748),
   (227, 898),
   (155, 837),
   (117, 972),
   (91, 1041),
   (82, 1133)],
  'hierarchy': [[[(134, 372)], []],
   [[(135, 381)], []],
   [[(121, 430)], [(190, 463)]],
   [[(95, 577)], [(357, 646)]],
   [[(91, 614)], [(43, 599)]],
   [[(98, 670)], [(87, 619)]],
   [[(102, 741)], [(118, 707)]],
   [[(104, 748)], [(104, 748)]],
   [[(104, 751)], []],
   [[(108, 799)], [(227, 898)]],
   [[(106, 832)], [(155, 837)]],
   [[(90, 1040)], [(117, 972)]],
   [[(78, 1133)], [(91, 1041)]],
   [[(4, 1185)], [(82, 1133)]]]},
 '3': {'blue': [(357, 1105), (91, 1041), (356, 1202), (111, 1123)],
  'hierarchy': [[[(114, 1027)], [(357, 1105)]],
   [[(117, 1068)], [(91, 1041)]],
   [[(113, 1072)], [(356, 1202)]],
   [[(118, 1090)], []],
   [[(208, 1226)], [(111, 1123)]],
   [[(357, 1420)], []]]},
 '4': {'blue': [(111, 1123)],
  'hierarchy': [[[(113, 1072)], []],
   [[(118, 1090)], []],
   [[(208, 1226)], [(111, 1123)]],
   [[(357, 1420)], []]]},
 '5': {'blue': [(111, 1153),
   (208, 1226),
   (118, 1310),
   (118, 1310),
   (356, 1434),
   (4, 1491),
   (2, 1569),
   (151, 1541)],
  'hierarchy': [[[(82, 1133)], [(111, 1153)]],
   [[(111, 1153)], []],
   [[(116, 1161)], []],
   [[(125, 1212)], [(208, 1226)]],
   [[(142, 1282)], [(118, 1310)]],
   [[(150, 1332)], [(118, 1310)]],
   [[(143, 1389)], [(356, 1434)]],
   [[(143, 1409)], []],
   [[(148, 1485)], [(4, 1491)]],
   [[(152, 1490)], [(2, 1569)]],
   [[(191, 1520)], [(151, 1541)]]]},
 '6': {'blue': [(4, 1491), (2, 1569), (151, 1541)],
  'hierarchy': [[[(143, 1389)], []],
   [[(143, 1409)], []],
   [[(148, 1485)], [(4, 1491)]],
   [[(152, 1490)], [(2, 1569)]],
   [[(191, 1520)], [(151, 1541)]]]},
 '7': {'blue': [(260, 1597),
   (4, 1650),
   (357, 1761),
   (355, 1850),
   (146, 1849),
   (4, 1937),
   (355, 1994),
   (356, 2094),
   (1, 2117),
   (357, 2199),
   (273, 2234),
   (62, 2260),
   (108, 2316)],
  'hierarchy': [[[(146, 1580)], [(260, 1597)]],
   [[(142, 1657)], [(4, 1650)]],
   [[(140, 1723)], [(357, 1761)]],
   [[(135, 1822)], [(355, 1850)]],
   [[(134, 1851)], [(146, 1849)]],
   [[(146, 1917)], [(4, 1937)]],
   [[(153, 1975)], [(355, 1994)]],
   [[(156, 2021)], [(356, 2094)]],
   [[(167, 2101)], [(1, 2117)]],
   [[(164, 2105)], [(357, 2199)]],
   [[(169, 2189)], []],
   [[(166, 2206)], [(273, 2234)]],
   [[(172, 2277)], [(62, 2260)]],
   [[(178, 2307)], [(108, 2316)]]]},
 '8': {'blue': [], 'hierarchy': [[[(4, 2303)], []]]}}

In [None]:
all_starting_node

In [None]:
plants = []
for key in all_segments_dir.keys():
    for starting_node in all_starting_node:
        if starting_node in all_segments_dir[key]['blue']:
            plants.append((key, all_starting_node.index(starting_node), starting_node))
            # break

In [None]:
plants
# plant it's blue | plant it is a starting node

In [None]:
print(counter)

In [None]:
new_segments = {}
new_full_branch = {}

for fp, sp, coordinates in plants:
    for i, coord in enumerate(all_segments_dir[fp]['hierarchy']):
        if coordinates in coord[1]:
            counter += 1
            print(f'i - {i}')
            reds = []
            for j in range(i):
                
                reds.append(
                    all_segments_dir[fp]['hierarchy'][j][0]
                )
            print(reds) 
            for index in range(len(all_segments_dir[str(sp)]['hierarchy'])):
                
                reds.append(
                    all_segments_dir[str(sp)]['hierarchy'][index][0]
                )
            print(reds)
            current_br = {}
            current_br['blue'] = all_segments_dir[str(sp)]['blue']
            current_br['hierarchy'] = all_segments_dir[str(sp)]['hierarchy']
            new_full_branch[counter] = current_br
            new_segments[counter] = reds

In [None]:
new_full_branch

In [None]:
for key, value in new_segments.items():
    mask = cv2.imread(
    "/Users/work_uni/Documents/GitHub/AIxPlant_Science/overlapping_roots/final_pipeline_masks/masks/43-19-ROOT1-2023-08-08_pvd_OD001_f6h1_02-Fish Eye Corrected_root_mask.png"
    )[454:, 3064:3424]
    for item in value:
        mask = cv2.circle(mask, (item[0][0], item[0][1]), 5, (0,255,0),2)
        cv2.imwrite(f'/Users/work_uni/Documents/GitHub/2023-24d-fai2-adsai-group-cv2/pyphenotyper/pipelines/images_withbpoints/{key}-mask.png', mask)

In [None]:
all_segments_dir['0']

In [None]:
all_starting_node

In [None]:
blue_nodes

In [None]:
for start_nodes in all_starting_node:
    if start_nodes in blue_nodes:
        for h_per_segment in hierarchy_all:
            for h in h_per_segment:
                if start_nodes in h[1]:
                    print('found')
                    print(start_nodes)
                    break

In [None]:
hierarchy_all[0]