# Whorls pose estimation

### If required install required packages

In [1]:
#!pip install ultralytics
#!python -m pip install "laspy[lazrs,laszip]"
#!pip install comet_ml 
#!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

### Load libraries

In [1]:
# Import libraries
import os, glob, shutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import io
import laspy
from PIL import Image
import re
from scipy.spatial import cKDTree
import concurrent.futures


import ultralytics
from ultralytics import YOLO

### Load custom functions

In [26]:
# Define custom functions
def rotate_point_cloud(point_cloud, angle_degrees, center_point):
    """
    Rotate the point cloud around a center point by a given angle in degrees.
    """
    theta = np.radians(angle_degrees)
    c, s = np.cos(theta), np.sin(theta)
    R = np.array(((c, -s), (s, c)))
    rotated_points = np.dot(point_cloud[:, :2] - center_point, R.T) + center_point
    return np.hstack((rotated_points, point_cloud[:, 2].reshape(-1, 1)))

def slice_tree_center_thick_slices(point_cloud, slice_thickness=10):
    """
    Take thick slices in the X and Y directions, centered around the tree's center.
    """
    tree_center = point_cloud[point_cloud[:,2].argmax(), :2]
    x_slice_mask = (point_cloud[:,0] >= tree_center[0] - slice_thickness/2) & \
                   (point_cloud[:,0] <= tree_center[0] + slice_thickness/2)
    y_slice_mask = (point_cloud[:,1] >= tree_center[1] - slice_thickness/2) & \
                   (point_cloud[:,1] <= tree_center[1] + slice_thickness/2)
    x_slice = point_cloud[x_slice_mask]
    y_slice = point_cloud[y_slice_mask]
    return x_slice, y_slice



def plot_to_image(figure, dpi):
    """
    Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it as a numpy array, setting the resolution with a high DPI.
    """
    buf = io.BytesIO()
    figure.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=dpi)
    plt.close(figure)
    buf.seek(0)
    image = Image.open(buf)
    return np.array(image)

def plot_section_as_image_with_alpha(slice_data, z_low, z_high, alpha=0.3, output_size=(1000, 1000), dpi=100):
    """
    Create a figure and plot the slice_data with alpha transparency.
    Dynamically adjusts plot limits based on the data and resizes the output image to a square format.
    """
    if slice_data.size == 0:
        return None  # Return None if there are no data points to plot.

    # Determine aspect ratio and figure size
    buffer = 0 # Add a buffer around data extents
    x_min, x_max = np.min(slice_data[:, 0]) - buffer, np.max(slice_data[:, 0]) + buffer
    y_min, y_max = np.min(slice_data[:, 2]) - buffer, np.max(slice_data[:, 2]) + buffer

    # Dynamically adjust xlim and ylim to include all points and maintain real tree dimensions
    x_range = x_max - x_min
    y_range = z_high - z_low  # This should be close to section_height if properly sliced

    # Determine the scale factor to use for x and y to maintain aspect ratio
    if x_range > y_range:
        scale_factor = x_range / y_range
        fig_width, fig_height = 10 * scale_factor, 10
    else:
        scale_factor = y_range / x_range
        fig_width, fig_height = 10, 10 * scale_factor

    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    ax.scatter(slice_data[:, 0], slice_data[:, 2], s=3, color='black', alpha=alpha, edgecolors='none')
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_aspect('auto')  # 'auto' allows free aspect ratio that adjusts to specified limits

    ax.axis('off')
    for spine in ax.spines.values():
        spine.set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])

    return plot_to_image(fig, dpi)
    
