In [None]:
import pandas as pd
import numpy as np
import sleap # requires sleap in environment
import h5py
import imageio as iio

from pathlib import Path
from typing import Tuple

In [None]:
# TODO: Input search folders to look for image folders
search_folders = r"""
Day3_09-12-2023_FastScanner/Day3_09-12-2023_FastScanner
Day10_09-22-2023_FastScanner
""".strip().split('\n')

# TODO: Output folder for h5 files and predictions
dst_folder = r"downstream_data_analysis_and_extraction/h5_files_and_predictions"

# TODO: Set overwrite to True to overwrite existing files
overwrite = False

In [None]:
def natural_sort(l):
    """https://stackoverflow.com/a/4836734"""
    l = [x.as_posix() if isinstance(x, Path) else x for x in l]
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key=alphanum_key)


def convert_img_folder_to_h5(img_folder, dst_folder, overwrite=False):
    """Convert an image folder to an HDF5 file and save it to the dst_folder.
    
    Args:
        img_folder: Path to a folder filled with PNGs named sequentially.
        dst_folder: Path to a folder where HDF5 file will be saved. 
        overwrite: If True, save the HDF5 file even if it already exists.
        
    Notes:
        This will save a file with the same filename as the image folder, but
        with a .h5 extension. The resulting file will contain a dataset named
        "vol" with shape (slices, height, width, 1).
    """
    # Create a Path object from the img_folder path
    img_folder_path = Path(img_folder)
    # Get the img_folder as a string 
    img_folder = img_folder_path.as_posix()
    # Get the parent directory using the .parent attribute
    parent_name = img_folder_path.parent.parts[-1]
    print(parent_name)
    # Get the h5_name from the img_folder_path
    h5_name = f"{img_folder_path.stem}.h5"
    # Create a Path object from the dst_folder 
    dst_folder = Path(dst_folder)
    
    # Create the parent folder
    parent_folder_path = dst_folder / parent_name
    parent_folder_path.mkdir(exist_ok=True)
    
    # Make the final h5 file path 
    dst_name = parent_folder_path / h5_name
    
    if not overwrite and Path.exists(dst_name):
        return dst_name

    p = Path(img_folder)
    img_paths = natural_sort(list(p.glob("*.png")))

    vol = np.stack([iio.imread(p) for p in img_paths], axis=0)  # (slices, height, width)

    with h5py.File(dst_name, "w") as f:
        ds = f.create_dataset(
            "vol",
            data=np.expand_dims(vol, axis=-1),  # (slices, height, width, 1)
            compression=1
        )
    return dst_name

In [None]:



def predict(
    video: sleap.Video,
    scan_id: str,
    models_input_dir: str,
    model_dict: dict,
    output_dir: str,
) -> Tuple[sleap.Labels, dict]:
    """Get the SLEAP predictions.

    Args:
        video: A SLEAP Video object.
        scan_id: Unique ID of scan for naming the predictions.
        models_input_dir: Directory containing the models.
        model_dict: A dictionary where keys are model types and values are
            {'model_id': model_id, 'model_path': model_path}.
        output_dir: Directory to save predictions and predictions.csv.

    Returns:
        sleap.Labels: SLEAP predictions.
        preds_dict: A dictionary containing scan_id, model_type, and prediction path.
    """
    # Initialize dictionary for prediction info
    preds_dict = {}
    # Add scan_id to dictionary
    preds_dict["scan_id"] = scan_id

    # Check if video is None
    if video is None:
        return None, preds_dict

    # Log sleap version
    logging.info(f"SLEAP version: {sleap.versions()}")

    # Iterate over each model type
    for model_type, model_info in model_dict.items():
        model_id = model_info.get("model_id")
        model_path = model_info.get("model_path")

        # Modify model path to include the models input directory in the container
        modified_model_path = Path(models_input_dir) / model_path
        logging.info(
            f"Processing model {model_type} with model_id {model_id} from {modified_model_path}."
        )

        # Build path to save predictions
        preds_name = f"scan{scan_id}.model{model_id}.root{model_type}.slp"
        preds_path = Path(output_dir) / preds_name

        # Load the model
        predictor = sleap.load_model(
            modified_model_path.as_posix(), progress_reporting="none"
        )
        logging.info(f"Loaded model {model_type} from {modified_model_path}.")

        # Get the predictions
        labels = predictor.predict(video)
        # Save the predictions
        labels.save(preds_path)
        logging.info(f"Saved predictions to {preds_path}.")

        # Add prediction name to dictionary (path in output_dir)
        preds_dict[model_type] = preds_name

    return labels, preds_dict