In [None]:
# General Imports
import os
import numpy as np
import subprocess
import sys

# Data
import pathlib
from pathlib import Path
import tifffile

# Architecture imports
import torch
import torch.nn as nn

from typing import List, Tuple
import shutil

from scipy.ndimage import zoom
from skimage import exposure
import warnings

In [None]:
# Check if GPU is there
if torch.cuda.is_available():
    print("GPU available for execution ")
else:
    print("GPU not available ! . Please check or proceed to execute in CPU")

In [None]:
# Set environment variables
os.environ["nnUNet_raw"] = "/research/sharedresources/cbi/data_exchange/dyergrp/retinal_degeneration/Version_4_underdev/DAPI/nnunet/nnUNet_raw"
os.environ["nnUNet_preprocessed"] = "/research/sharedresources/cbi/data_exchange/dyergrp/retinal_degeneration/Version_4_underdev/DAPI/nnunet/nnUNet_preprocessed"
os.environ["nnUNet_results"] = "/research/sharedresources/cbi/data_exchange/dyergrp/retinal_degeneration/Version_4_underdev/DAPI/nnunet/nnUNet_results"

In [None]:
input_path = Path(r"/research/sharedresources/cbi/data_exchange/dyergrp/retinal_degeneration/Version_4_underdev/Outputs_DL_CBI_9-20-23/3_input_subgroups/1")
output_path = r"/research/sharedresources/cbi/data_exchange/dyergrp/retinal_degeneration/Version_4_underdev/Outputs_DL_CBI_9-20-23/2_Full_execution_outputs"

In [None]:
# nnunet details
dataset_num=104
#config="2d"
config="3d_fullres"

In [None]:
on_demand=True

## Chunk the Image and Execute the model

In [None]:
def make_chuncks(volume, output_folder, tif_file,chunk_size=(128, 256, 256)):
    # Calculate number of chunks in each dimension
    chunks_z = int(np.ceil(volume.shape[0] / chunk_size[0]))
    chunks_y = int(np.ceil(volume.shape[1] / chunk_size[1]))
    chunks_x = int(np.ceil(volume.shape[2] / chunk_size[2]))
    chunk_num=0

    for z in range(chunks_z):
        for y in range(chunks_y):
            for x in range(chunks_x):
                # Calculate chunk boundaries
                z_start, z_end = z * chunk_size[0], min((z + 1) * chunk_size[0], volume.shape[0])
                y_start, y_end = y * chunk_size[1], min((y + 1) * chunk_size[1], volume.shape[1])
                x_start, x_end = x * chunk_size[2], min((x + 1) * chunk_size[2], volume.shape[2])

                # Extract chunks
                volume_chunk = volume[z_start:z_end, y_start:y_end, x_start:x_end]
               
                # Pad chunks if necessary
                if volume_chunk.shape != chunk_size:
                    volume_chunk = np.pad(volume_chunk, 
                                          ((0, chunk_size[0] - volume_chunk.shape[0]), 
                                           (0, chunk_size[1] - volume_chunk.shape[1]), 
                                           (0, chunk_size[2] - volume_chunk.shape[2])),
                                          mode='constant')

                # Save chunks
                chunk_name = f"{tif_file[:-4]}_z{z}_y{y}_x{x}_{chunk_num:03}_0000.tif"
                tifffile.imwrite(output_folder / chunk_name, volume_chunk)
                chunk_num+=1
               
    print("chunks created ...")

In [None]:
def resize_volume_bicubic(volume, target_size=(256, 819)):
    """
    Performs bicubic interpolation on a 3D volume array to resize x and y dimensions.
    
    Parameters:
    -----------
    volume : numpy.ndarray
        Input 3D volume with shape (z, y, x)
    target_size : tuple
        Desired output size for (y, x) dimensions, default is (256, 819)
        
    Returns:
    --------
    numpy.ndarray
        Resized volume with shape (z, 256, 819)
    """
    
    # Get current dimensions
    z_dim, y_dim, x_dim = volume.shape
    
    # Calculate zoom factors for each dimension
    z_factor = 1.0  # Keep z dimension unchanged
    y_factor = target_size[0] / y_dim
    x_factor = target_size[1] / x_dim
    
    # Perform bicubic interpolation
    # order=3 specifies bicubic interpolation
    resized_volume = zoom(volume, (z_factor, y_factor, x_factor), order=3)
    
    return resized_volume