def convert_sections_to_images(point_cloud, section_height, slice_thickness, tree_center, bottom_height, output_dir, base_filename):
    max_height =  np.max(point_cloud[:, 2])-bottom_height
    num_sections = int(np.ceil(max_height / section_height))

    # Normalize point cloud
    point_cloud[:, 2]=point_cloud[:, 2]-bottom_height

    for i in range(num_sections):
        # compute the lower end of the interval
        z_low=np.float64(0)
        if i>0:
            z_low = np.float64(i)*section_height

        # get upper end of the interval
        z_high = (i +1) * section_height

        # slice pointcloud
        section_mask = (point_cloud[:, 2] >= z_low) & (point_cloud[:, 2] < z_high)
        section_points = point_cloud[section_mask]

        # skip if there are no points
        if not section_points.size:
            continue

        # if the tree does not fill the frame then modify the lower and upper limits
        if (np.max(section_points[:, 2])-np.min(section_points[:, 2]))<(section_height-0.5):
            z_high = np.max(section_points[:, 2])
            z_low= z_high-section_height
            # slice pointcloud
            section_mask = (point_cloud[:, 2] >= z_low) & (point_cloud[:, 2] < z_high)
            section_points = point_cloud[section_mask]

        
        x_section_slice, y_section_slice = slice_tree_center_thick_slices(section_points, slice_thickness)
        rotated_section_points = rotate_point_cloud(section_points, 45, tree_center)
        x45_section_slice, y45_section_slice = slice_tree_center_thick_slices(rotated_section_points, slice_thickness)

        plot_limits = [tree_center[0] - slice_thickness/2, tree_center[0] + slice_thickness/2,
                       tree_center[1] - slice_thickness/2, tree_center[1] + slice_thickness/2]

        for slice_data, slice_name in zip([x_section_slice, y_section_slice, x45_section_slice, y45_section_slice],
                                          ['x', 'y', 'x45', 'y45']):
            img_array = plot_section_as_image_with_alpha(slice_data, z_low, z_high, dpi=100)
            
            # Metadata for filename
            min_x_or_y = np.min(slice_data[:, 0 if slice_name in ['x', 'x45'] else 1])
            max_x_or_y = np.max(slice_data[:, 0 if slice_name in ['x', 'x45'] else 1])
            filename = f"{output_dir}/{base_filename}_{slice_name}_section_{i}_min{min_x_or_y:.2f}_max{max_x_or_y:.2f}_zmin{z_low:.2f}_zmax{z_high:.2f}.png"
            
            Image.fromarray(img_array).save(filename)


def process_point_cloud(point_cloud, output_directory, bottom_height, base_filename):
    tree_center = point_cloud[point_cloud[:,2].argmax(), :2]
    convert_sections_to_images(point_cloud, 10, 10, tree_center,bottom_height, output_directory, base_filename)

# Function to get image size
def get_image_size(image_path):
    with Image.open(image_path) as img:
        return img.size  # (width, height)

# Function to convert normalized coordinates to real-world coordinates
def convert_to_real_world(px, py, img_width, img_height, x_min, x_max, z_min, z_max):
    real_x = px * img_width
    real_y = py * img_height
    world_x = real_x / img_width * (x_max - x_min) + x_min
    world_z = real_y / img_height * (z_max - z_min) + z_min
    return world_x, world_z

# Function to process a single file
def process_file(text_file_path, img_width, img_height, x_min, x_max, z_min, z_max):
    real_world_data = []
    with open(text_file_path, 'r') as file:
        for line in file.readlines():
            parts = line.strip().split()

            # get coordinates and confidence for the three keypoints where
            # p1: left branch; p2: whorl center; p3: right branch
            px1 = float(parts[5])
            py1 = float(parts[6])
            confidence_p1 = float(parts[7])
            px2 = float(parts[8])
            py2 = float(parts[9])
            confidence_p2 = float(parts[10])
            px3 = float(parts[11])
            py3 = float(parts[12])
            confidence_p3 = float(parts[13])


            world_px1, world_pz1 = convert_to_real_world(px1,py1, img_width, img_height, x_min, x_max, z_min, z_max)
            world_px2, world_pz2 = convert_to_real_world(px2, py2, img_width, img_height, x_min, x_max, z_min, z_max)
            world_px3, world_pz3 = convert_to_real_world(px3, py3, img_width, img_height, x_min, x_max, z_min, z_max)

            real_world_data.append((confidence_p1, world_px1, world_pz2, confidence_p2, world_px2, world_pz2,confidence_p3, world_px3, world_pz3))
    return real_world_data

