In [8]:
import pandas as pd
import numpy as np
import re
import sleap # requires sleap in environment
import h5py
import imageio as iio

from pathlib import Path
from typing import Tuple
from rich.progress import track

In [12]:
# TODO: Input search folders to look for image folders
search_folders = r"""
C:\Users\Elizabeth_Berrigan\Box\Phenotyping_team_GH\Experiments\Rice\main_experiments\CYL_Rice_FN900_Related_EXP\CYL_KitX_APR24\0520
""".strip().split('\n')

# TODO: Output folder for h5 files and predictions
dst_folder = r"C:\Users\Elizabeth_Berrigan\Box\Phenotyping_team_GH\Experiments\Rice\main_experiments\CYL_Rice_FN900_Related_EXP\CYL_KitX_APR24\analysis_EB_20240523\h5s_preds"

# TODO: Set overwrite to True to overwrite existing files
overwrite = False

# TODO: Model folder to use for predictions--should be for the older rice plants
crown_model_folder = r"C:\Users\Elizabeth_Berrigan\Desktop\phenotyping\sleap-roots-pipeline\20240501_models\rice\older\crown\221208_113552.multi_instance.n=574.zip"

In [13]:
def natural_sort(l):
    """https://stackoverflow.com/a/4836734"""
    l = [x.as_posix() if isinstance(x, Path) else x for x in l]
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key=alphanum_key)


def convert_img_folder_to_h5(img_folder, dst_folder, overwrite=False):
    """Convert an image folder to an HDF5 file and save it to the dst_folder.
    
    Args:
        img_folder: Path to a folder filled with PNGs named sequentially.
        dst_folder: Path to a folder where HDF5 file will be saved. 
        overwrite: If True, save the HDF5 file even if it already exists.
        
    Notes:
        This will save a file with the same filename as the image folder, but
        with a .h5 extension. The resulting file will contain a dataset named
        "vol" with shape (slices, height, width, 1).
    """
    # Create a Path object from the img_folder path
    img_folder_path = Path(img_folder)
    # Get the img_folder as a string 
    img_folder = img_folder_path.as_posix()
    # Get the parent directory using the .parent attribute
    parent_name = img_folder_path.parent.parts[-1]
    # Get the h5_name from the img_folder_path
    h5_name = f"{img_folder_path.stem}.h5"
    # Create a Path object from the dst_folder 
    dst_folder = Path(dst_folder)
    
    # Create the parent folder
    parent_folder_path = dst_folder / parent_name
    parent_folder_path.mkdir(exist_ok=True)
    
    # Make the final h5 file path 
    dst_name = parent_folder_path / h5_name
    
    if not overwrite and Path.exists(dst_name):
        return dst_name

    p = Path(img_folder)
    img_paths = natural_sort(list(p.glob("*.png")))

    vol = np.stack([iio.imread(p) for p in img_paths], axis=0)  # (slices, height, width)

    with h5py.File(dst_name, "w") as f:
        ds = f.create_dataset(
            "vol",
            data=np.expand_dims(vol, axis=-1),  # (slices, height, width, 1)
            compression=1
        )
    print(f"Saved {dst_name.as_posix()}.")
    return dst_name


def predict(
    h5_path: str,
    model_input_dir: str,
    model_type: str,
    output_dir: str,
) -> Tuple[sleap.Labels, dict]:
    """Get the SLEAP predictions.

    Args:
        h5_path: Path to h5 file containing the image data with shape 
            (slices, height, width, 1) and dataset name 'vol'.
        model_input_dir: Directory containing the model.
        model_type: Type of model to use for predictions: "crown", "primary", "lateral".
        output_dir: Directory to save predictions and predictions.csv.

    Returns:
        sleap.Labels: SLEAP predictions.
        preds_dict: A dictionary containing scan_id, model_type, and prediction path.
    """
    # Create Path objects from the input strings
    h5_path = Path(h5_path)
    model_input_dir = Path(model_input_dir)
    output_dir = Path(output_dir)

    # Check if h5_path is None
    if h5_path is None:
        return None, preds_dict
    
    # Extract the series name from the h5 path
    series_name = h5_path.name.split(".")[0]
    
    # Initialize dictionary for prediction info
    preds_dict = {}
    # Add scan_id to dictionary
    preds_dict["scan_id"] = series_name

    # Log sleap version
    print(f"SLEAP version: {sleap.versions()}")

    # Load the model info
    model_id = model_input_dir.name.split(".")[0]

    # Generate the paths for the crown predictions
    crown_path = h5_path.replace(".h5", f".model{model_id}.root{model_type}.slp")

    # Load the model
    predictor = sleap.load_model(
        model_input_dir.as_posix(), progress_reporting="none"
    )
    print(f"Loaded model {model_type} from {model_input_dir}.")

    # Get the predictions
    labels = predictor.predict(h5_path.as_posix())
    # Save the predictions
    labels.save(crown_path.as_posix())
    print(f"Saved predictions to {crown_path.as_posix()}.")

    # Add the prediction path to the dictionary
    preds_dict[model_type] = crown_path.as_posix()

    return labels, preds_dict

In [14]:
# Search for image folders
img_folders = []
for search_folder in track(search_folders, description="Searching for image folders..."):
    p = Path(search_folder)
    img_folders.extend([x.parent.as_posix() for x in p.rglob("1.png")])
print(f"Found {len(img_folders)} image folders")

Output()

Found 38 image folders


In [15]:
# Convert to HDF5
h5_files = []
for img_folder in track(img_folders, description="Converting to HDF5..."):
    h5_files.append(convert_img_folder_to_h5(img_folder, dst_folder, overwrite=overwrite))
print(f"Converted {len(h5_files)} image folders to HDF5")

Output()

Converted 38 image folders to HDF5


In [None]:
# Predict
preds = []
for h5_file in track(h5_files, description="Predicting..."):
    for model_type in ["crown"]:
        preds.append(predict(h5_file, crown_model_folder, model_type, dst_folder))
print(f"Predicted {len(preds)} HDF5 files")