In [None]:
def restore_label_volume(label_volume, original_shape):
    """
    Resizes a label volume back to its original dimensions using nearest neighbor interpolation
    to preserve label values.
    
    Parameters:
    -----------
    label_volume : numpy.ndarray
        Input label volume with shape (z, 256, 819)
    original_shape : tuple
        Original shape to restore to (z, y, x)
        
    Returns:
    --------
    numpy.ndarray
        Restored label volume with original shape
    """
    
    # Get current dimensions
    z_dim, y_dim, x_dim = label_volume.shape
    
    # Calculate zoom factors for each dimension
    #z_factor = original_shape[0] / z_dim
    z_factor = 1
    y_factor = original_shape[1] / y_dim
    x_factor = original_shape[2] / x_dim
    
    # Use nearest neighbor interpolation (order=0) to preserve label values
    restored_volume = zoom(label_volume, (z_factor, y_factor, x_factor), order=0)
    
    return restored_volume

In [None]:
def apply_clahe_3d(volume, kernel_size=(8, 8), clip_limit=0.01, nbins=256):
    """
    Applies Contrast Limited Adaptive Histogram Equalization (CLAHE) to a 3D volume
    slice by slice along the z-axis.
    
    Parameters:
    -----------
    volume : numpy.ndarray
        Input 3D volume with shape (z, y, x)
    kernel_size : tuple
        Size of kernel for CLAHE in (y, x) dimensions, default is (8, 8)
    clip_limit : float
        Clipping limit for CLAHE, normalized between 0 and 1
    nbins : int
        Number of bins for histogram, default is 256
        
    Returns:
    --------
    numpy.ndarray
        CLAHE processed volume with same shape as input
    """
    
    # Input validation
    if volume.ndim != 3:
        raise ValueError("Input volume must be 3D")
        
    # Convert to float and normalize to [0, 1] if not already
    if volume.dtype != np.float32 and volume.dtype != np.float64:
        volume_norm = volume.astype(float)
        if volume_norm.max() > 1.0:
            volume_norm = (volume_norm - volume_norm.min()) / (volume_norm.max() - volume_norm.min())
    else:
        volume_norm = volume.copy()
    
    # Initialize CLAHE object
    clahe = exposure.equalize_adapthist
    
    # Process each slice
    processed_volume = np.zeros_like(volume_norm)
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        for z in range(volume.shape[0]):
            processed_volume[z] = clahe(
                volume_norm[z],
                kernel_size=kernel_size,
                clip_limit=clip_limit,
                nbins=nbins
            )
    
    return processed_volume

In [None]:
def convert_xz_to_xy(volume):
    return np.transpose(volume, (1, 0, 2))