# Function to calculate the angle at p2 formed by p1 and p3
def calculate_angle_at_p2(px1, pz1, px2, pz2, px3, pz3):
    # Construct vectors from p2 to p1 and p2 to p3
    vector_p2_p1 = np.array([px1 - px2, pz1 - pz2])
    vector_p2_p3 = np.array([px3 - px2, pz3 - pz2])
    # Calculate the dot product and norms of the vectors
    dot_product = np.dot(vector_p2_p1, vector_p2_p3)
    norm_p2_p1 = np.linalg.norm(vector_p2_p1)
    norm_p2_p3 = np.linalg.norm(vector_p2_p3)
    # Calculate the angle in radians and then convert to degrees
    angle = np.arccos(np.clip(dot_product / (norm_p2_p1 * norm_p2_p3), -1.0, 1.0))
    angle_degrees = np.degrees(angle)
    # Check for reflex angle (greater than 180 degrees)
    if np.cross(vector_p2_p1, vector_p2_p3) < 0:  # using cross product to determine the orientation
        angle_degrees = 360 - angle_degrees
    return angle_degrees

# Function to calculate Euclidean distance
def calculate_distance(px1, pz1, px2, pz2):
    return np.sqrt((px1 - px2) ** 2 + (pz1 - pz2) ** 2)

