In [None]:
import tifffile
import numpy as np
import matplotlib.pyplot as plt
from skimage import exposure
from skimage.util import img_as_ubyte
from scipy.ndimage import gaussian_filter
from skimage.restoration import denoise_nl_means
from pathlib import Path
import random

# Notebook-friendly plot display
%matplotlib inline

# Set paths
input_dir = Path("raw-data")
output_base = Path("../data")
split_weights = {'train': 80, 'val': 10, 'test': 10}

# Create output directories
for split in split_weights:
    (output_base / split / "A").mkdir(parents=True, exist_ok=True)
    (output_base / split / "B").mkdir(parents=True, exist_ok=True)

# Set a seed for deterministic RNG when splitting
random.seed(2025)

# Assign splits randomly (near perfect with large enough raw-data folder
def assign_split():
    return random.choices(['train', 'val', 'test'], weights=[8, 1, 1])[0]

In [None]:
# Set a hash of used file names
used_names = set()

# Create a new filename that is guaranteed not to exist if loading all
# at once. String format: `basename_#`
def get_unique_filename(base_name: str) -> str:
    attempt = 0
    while True:
        # 48-bit = 281 trillion possibilities
        suffix = random.getrandbits(48)
        name = f"{base_name}_{suffix}"
        if name not in used_names:
            used_names.add(name)
            return name
        attempt += 1
        if attempt > 1000:
            raise RuntimeError("Too many naming collisions")

In [None]:
# Wrap individual file processing in try:except to prevent halting
for tif_path in input_dir.glob("*.tif"):
    try:
        with tifffile.TiffFile(tif_path) as tif:
            arr = tif.asarray()

        bright_field_index = 0
        fluorescence_index = 1

        # Validate shape using just the first frame
        if arr.ndim == 4:
            n_frames = arr.shape[0]
            if arr.shape[1] <= 4:  # (frames, channels, H, W)
                example_bf = arr[0, bright_field_index]
                example_fl = arr[0, fluorescence_index]
                get_bf = lambda i: arr[i, bright_field_index]
                get_fl = lambda i: arr[i, fluorescence_index]
            else:  # (channels, frames, H, W)
                n_frames = arr.shape[1]
                example_bf = arr[bright_field_index, 0]
                example_fl = arr[fluorescence_index, 0]
                get_bf = lambda i: arr[bright_field_index, i]
                get_fl = lambda i: arr[fluorescence_index, i]
        elif arr.ndim == 3 and arr.shape[0] == 2 and arr.shape[1:] == (512, 512):
            n_frames = 1
            example_bf = arr[bright_field_index]
            example_fl = arr[fluorescence_index]
            get_bf = lambda i: arr[bright_field_index]
            get_fl = lambda i: arr[fluorescence_index]
        else:
            print(f"ERROR: {tif_path} has unsupported shape {arr.shape}")
            continue

        # Ensure shape is valid
        if example_bf.shape != (512, 512) or example_fl.shape != (512, 512):
            print(f"ERROR: {tif_path} has incorrect size {example_bf.shape}")
            continue

        # Process each frame individually
        for frame_index in range(n_frames):
            bf_image = get_bf(frame_index)
            fl_image = get_fl(frame_index)

            # Preprocess fluorescence
            p_low, p_high = np.percentile(fl_image, (1, 99))
            contrast_stretched = exposure.rescale_intensity(fl_image, in_range=(p_low, p_high))
            denoised = denoise_nl_means(contrast_stretched, h=0.06, fast_mode=True)
            fl_processed = gaussian_filter(denoised, sigma=1.0)
            fl_normalized = img_as_ubyte(fl_processed / np.max(fl_processed))

            # Assign to split and name
            current_split = assign_split()
            out_name = get_unique_filename('frame') + ".png"

            # Sanity check to prevent unpaired data
            if bf_image is None or fl_normalized is None:
                print(f"ERROR: Skipping frame {frame_index} from {tif_path} due to missing data after processing")
                continue

            # Save images to data folder for model
            plt.imsave(output_base / current_split / "A" / out_name, bf_image, cmap='gray')
            plt.imsave(output_base / current_split / "B" / out_name, fl_normalized, cmap='viridis')

    except Exception as e:
        print(f"ERROR processing {tif_path}: {e}")