In [None]:
def reconstruct_volume(chunks_folder, output_folder, final_shape, chunk_size=(128, 256, 256)):
    chunks_folder = Path(chunks_folder)
    output_folder = Path(output_folder)
    output_folder.mkdir(parents=True, exist_ok=True)

    # Group chunks by original filename
    chunk_groups = {}
    for chunk_file in chunks_folder.glob("*.tif"):
        # Parse chunk file name based on the new pattern
        original_name, coords = chunk_file.stem.rsplit('_z', 1)
        z, yx_chunknum = coords.split('_y')
        y, x_chunknum = yx_chunknum.split('_x')
        x, chunk_num = x_chunknum.split('_')
        
        # Convert coordinates and chunk_num to integers
        z, y, x = int(z), int(y), int(x)
        
        # Group chunks by original file name
        if original_name not in chunk_groups:
            chunk_groups[original_name] = []
        chunk_groups[original_name].append((z, y, x, chunk_file))

    for original_name, chunks in chunk_groups.items():
        # Determine the shape of the padded volume
        max_z = max(chunk[0] for chunk in chunks) + 1
        max_y = max(chunk[1] for chunk in chunks) + 1
        max_x = max(chunk[2] for chunk in chunks) + 1

        # Initialize the reconstructed volume (padded)
        padded_shape = (
            max_z * chunk_size[0],
            max_y * chunk_size[1],
            max_x * chunk_size[2]
        )
        reconstructed_volume = np.zeros(padded_shape, dtype=np.float32)

        # Fill the reconstructed volume with chunks
        for z, y, x, chunk_file in chunks:
            chunk = tifffile.imread(chunk_file)
            reconstructed_volume[
                z * chunk_size[0] : (z + 1) * chunk_size[0],
                y * chunk_size[1] : (y + 1) * chunk_size[1],
                x * chunk_size[2] : (x + 1) * chunk_size[2]
            ] = chunk

        # Crop the reconstructed volume to the final shape
        #final_volume = reconstructed_volume[:final_shape[0], :final_shape[1], :final_shape[2]]
        final_volume = reconstructed_volume

        # Save the reconstruction
        output_file = output_folder / f"{original_name}_reconstructed_original_before_shape_adjustment.tif"
        tifffile.imwrite(output_file, final_volume)

        # Adjust the size of the final label
        restored_final_volume = restore_label_volume(final_volume, final_shape)

        # Unnecessary step - at this stage final shape should be the shaoe if restored_final_volume - but still clipping
        restored_final_volume = restored_final_volume[:final_shape[0], :final_shape[1], :final_shape[2]]

        # Save the reconstructed and cropped volume
        output_file = output_folder / f"{original_name}_reconstructed.tif"
        tifffile.imwrite(output_file, restored_final_volume)

        # Transpose from xz to xy and save
        xy_final_volume = convert_xz_to_xy(restored_final_volume)
        xy_output_file = output_folder / f"{original_name}_reconstructed_xy.tif"
        tifffile.imwrite(xy_output_file, xy_final_volume)
        

    print("Reconstruction complete.")

In [None]:
def run_nnunet(input_path, output_path, dataset_num, config):
    # Create output directory
    Path(output_path).mkdir(parents=True, exist_ok=True)
    
    # Run command
    cmd = [
        "nnUNetv2_predict",
        "-i", str(input_path),
        "-o", str(output_path),
        "-d", str(dataset_num),
        "-c", config,
        "--save_probabilities"
    ]

    result = ' '.join(cmd)
    print("command is", result)
    
    try:
        subprocess.run(cmd, check=True, text=True)
        print("Prediction completed successfully!")
    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        sys.exit(1)

In [None]:
def run_nnunet_ondemand(input_path, output_path, dataset_num, config):
    # Create output directory
    Path(output_path).mkdir(parents=True, exist_ok=True)
    
    # Run command
    cmd = [
        "nnUNetv2_predict",
        "-i", str(input_path),
        "-o", str(output_path),
        "-d", str(dataset_num),
        "-c", config,
        "--save_probabilities"
    ]

    cmd = ["conda", "run", "-p", "/research/sharedresources/cbi/public/conda_envs/nnunet"] + cmd

    result = ' '.join(cmd)
    print("command is", result)
    
    try:
        subprocess.run(cmd, check=True, text=True)
        print("Prediction completed successfully!")
    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        sys.exit(1)