# function to process each tree and obtain the detected whorls
def pose_detection_tree(treeID, trees,non_tree, dir_root, my_model, dir_pred, min_dist_whorls=0.24):
    # skip if treeID ==255
    if treeID == 0:
        return None  # Skip if treeID is 0

    # Folder path containing  the text files
    dir_temp_imgs=dir_pred+"/orig_imgs"
    if not os.path.exists(dir_temp_imgs):
       os.makedirs(dir_temp_imgs)
    dir_orig_imgs_tree=dir_temp_imgs+"/"+str(round(treeID))
    if not os.path.exists(dir_orig_imgs_tree):
       os.makedirs(dir_orig_imgs_tree)
        
    dir_pred_out=dir_pred+"/preds"
    if not os.path.exists(dir_pred_out):
       os.makedirs(dir_pred_out)
        
    dir_labels=dir_pred_out+"/"+str(round(treeID))+"/labels"

    # select the one tree
    one_tree_np= trees[trees[:,3]==treeID]
    
    # compute tree top (will be used later)
    top_point= one_tree_np[one_tree_np[:,2]==np.max(one_tree_np[:,2])] 
    top_height=top_point[0,2]
    
    # compute tree bottom (will be used later)
    bottom_point= one_tree_np[one_tree_np[:,2]==np.min(one_tree_np[:,2])] 
    bottom_height=bottom_point[0,2]
    
    if len(non_tree)<0:
        # Separate the x, y, and z coordinates
        xy_coordinates = non_tree[:, :2]  # Extract x and y
        print(xy_coordinates)
        print(type(xy_coordinates))
        z_values = non_tree[:, 2]  # Extract z
    
        # Create a KDTree for efficient spatial search
        KDtree = cKDTree(xy_coordinates)
        
        # Given point coordinates (x, y)
        x, y = bottom_point[0,0], bottom_point[0,1]
        print(type(bottom_point[0,0]))
        print([x, y])
        print(KDtree)
        # Find the nearest point
        _, index = KDtree.query([x, y], k=1)
        
        # Extract the z value of the nearest point
        bottom_height = z_values[index]
    

    ## Create images from point clouds
    process_point_cloud(one_tree_np, dir_orig_imgs_tree,bottom_height, 'img_')



    ####################################################################################################################################
    ## Predict on new data
    model=YOLO(my_model)
    model.predict(source=dir_orig_imgs_tree,conf=0.3, imgsz=1000, save=True, save_txt=True, project=dir_pred_out, name=str(round(treeID)))  # no arguments needed, dataset and settings remembered


    
    ####################################################################################################################################
    ## Parse output to produce data for the whole tree
    
    # List all text files in the folder
    text_files = [f for f in os.listdir(dir_labels) if f.endswith('.txt')]
    print(dir_labels)
    # Process each text file
    all_data = []
    for text_file in text_files:
        # Extract the base filename to find the corresponding image file
        base_filename = text_file.replace('.txt', '')
        image_filename = base_filename + '.png'  # Assuming image extension is .png
        image_path = os.path.join(dir_orig_imgs_tree, image_filename)
        text_file_path = os.path.join(dir_labels, text_file)
        
        # Get image size
        img_width, img_height = get_image_size(image_path)
        
        # Extract metadata from the file name
        pattern = r"min(-?\d+\.\d+)_max(-?\d+\.\d+)_zmin(-?\d+\.\d+)_zmax(-?\d+\.\d+)"
    
        match = re.search(pattern, text_file)
        if match:
            x_min, x_max, z_min, z_max = map(float, match.groups())
            #x_min, x_max, z_min, z_max = map(float, re.findall(r"min(-?\d+\.\d+)_max(-?\d+\.\d+)_zmin(-?\d+\.\d+)_zmax(-?\d+\.\d+)", text_file)[0])
            print(x_min)

        else:
            x_min, x_max, z_min, z_max = map(float, re.findall(r"min(\d+\.\d+)_max(\d+\.\d+)_zmin(\d+\.\d+)_zmax(\d+\.\d+)", text_file)[0])
        
        # Process the file
        file_data = process_file(text_file_path, img_width, img_height, x_min, x_max, z_min, z_max)
        
        # Add filename string to each row and extend the all_data list
        slice_direction = base_filename.split('__')[1].split('_section')[0]
        all_data.extend([(confidence_p1, world_px1, world_pz2, confidence_p2, world_px2, world_pz2,confidence_p3, world_px3, world_pz3, slice_direction) for confidence_p1, world_px1, world_pz2, confidence_p2, world_px2, world_pz2,confidence_p3, world_px3, world_pz3 in file_data])
    
    # Create a DataFrame with all data
    df_all = pd.DataFrame(all_data, columns=['confidence_p1', 'world_px1', 'world_pz1', 'confidence_p2', 'world_px2', 'world_pz2','confidence_p3', 'world_px3', 'world_pz3','slice_direction'])
    df_all['treeID']=treeID
    df_all_sorted = df_all.sort_values(by='world_pz2')
    
    # Applying the functions to the DataFrame
    df_all_sorted['branch_opening_angle'] = df_all_sorted.apply(lambda row: calculate_angle_at_p2(row['world_px1'], row['world_pz1'],row['world_px2'], row['world_pz2'],row['world_px3'], row['world_pz3']), axis=1)
    #df_all_sorted['branch_opening_angle']=180-df_all_sorted['branch_opening_angle']
    df_all_sorted['branch_length_p1_p2'] = df_all_sorted.apply(lambda row: calculate_distance(row['world_px1'], row['world_pz1'], row['world_px3'], row['world_pz3']), axis=1)
    df_all_sorted['branch_length_p3_p2'] = df_all_sorted.apply(lambda row: calculate_distance(row['world_px3'], row['world_pz3'], row['world_px2'], row['world_pz2']), axis=1)
    #df_all_sorted['branch_length_p3_p1'] = df_all_sorted.apply(lambda row: calculate_distance(row['world_px3'], row['world_pz3'], row['world_px1'], row['world_pz1']), axis=1)

    # add tree top and tree bottom points to the sorted dataframe
    df_all_sorted.loc[len(df_all_sorted)] = [0, 0,0,1,top_point[0,0], top_height,0,0,0,0,0,0,0,0 ]
    #df_all_sorted.loc[len(df_all_sorted)] = [0, 0,0,1,bottom_point[0,0], bottom_height,0,0,0,0,0,0,0,0 ]
    
    # re-sort the dataframe
    df_all_sorted = df_all_sorted.sort_values(by='world_pz2')
    
    
    
    ####################################################################################################################################
    # Cleanup
    # subset to select only most confident predictions within each "min_dist_whorls" cm interval
    # Now, let's iterate through each row and select the one with the largest probability
    # if consecutive rows are closer than 0.05 in Z values.
    selected_rows = []
    current_row = df_all_sorted.iloc[0]
    
    for index, next_row in df_all_sorted.iterrows():
        if (next_row['world_pz2'] - current_row['world_pz2']) < min_dist_whorls:
            # If the Z values are closer than 0.05, check the confidence
            if next_row['confidence_p2'] > current_row['confidence_p2']:
                current_row = next_row
        else:
            # If they are not closer, add the current row to the selected rows
            selected_rows.append(current_row)
            current_row = next_row
    
    # Make sure to add the last row after the loop
    selected_rows.append(current_row)
    
    # Create a DataFrame with the selected rows
    df_selected = pd.DataFrame(selected_rows)
    
    #df_selected
    
    # Calculating the maximum branch length
    df_selected['max_branch_length'] = df_selected[['branch_length_p1_p2', 'branch_length_p3_p2']].max(axis=1)
    
    # Calculating the average branch length
    df_selected['average_branch_length'] = df_selected[['branch_length_p1_p2', 'branch_length_p3_p2']].mean(axis=1)
    #df_selected['crown_diam'] = df_selected[['branch_length_p3_p1', 'branch_length_p3_p2']].max(axis=1)

    
    # Now, replace values greater than 10 with 0 in both columns
    df_selected['max_branch_length'] = df_selected['max_branch_length'].apply(lambda x: 0 if x > 10 else x)
    df_selected['average_branch_length'] = df_selected['average_branch_length'].apply(lambda x: 0 if x > 10 else x)
    #df_selected['crown_diam'] = df_selected['crown_diam'].apply(lambda x: 0 if x > 10 else x)

    
    ####################################################################################################################################
    ## Create pointcloud result
    whorls_pc= df_selected[['world_px2','world_pz2','confidence_p2','branch_opening_angle','max_branch_length','average_branch_length']]
    whorls_pc['x']=top_point[0,0]
    whorls_pc['y']=top_point[0,1]
    
    # de-normalize z
    whorls_pc['z']=whorls_pc['world_pz2']+bottom_height
    # add tree ID
    #ID = re.findall(r'\d+', os.path.splitext(os.path.basename(treeID))[0])
    ID = int(treeID)
    whorls_pc['treeID']=ID    # This should include any processing and return the necessary results

    # For demonstration, returning a simple dictionary. This should be replaced with actual processing results
    return {'treeID': treeID, 'result': df_selected, 'whorl_pc':whorls_pc}  # Replace with actual result

