In [None]:
from dotenv import load_dotenv
import os
import tifffile
import matplotlib.pyplot as plt
from pathlib import Path
import random
import numpy as np

# Notebook-friendly plot display
%matplotlib inline

# Environment variable loading
load_dotenv()
COLOR_MODE = os.getenv("COLOR_MODE")
if COLOR_MODE not in ("grayscale", "rgb"):
    raise ValueError(f"COLOR_MODE '{COLOR_MODE}' must be 'grayscale' or 'rgb'")
SAMPLE_FRAME_INTERVAL = int(os.getenv("SAMPLE_FRAME_INTERVAL"))

# Set paths
input_dir = Path("raw-data")
output_base = Path("../data")

# Create output directories
for subset in ["train", "val", "test"]:
    (output_base / subset / "A").mkdir(parents=True, exist_ok=True)
    (output_base / subset / "B").mkdir(parents=True, exist_ok=True)

In [None]:
# Set a hash of used file names
used_names = set()

# Create a new filename that is guaranteed not to exist if loading all
# at once. String format: `basename_#`
def get_unique_filename(base_name: str) -> str:
    attempt = 0
    while True:
        # 48-bit = 281 trillion possibilities
        suffix = random.getrandbits(48)
        name = f"{base_name}_{suffix}"
        if name not in used_names:
            used_names.add(name)
            return name
        attempt += 1
        if attempt > 1000:
            raise RuntimeError("Too many naming collisions")

In [None]:
# Wrap individual file processing in try:except to prevent halting
for current_split in ["train", "val", "test"]:
    for tif_path in (input_dir / current_split).glob("*.tif"):
        try:
            with tifffile.TiffFile(tif_path) as tif:
                arr = tif.asarray()

            bright_field_index = 0
            fluorescence_index = 1

            # Validate shape using just the first frame
            if arr.ndim == 4:
                n_frames = arr.shape[0]
                if arr.shape[1] <= 4:  # (frames, channels, H, W)
                    example_bf = arr[0, bright_field_index]
                    example_fl = arr[0, fluorescence_index]
                    get_bf = lambda i: arr[i, bright_field_index]
                    get_fl = lambda i: arr[i, fluorescence_index]
                else:  # (channels, frames, H, W)
                    n_frames = arr.shape[1]
                    example_bf = arr[bright_field_index, 0]
                    example_fl = arr[fluorescence_index, 0]
                    get_bf = lambda i: arr[bright_field_index, i]
                    get_fl = lambda i: arr[fluorescence_index, i]
            elif arr.ndim == 3 and arr.shape[0] == 2 and arr.shape[1:] == (512, 512):
                n_frames = 1
                example_bf = arr[bright_field_index]
                example_fl = arr[fluorescence_index]
                get_bf = lambda i: arr[bright_field_index]
                get_fl = lambda i: arr[fluorescence_index]
            else:
                print(f"ERROR: {tif_path} has unsupported shape {arr.shape}")
                continue

            # Normalize bright-field to the first frame (to prevent extreme flickering)
            ref_bf_frame = get_bf(0).astype(np.float32)
            ref_bf_mean = ref_bf_frame.mean()

            # Process every nth frame individually
            for frame_index in range(0, n_frames, SAMPLE_FRAME_INTERVAL):
                # Normalize bright-field frame to match the reference mean
                bf_image = get_bf(frame_index).astype(np.float32)
                bf_image = bf_image - bf_image.mean() + ref_bf_mean

                # Set current frames to use
                bf_image = np.clip(bf_image, 0, 65535).astype(np.uint16)
                fl_image = get_fl(frame_index)

                if (
                    bf_image is None or fl_image is None or
                    bf_image.shape != (512, 512) or fl_image.shape != (512, 512)
                ):
                    print(f"ERROR: Skipping frame {frame_index} from {tif_path} due to invalid data")
                    continue

                base_name = get_unique_filename('frame')
                bf_base = bf_image
                fl_base = fl_image

                split_A = output_base / current_split / "A"
                split_B = output_base / current_split / "B"

                if current_split == "train":
                    augmentations = {
                        "": (bf_base, fl_base),
                        "_90": (np.rot90(bf_base, k=1), np.rot90(fl_base, k=1)),
                        "_180": (np.rot90(bf_base, k=2), np.rot90(fl_base, k=2)),
                        "_270": (np.rot90(bf_base, k=3), np.rot90(fl_base, k=3)),
                        "_horiz": (np.fliplr(bf_base), np.fliplr(fl_base)),
                        "_vert": (np.flipud(bf_base), np.flipud(fl_base)),
                    }

                    for suffix, (bf_aug, fl_aug) in augmentations.items():
                        if COLOR_MODE == "grayscale":
                            # Save to high-resolution grayscale
                            out_name = f"{base_name}{suffix}.tif"
                            out_A = split_A / out_name
                            out_B = split_B / out_name
                            tifffile.imwrite(out_A, bf_aug.astype(np.uint16))
                            tifffile.imwrite(out_B, fl_aug.astype(np.uint16))
                        elif COLOR_MODE == "rgb":
                            # Save to rgb png
                            out_name = f"{base_name}{suffix}.png"
                            out_A = split_A / out_name
                            out_B = split_B / out_name
                            plt.imsave(out_A, bf_aug, cmap="gray")
                            plt.imsave(out_B, fl_aug, cmap="inferno")
                        else:
                            raise ValueError(f"COLOR_MODE '{COLOR_MODE}' must be 'grayscale' or 'rgb'")

                else:
                    if COLOR_MODE == "grayscale":
                        # Save to high-resolution grayscale
                        out_name = f"{base_name}.tif"
                        out_A = split_A / out_name
                        out_B = split_B / out_name
                        tifffile.imwrite(out_A, bf_base.astype(np.uint16))
                        tifffile.imwrite(out_B, fl_base.astype(np.uint16))
                    elif COLOR_MODE == "rgb":
                        # Save to rgb png
                        out_name = f"{base_name}.png"
                        out_A = split_A / out_name
                        out_B = split_B / out_name
                        plt.imsave(out_A, bf_base, cmap="gray")
                        plt.imsave(out_B, fl_base, cmap="inferno")
                    else:
                        raise ValueError(f"COLOR_MODE '{COLOR_MODE}' must be 'grayscale' or 'rgb'")
        except Exception as e:
            print(f"ERROR processing {tif_path}: {e}")