In [None]:
def execute_for_single_volume(file_path,output_path):
    print("Executing for ",file_path.stem)

    # Read the volume
    volume = tifffile.imread(file_path)

    # Restructure the input volume to match nn input shape
    resized_volume = resize_volume_bicubic(volume, target_size=(256, 819))

    print("Preprocessing with clahe")
    # Apply clahe 
    resized_preprocessed_volume = apply_clahe_3d(resized_volume)
    
    # Make Chuncks
    chuncked_volume_output = Path(output_path) / (str(file_path.stem) + "_chunks")
    chuncked_volume_output.mkdir(parents=True, exist_ok=True)    
    #make_chuncks(volume, chuncked_volume_output, file_path.name,chunk_size=(128, 256, 256))
    make_chuncks(resized_preprocessed_volume, chuncked_volume_output, file_path.name,chunk_size=(128, 256, 819))
    
    # Create Segmentation
    model_outputs = Path(output_path) / (str(file_path.stem) + "_segmentations")
    model_outputs.mkdir(parents=True, exist_ok=True)

    # Execute nn unet - change here if you running on downsampled image --- for faster execution remove downsampling in future
    if on_demand:
        run_nnunet_ondemand(chuncked_volume_output, model_outputs, dataset_num, config)
        #print("")
    else:
        run_nnunet(chuncked_volume_output, model_outputs, dataset_num, config)
    
    # Reconstruct
    #reconstruct_volume(model_outputs_upsampled, output_path, volume.shape, chunk_size=(128, 256, 256))
    #reconstruct_volume(model_outputs, output_path, volume.shape, chunk_size=(128, 256, 256))
    reconstruct_volume(model_outputs, output_path, volume.shape, chunk_size=(128, 256, 819))

## Get all valid isotropic DAPI

In [None]:
def process_dapi_paths(input_dir, output_dir):
    """
    Process directories to find DAPI XZ images and create corresponding result folders.
    Skip folders that already contain processed results (C4-DAPI-XZ_reconstructed.tif).
    
    Args:
        input_dir (str): Path to the input directory containing processed images
        output_dir (str): Path to create DAPI results folders
        
    Returns:
        Tuple[List[str], List[str]]: Lists of (input DAPI paths, output result folder paths)
    """
    # Initialize lists to store paths
    dapi_xz_paths = []
    dapi_result_paths = []
    
    # Convert to Path objects for easier handling
    input_path = Path(input_dir)
    output_path = Path(output_dir)
    
    # Create output directory if it doesn't exist
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Walk through all directories and subdirectories
    for root, dirs, files in os.walk(input_path):
        # Convert current root to Path object
        root_path = Path(root)
        
        # Check if we're in an isotropic_image folder
        if root_path.name == "isotropic_image":
            # Look for C4-DAPI-XZ.tif in files
            if "C4-DAPI-XZ.tif" in files:
                # Get the full path to the DAPI XZ image
                dapi_path = root_path / "C4-DAPI-XZ.tif"
                
                # Get the series folder name (parent of isotropic_image)
                series_folder = root_path.parent.name
                
                # Create corresponding output folder structure
                result_folder = output_path / series_folder / "DAPI_results"
                result_folder.mkdir(parents=True, exist_ok=True)
                
                # Check if reconstructed file already exists
                reconstructed_file = result_folder / "C4-DAPI-XZ_reconstructed.tif"
                if reconstructed_file.exists():
                    print(f"Skipping {dapi_path} - reconstructed file already exists")
                    continue
                
                # Add paths to lists
                dapi_xz_paths.append(str(dapi_path))
                dapi_result_paths.append(str(result_folder))
                
                print(f"Found DAPI XZ image: {dapi_path}")
                print(f"Created results folder: {result_folder}")
    
    return dapi_xz_paths, dapi_result_paths

In [None]:
dapi_paths, result_paths = process_dapi_paths(input_path, output_path)

In [None]:
print(len(dapi_paths))
print(dapi_paths)

In [None]:
for dapi_input_file, result_output_directory in zip(dapi_paths, result_paths):
    print("Executing DAPI - model segmentation for :")
    print(dapi_input_file)
    # Execute for single volume
    execute_for_single_volume(Path(dapi_input_file),result_output_directory)
print("Processing complete !")