## Define paths and create required directories for temporary or output files

In [15]:
# directory to where point cloud (*.las or *.laz) forest scenes are stored (they need to have treeIDs for each point!)
dir_root="data"

# directory to where YOLO predictions are stored
dir_output = os.path.join(dir_root, "results")  # path to where to store the final results
dir_pred = os.path.join(dir_root, "pred_temp")  # temp path to where to store the intermediate prediction
dir_temp_imgs = os.path.join(dir_pred, "orig_imgs") # Further subdirectories within 'dir_pred'

# Ensure directories exist without throwing an error if they already do
os.makedirs(dir_pred, exist_ok=True)
os.makedirs(dir_output, exist_ok=True)
os.makedirs(dir_temp_imgs, exist_ok=True)

# Find all .las and .laz files in the root folder
las_files = glob.glob(os.path.join(dir_root, '*.las'))
laz_files = glob.glob(os.path.join(dir_root, '*.laz'))

# Combine the lists of files
all_files = las_files + laz_files
all_files

['data\\demo_data.laz']

## Define parameters 

In [9]:
# minimum distance between whorls (used to remove duplicates)
min_internodal_d=0.01 # in m

# pose model
my_model="whorl_pose_nano_1000px/weights/best.pt"

# label for the column with tree instance unique identifiers
tree_id_label='treeID' 

