In [None]:
import os
import tifffile
import matplotlib.pyplot as plt
from pathlib import Path
import random
import numpy as np

# Notebook-friendly plot display
%matplotlib inline

# Choose how often to sample a frame from each film
SAMPLE_FRAME_INTERVAL = 10

# Set paths
input_dir = Path("raw-data")
output_base = Path("../data")

# Create output directories
for subset in ["train", "val", "test"]:
    (output_base / subset / "A").mkdir(parents=True, exist_ok=True)
    (output_base / subset / "B").mkdir(parents=True, exist_ok=True)

In [None]:
# Process files from already-split directories
BF_CMAP = "bone"    # Bright-field colormap
FL_CMAP = "magma"   # Fluorescence colormap

for current_split in ["train", "val", "test"]:
    for tif_path in (input_dir / current_split).glob("*.tif"):
        try:
            with tifffile.TiffFile(tif_path) as tif:
                arr = tif.asarray()

            bright_field_index = 0
            fluorescence_index = 1

            # Work out indexing helpers
            if arr.ndim == 4:
                n_frames = arr.shape[0]
                if arr.shape[1] <= 4:                    # (frames, channels, H, W)
                    get_bf = lambda i: arr[i, bright_field_index]
                    get_fl = lambda i: arr[i, fluorescence_index]
                else:                                     # (channels, frames, H, W)
                    n_frames = arr.shape[1]
                    get_bf = lambda i: arr[bright_field_index, i]
                    get_fl = lambda i: arr[fluorescence_index, i]
            elif arr.ndim == 3 and arr.shape[0] == 2 and arr.shape[1:] == (512, 512):
                n_frames = 1
                get_bf = lambda i: arr[bright_field_index]
                get_fl = lambda i: arr[fluorescence_index]
            else:
                print(f"ERROR: {tif_path} has unsupported shape {arr.shape}")
                continue

            # Process frames
            for frame_index in range(0, n_frames, SAMPLE_FRAME_INTERVAL):
                bf_image = get_bf(frame_index)
                fl_image = get_fl(frame_index)

                # Sanity check
                if (
                    bf_image is None or fl_image is None or
                    bf_image.shape != (512, 512) or fl_image.shape != (512, 512)
                ):
                    print(f"ERROR: Skipping frame {frame_index} from {tif_path}")
                    continue

                base_name = f"{tif_path.stem}_{frame_index:04d}"
                split_A = output_base / current_split / "A"
                split_B = output_base / current_split / "B"

                if current_split == "train":
                    # Augment on RAW grayscale, colormap only at save time
                    aug = {
                        "":       (bf_image,                         fl_image),
                        "_90":    (np.rot90(bf_image, k=1),          np.rot90(fl_image, k=1)),
                        "_180":   (np.rot90(bf_image, k=2),          np.rot90(fl_image, k=2)),
                        "_270":   (np.rot90(bf_image, k=3),          np.rot90(fl_image, k=3)),
                        "_horiz": (np.fliplr(bf_image),              np.fliplr(fl_image)),
                        "_vert":  (np.flipud(bf_image),              np.flipud(fl_image)),
                    }
                    for suffix, (bf_aug, fl_aug) in aug.items():
                        plt.imsave(split_A / f"{base_name}{suffix}.png", bf_aug, cmap=BF_CMAP)
                        plt.imsave(split_B / f"{base_name}{suffix}.png", fl_aug, cmap=FL_CMAP)
                else:
                    plt.imsave(split_A / f"{base_name}.png", bf_image, cmap=BF_CMAP)
                    plt.imsave(split_B / f"{base_name}.png", fl_image, cmap=FL_CMAP)

        except Exception as e:
            print(f"ERROR processing {tif_path}: {e}")