# label for the column with semantic labels (non required)
semantic_label='semantic' 

## Run for each file in the folder

In [4]:
for filename in all_files:
    print("Processing: "+filename)

Processing: data\demo_data.laz


In [None]:
# DEBUG
# shutil.rmtree(dir_pred)

In [7]:
# Read in the forest segmented data
las = laspy.read(filename)

# convert to a numpy array 
las_np = np.vstack((las.x, las.y, las.z,  getattr(las, tree_id_label),  getattr(las, semantic_label))).transpose()

# split tree/non-tree (this assumes that the label 
trees= las_np[las_np[:,4]!=0]
non_tree= las_np[las_np[:,4]==0]

print("tree points: "+ str(round(len(trees)/len(las_np)*100))+"%")
print("Non-tree points: "+ str(round(len(non_tree)/len(las_np)*100))+"%")

tree points: 100%
Non-tree points: 0%


In [24]:
dir_preds=dir_pred

# Whorl detection

In [27]:
# Parallel individual tree whorl pose detection
unique_treeIDs = np.unique(getattr(las, tree_id_label))
results = []

#os.makedirs("test_whorls_maria/forest/pred_temp/preds", exist_ok=True)


with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
    # Submitting tasks to the executor
    future_to_treeID = {
        executor.submit(
            pose_detection_tree, 
            treeID=treeID,         # Pass treeID as a keyword argument
            trees=trees, 
            non_tree=non_tree, 
            dir_root=dir_root, 
            my_model=my_model, 
            dir_pred=dir_pred, 
            min_dist_whorls=min_internodal_d  # Pass min_dist_whorls as a keyword argument
        ): treeID 
        for treeID in unique_treeIDs
    }
    
    # Collecting results as they complete
    for future in concurrent.futures.as_completed(future_to_treeID):
        treeID = future_to_treeID[future]
        try:
            data = future.result()
            if data is not None:
                results.append(data)
        except Exception as exc:
            print(f'TreeID {treeID} generated an exception: {exc}')



image 1/8 C:\Users\stpu\whorl_pose_detector\data\pred_temp\orig_imgs\25409022\img__x45_section_0_min4.53_max9.16_zmin0.00_zmax10.00.png: 1024x480 11 whorlss, 55.0ms
image 1/8 C:\Users\stpu\whorl_pose_detector\data\pred_temp\orig_imgs\25409020\img__x45_section_0_min-1.93_max1.90_zmin0.00_zmax10.00.png: 1024x416 15 whorlss, 63.0ms
image 2/8 C:\Users\stpu\whorl_pose_detector\data\pred_temp\orig_imgs\25409022\img__x45_section_1_min4.53_max9.16_zmin2.50_zmax12.50.png: 1024x480 18 whorlss, 6.0ms
image 2/8 C:\Users\stpu\whorl_pose_detector\data\pred_temp\orig_imgs\25409020\img__x45_section_1_min-1.93_max1.90_zmin1.39_zmax11.39.png: 1024x416 17 whorlss, 24.0ms
image 3/8 C:\Users\stpu\whorl_pose_detector\data\pred_temp\orig_imgs\25409022\img__x_section_0_min4.74_max8.38_zmin0.00_zmax10.00.png: 1024x384 10 whorlss, 57.0ms
image 3/8 C:\Users\stpu\whorl_pose_detector\data\pred_temp\orig_imgs\25409020\img__x_section_0_min-2.13_max2.38_zmin0.00_zmax10.00.png: 1024x480 14 whorlss, 59.0ms
image 4/8 

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  whorls_pc['x']=top_point[0,0]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  whorls_pc['y']=top_point[0,1]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  whorls_pc['z']=whorls_pc['world_pz2']+bottom_height
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer

In [5]:
# DEBUG
#shutil.rmtree("C:/Users/stpu/whorl_pose_detector/test_whorls_maria/forest/pred_temp/orig_imgs")

## Post-processing of the predictions

In [28]:
# Post-processing the results
# Creating a dictionary mapping treeID to its result
results_dict = {result['treeID']: result['result'] for result in results if result is not None}

# Merging the 'whorl_pc' DataFrames from each result
whorls_pc_dfs = [result['whorl_pc'] for result in results if result is not None and 'whorl_pc' in result]
whorld_df = pd.concat(whorls_pc_dfs, ignore_index=True)

# correct z
whorld_df.loc[whorld_df['branch_opening_angle'] == 0, 'z'] = whorld_df['world_pz2']

# remove negative whorls
whorld_df= whorld_df[whorld_df['world_pz2']>=0]


# write out the output
whorld_df.to_csv(dir_output+'/'+os.path.splitext(os.path.basename(filename))[0]+'_'+str(round(min_internodal_d*100))+'cmThresh_whorls_pc_HKL2model.csv', index=False)  


In [36]:
print(whorld_df)

     world_px2  world_pz2  confidence_p2  branch_opening_angle  \
0     7.145728   0.273531       0.960763              5.969818   
1     1.752101   0.499852       0.977073              8.872340   
2     0.562580   0.766537       0.999808            159.023735   
3     7.209631   1.263350       0.999027            150.733587   
4     7.273557   1.464500       0.997177            151.758838   
..         ...        ...            ...                   ...   
101   0.275257  10.209120       0.997685            176.362206   
102   3.970931  10.256400       0.996999            191.226363   
103   3.637505  10.802320       0.987156            191.187685   
104   3.594131  10.896220       0.838436            173.924614   
105   0.324000  10.962000       1.000000              0.000000   

     max_branch_length  average_branch_length      x      y          z  \
0             2.629991               1.451761  7.219  0.675  -1.007469   
1             3.240880               2.716941  7.219  0.675

# PLOTTING (to edit!)

In [31]:
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib.colors import LinearSegmentedColormap

# Define a custom colormap using blue, yellow, and red, where yellow corresponds to the midpoint (180 degrees)
colors = [(0, 0, 1), (1, 1, 0), (1, 0, 0)]  # Blue -> Yellow -> Red
n_bins = 100  # Discretizes the interpolation into bins
cmap_name = 'custom_colormap'
cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=n_bins)

# Normalizing the branch_opening_angle to fit the range [0, 1] for our colormap
# Since yellow corresponds to 180 degrees in our scale, we adjust the normalization accordingly
#norm = plt.Normalize(0, 2*180)

df_selected_no_topBottom=df_selected[df_selected['confidence_p1']>0]
# Plotting all of the World_Z values
plt.figure(figsize=(10, 6)) 
# whorld_df
scatter = plt.scatter(range(len(df_selected)), df_selected['world_pz2'], c=df_selected['branch_opening_angle'], cmap='viridis',alpha=0.8)
#scatter = plt.scatter(range(len(df_selected)), df_selected['world_pz2'], df_selected['branch_opening_angle'], cmap=cm,alpha=0.8)
plt.colorbar(scatter, label='Branch Opening Angle (Degrees)')
plt.xlabel('Index')
plt.ylabel('world_pz2')
plt.show()

NameError: name 'df_selected' is not defined

In [32]:
np.max(df_selected_no_topBottom['branch_opening_angle'])

NameError: name 'df_selected_no_topBottom' is not defined

In [None]:

# Define a custom colormap using blue, yellow, and red, where yellow corresponds to 180 degrees
colors = [(0, 0, 1), (1, 1, 0), (1, 0, 0)]  # Blue -> Yellow -> Red
cmap_name = 'custom_colormap'
cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=20)

# Normalizing the branch_opening_angle to fit the fixed range [0, 360] for our colormap
norm = plt.Normalize(np.min(df_all_sorted['branch_opening_angle']), np.max(df_all_sorted['branch_opening_angle']))

df_selected_no_topBottom=df_selected[df_selected['confidence_p1']>0]
# Plotting with the fixed color scale
plt.figure(figsize=(10, 6))
scatter = plt.scatter(range(len(df_all_sorted)), df_all_sorted['world_pz2'], c=df_all_sorted['branch_opening_angle'], cmap=cm, norm=norm)
plt.colorbar(scatter, label='Branch Opening Angle (Degrees)')
plt.xlabel('Row Number')
plt.ylabel('world_pz2')
plt.title('Plot of world_pz2 Colored by Branch Opening Angle Using Fixed Color Range')
plt.show()

In [None]:
# Define two different colormaps for values below and above 180 degrees
cmap_below_180 = LinearSegmentedColormap.from_list('below_180', [(1, 1, 1), (0, 0, 1)], N=100)  # White to blue
cmap_above_180 = LinearSegmentedColormap.from_list('above_180', [(1, 1, 1), (1, 0, 0)], N=100)  # White to red

# Function to apply the appropriate colormap based on the value
def apply_cmap(value):
    if value <= 180:
        return cmap_below_180(norm(value))
    else:
        return cmap_above_180(norm(value - 180))

# Apply the color mapping to each value
colors = df_selected['branch_opening_angle'].apply(apply_cmap)

plt.figure(figsize=(10, 6))
scatter = plt.scatter(range(len(df_selected)), df_selected['world_pz2'], c=colors)
plt.colorbar(scatter, label='Branch Opening Angle (Degrees)')

# Create custom colorbar
sm = plt.cm.ScalarMappable(cmap=LinearSegmentedColormap.from_list('custom', [(0, 0, 1), (1, 0, 0)]), norm=norm)
sm.set_array([])  # You have to set_array for the ScalarMappable.
#cbar = plt.colorbar(sm, ticks=[0, 180, 360])
#cbar.ax.set_yticklabels(['0°', '180°', '360°'])  # Set the tick labels

#plt.xticks(row_numbers)
plt.xlabel('Row Number')
plt.ylabel('world_pz2')
plt.title('Plot of world_pz2 Colored by Branch Opening Angle with Custom Color Bar')
plt.show()

In [None]:
import matplotlib.colors as mcolors

# Adjusting the colormap to span the full range from 0 to 1
min_angle = df_selected["branch_opening_angle"].min()
max_angle = df_selected["branch_opening_angle"].max()

# Define color ranges in terms of the angle range
color_map_data = [
    (0, 'darkblue'), 
    ((180 - min_angle) / (max_angle - min_angle), 'lightblue'),
    ((180.9 - min_angle) / (max_angle - min_angle), 'white'),
    ((250 - min_angle) / (max_angle - min_angle), 'lightpink'),
    (1, 'darkred')
]

# Create the colormap
cmap = mcolors.LinearSegmentedColormap.from_list('custom', color_map_data)

# Normalize the color map to the range of branch_opening_angle
norm = mcolors.Normalize(vmin=min_angle, vmax=max_angle)

# Plot
plt.scatter(range(len(df_selected)), df_selected["world_pz2"], c=df_selected["branch_opening_angle"], cmap=cmap, norm=norm)
#plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap))
plt.xlabel('Index')
plt.ylabel('world_pz2')
plt.title('world_pz2 vs Index Colored by Branch Opening Angle')
plt.show()