In [None]:
import tensorflow as tf
print(tf.__version__)
print("GPUs:", tf.config.list_physical_devices('GPU'))

In [None]:
import os
import shutil
import numpy as np
import glob
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import tensorflow_probability as tfp
from tensorflow.keras import layers, models
import scipy.ndimage as ndimage
from skimage.filters import threshold_otsu
from skimage.transform import resize
from tqdm import tqdm
from datetime import datetime
import logging
import random
from tensorflow.keras.models import Model
from tensorflow_addons.layers import SpectralNormalization, InstanceNormalization
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, LeakyReLU, BatchNormalization, Flatten, Dense, UpSampling2D, Concatenate, Dropout, Layer, LayerNormalization, MultiHeadAttention, add
from tensorflow.keras.layers import concatenate, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.mixed_precision import Policy, set_global_policy
from collections import defaultdict

# Configure TensorFlow for consistent float32 precision
tf.keras.backend.set_floatx('float32')
tf.keras.mixed_precision.set_global_policy('float32')

# GPU Configuration
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"Using GPU: {gpus[0]}")
    except RuntimeError as e:
        print(e)

def build_discriminator(input_shape=(128, 128, 2)):
    """
    Constructs a CNN-based discriminator for conditional GAN (cGAN).
    
    The discriminator takes a concatenated input of brightfield (ch3) and fluorescence (ch1/ch2) 
    images to distinguish between real and generated fluorescence outputs. Uses progressive 
    downsampling with batch normalization and LeakyReLU activations.
    
    Args:
        input_shape (tuple): Shape of input tensor (height, width, channels).
                           Expected: (128, 128, 2) for concatenated brightfield + fluorescence
    
    Returns:
        Model: Compiled Keras model that outputs a single scalar prediction (real/fake)
        
    Architecture:
        - 4 convolutional blocks with progressive channel increase (32→64→128→256)
        - Each block: Conv2D → BatchNorm → LeakyReLU → 2x2 stride downsampling
        - Final output: Flattened features → Dense(1) for binary classification
        
    Note:
        Uses float32 dtype explicitly to ensure numerical stability during training.
    """
    inputs = Input(shape=input_shape, dtype='float32')
    
    # Downsample from 128×128 to 8×8 through progressive convolutions
    x = Conv2D(32, (4,4), strides=2, padding='same')(inputs)  # 64×64
    x = LeakyReLU(0.2)(x)
    
    x = Conv2D(64, (4,4), strides=2, padding='same')(x)  # 32×32
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    
    x = Conv2D(128, (4,4), strides=2, padding='same')(x)  # 16×16
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    
    x = Conv2D(256, (4,4), strides=2, padding='same')(x)  # 8×8
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    
    # Flatten to 8×8×256 = 16384 features
    x = Flatten()(x)
    
    # Final classification layer for real/fake discrimination
    x = Dense(1, dtype='float32')(x)
    
    return Model(inputs, x)


def build_generator(input_shape=(128, 128, 1)):
    """
    Constructs a U-Net style generator for conditional GAN (cGAN).
    
    Takes brightfield microscopy images (ch3) as input and generates corresponding 
    fluorescence images for both dead (ch1) and alive (ch2) cell populations. 
    Uses encoder-decoder architecture with skip connections for preserving spatial details.
    
    Args:
        input_shape (tuple): Shape of input brightfield image (height, width, channels).
                           Expected: (128, 128, 1) for single-channel brightfield
    
    Returns:
        Model: Keras model with dual outputs [dead_fluorescence, alive_fluorescence]
        
    Architecture:
        Encoder:
        - 3 downsampling blocks: Conv2D → BatchNorm → LeakyReLU
        - Progressive channel increase: 32 → 64 → 128
        - Spatial reduction: 128×128 → 64×64 → 32×32 → 16×16
        
        Decoder:
        - 3 upsampling blocks: UpSampling2D → Conv2D → BatchNorm → LeakyReLU
        - Skip connections from corresponding encoder layers
        - Progressive channel decrease with spatial expansion
        
        Outputs:
        - Two separate tanh-activated channels for dead/alive fluorescence
        
    Design Notes:
        - Uses UpSampling2D + Conv2D instead of Conv2DTranspose to avoid checkerboard artifacts
        - Bias disabled (use_bias=False) in conv layers paired with BatchNormalization
        - Tanh activation ensures output range [-1, 1] matching normalized input data
        - Skip connections preserve fine-grained spatial information lost during downsampling
    """
    inputs = Input(shape=input_shape)

    # Encoder: Progressive downsampling with feature extraction
    enc1 = Conv2D(32, (4,4), strides=2, padding='same')(inputs)  # 128→64
    enc1 = LeakyReLU(0.2)(enc1)

    enc2 = Conv2D(64, (4,4), strides=2, padding='same')(enc1)   # 64→32
    enc2 = BatchNormalization()(enc2)
    enc2 = LeakyReLU(0.2)(enc2)

    enc3 = Conv2D(128, (4,4), strides=2, padding='same')(enc2)  # 32→16
    enc3 = BatchNormalization()(enc3)
    enc3 = LeakyReLU(0.2)(enc3)

    # Decoder: Progressive upsampling with skip connections (U-Net style)

    # Upsampling Block 1: 16×16 → 32×32
    dec1 = UpSampling2D(size=(2,2))(enc3)
    dec1 = Conv2D(256, (3,3), padding='same', use_bias=False)(dec1)
    dec1 = BatchNormalization()(dec1)
    dec1 = LeakyReLU(0.2)(dec1)
    dec1 = concatenate([dec1, enc2])  # Skip connection from encoder

    # Upsampling Block 2: 32×32 → 64×64
    dec2 = UpSampling2D(size=(2,2))(dec1)
    dec2 = Conv2D(128, (3,3), padding='same', use_bias=False)(dec2)
    dec2 = BatchNormalization()(dec2)
    dec2 = LeakyReLU(0.2)(dec2)
    dec2 = concatenate([dec2, enc1])  # Skip connection from encoder

    # Upsampling Block 3: 64×64 → 128×128
    dec3 = UpSampling2D(size=(2,2))(dec2)
    dec3 = Conv2D(64, (3,3), padding='same', use_bias=False)(dec3)
    dec3 = BatchNormalization()(dec3)
    dec3 = LeakyReLU(0.2)(dec3)

    # Dual output heads for dead and alive cell fluorescence channels
    out_dead = Conv2D(1, (1,1), padding='same', activation='tanh', use_bias=False, name='gen_output_dead')(dec3)
    out_alive = Conv2D(1, (1,1), padding='same', activation='tanh', use_bias=False, name='gen_output_alive')(dec3)

    return Model(inputs, [out_dead, out_alive])


class ImageProcessor:
    """
    Comprehensive image processing pipeline for microscopy data preparation.
    
    Handles organization, preprocessing, and tile extraction from multi-channel microscopy 
    images for deep learning applications. Designed specifically for brightfield (ch3) 
    and fluorescence (ch1/ch2) microscopy data with position-based organization.
    
    Key Features:
    - Automatic file organization by imaging position and channel
    - Train/validation/test splitting at position level (prevents data leakage)
    - GPU-accelerated normalization and brightness enhancement
    - Binary mask generation for region-of-interest extraction
    - Tile extraction around mask centroids for standardized input sizes
    
    Attributes:
        image_path (str): Source directory containing raw TIFF images
        output_base (str): Base directory for organized output structure
        batch_size (int): Number of files to process per batch
        channel_matrices (dict): Storage for organized file paths by channel
        device (str): Compute device ('/GPU:0' or '/CPU:0')
        split_ratios (tuple): Train/validation/test split ratios
    """
    
    def __init__(self, image_path, output_base="channel_matrix", batch_size=500, split_ratios=(0.6, 0.2, 0.2)):
        """
        Initialize the ImageProcessor with configuration parameters.
        
        Args:
            image_path (str): Path to directory containing source TIFF images
            output_base (str): Base directory for organized output structure
            batch_size (int): Files per processing batch (for memory management)
            split_ratios (tuple): (train, validation, test) split proportions
        
        Raises:
            ValueError: If split_ratios don't sum to 1.0
            FileNotFoundError: If image_path doesn't exist
        """
        self.image_path = image_path
        self.output_base = output_base
        self.batch_size = batch_size
        self.channel_matrices = {"ch1": {}, "ch2": {}, "ch3": {}}
        self.device = '/GPU:0' if gpus else '/CPU:0'
        self.split_ratios = split_ratios
        
        # Validate split ratios
        if abs(sum(split_ratios) - 1.0) > 1e-6:
            raise ValueError(f"Split ratios must sum to 1.0, got {sum(split_ratios)}")
    

    def extract_position_key(self, file_name):
        """
        Extracts standardized position identifier from microscopy filename.
        
        Parses filenames with format containing row (r), column (c), frame (f), 
        and position (p) information to create consistent position keys for grouping.
        
        Args:
            file_name (str): Microscopy image filename
                           Expected format: *rXcYfZpW-ch*sk*.tiff
                           
        Returns:
            str or None: Position key in format "rXcYfZpW" if parsing successful,
                        None if filename doesn't match expected pattern
                        
        Example:
            >>> extract_position_key("sample_r02c03f001p01-ch1sk1.tiff")
            "r02c03f001p01"
        """
        if all(x in file_name for x in ["r", "c", "p", "-ch"]):
            parts = file_name.split('r')[1].split('c')
            row = parts[0]
            col_parts = parts[1].split('f')
            col = col_parts[0]
            f_parts = col_parts[1].split('p')
            frame = f_parts[0]
            pos = f_parts[1].split('-')[0]
            return f"r{row}c{col}f{frame}p{pos}"
        return None
        
    def process_file(self, file_name, subset):
        """
        Process and organize a single image file into the appropriate directory structure.
        
        Extracts metadata from filename, creates organized directory structure by 
        subset/channel/position, and copies files to appropriate locations. Maintains
        channel_matrices tracking for downstream processing.
        
        Args:
            file_name (str): Name of the image file to process
            subset (str): Target subset ('train', 'val', or 'test')
            
        Side Effects:
            - Creates directory structure under output_base
            - Copies file to organized location
            - Updates self.channel_matrices with file path
            
        Error Handling:
            Catches and logs processing errors without stopping batch processing
        """
        try:
            position_key = self.extract_position_key(file_name)
            if position_key is None:
                return
                
            channel = file_name.split("-ch")[1].split("sk")[0]
            row_col = position_key.split('f')[0]  # Get rXcY part
            
            channel_key = f"r{row_col.split('r')[1].split('c')[0]}c{row_col.split('c')[1]}"
            file_path = os.path.join(self.image_path, file_name)
            
            # Create subset-specific directory structure
            channel_dir = os.path.join(self.output_base, subset, f"ch{channel}", channel_key)
            os.makedirs(channel_dir, exist_ok=True)

            shutil.copy(file_path, os.path.join(channel_dir, file_name))
            self.channel_matrices[f"ch{channel}"].setdefault(channel_key, []).append(file_path)
        except Exception as e:
            print(f"[ERROR] Failed to process file {file_name}: {e}")

    def organize_images(self):
        """
        Organize all images into train/validation/test splits with position-level grouping.
        
        Prevents data leakage by ensuring all images from the same microscopy position 
        are placed in the same subset. Uses multithreading for efficient file processing.
        
        Process:
        1. Group files by position identifier (all channels together)
        2. Randomly shuffle position groups
        3. Split groups according to specified ratios
        4. Process files in parallel using ThreadPoolExecutor
        
        Returns:
            dict: Updated channel_matrices containing organized file paths
            
        Note:
            This is crucial for preventing data leakage in microscopy data where
            adjacent tiles or different channels from the same position are highly correlated.
        """
        os.makedirs(self.output_base, exist_ok=True)
        all_files = [f for f in os.listdir(self.image_path) if f.lower().endswith('.tiff')]
        
        # Group files by position (all channels for same position together)
        position_groups = defaultdict(list)
        for file in all_files:
            position_key = self.extract_position_key(file)
            if position_key:
                position_groups[position_key].append(file)
        
        # Convert to list of groups and shuffle for random split
        group_list = list(position_groups.values())
        random.shuffle(group_list)
        
        # Calculate split indices based on position groups
        total_groups = len(group_list)
        train_end = int(total_groups * self.split_ratios[0])
        val_end = train_end + int(total_groups * self.split_ratios[1])
        
        # Split the groups
        train_groups = group_list[:train_end]
        val_groups = group_list[train_end:val_end]
        test_groups = group_list[val_end:]
        
        # Flatten the groups into file lists
        train_files = [file for group in train_groups for file in group]
        val_files = [file for group in val_groups for file in group]
        test_files = [file for group in test_groups for file in group]
        
        print(f"[INFO] Processing {len(all_files)} files in {total_groups} position groups...")
        print(f"[INFO] Splitting into: Train ({len(train_files)}), Val ({len(val_files)}), Test ({len(test_files)})")
        
        # Process each subset using multithreading for efficiency
        with ThreadPoolExecutor(max_workers=8) as executor:
            # Process train files
            print("[INFO] Processing train files...")
            list(tqdm(executor.map(lambda f: self.process_file(f, "train"), train_files), total=len(train_files)))
            
            # Process validation files
            print("[INFO] Processing validation files...")
            list(tqdm(executor.map(lambda f: self.process_file(f, "val"), val_files), total=len(val_files)))
            
            # Process test files
            print("[INFO] Processing test files...")
            list(tqdm(executor.map(lambda f: self.process_file(f, "test"), test_files), total=len(test_files)))
        
        return self.channel_matrices

    def normalize_image(self, image):
        """
        GPU-accelerated robust image normalization using percentile-based scaling.
        
        Normalizes images to 16-bit range (0-65535) using robust statistics to handle
        outliers and varying illumination conditions common in microscopy data.
        
        Args:
            image (np.ndarray): Input image array (any numeric dtype)
            
        Returns:
            np.ndarray: Normalized image as uint16 in range [0, 65535]
            
        Algorithm:
        1. Convert to float32 tensor for GPU computation
        2. Calculate 0.01% and 99.99% percentiles as robust min/max
        3. Scale linearly to 16-bit range
        4. Clamp and convert back to uint16
        
        Note:
            Uses TensorFlow Probability for efficient percentile computation on GPU.
            Robust percentiles prevent outlier pixels from distorting normalization.
        """
        with tf.device(self.device):
            image_tf = tf.convert_to_tensor(image, dtype=tf.float32)
            b_min = tfp.stats.percentile(image_tf, 0.0001 * 100)
            b_max = tfp.stats.percentile(image_tf, 0.9999 * 100)
            b_range = tf.maximum(b_max - b_min, 1e-7)  # Prevent division by zero
            normalized = (image_tf - b_min) / b_range * 65535
            return tf.cast(normalized, tf.uint16).numpy()


    def piecewise_brighten(self, image, lower_factor=1.2, upper_factor=2.0, threshold=None):
        """
        GPU-accelerated piecewise brightness enhancement for microscopy images.
        
        Applies differential brightness enhancement to improve contrast in both
        dim and bright regions of microscopy images. Uses automatic thresholding
        to separate intensity regions and apply appropriate enhancement factors.
        
        Args:
            image (np.ndarray): Input image to enhance
            lower_factor (float): Brightness multiplier for pixels below threshold
            upper_factor (float): Brightness multiplier for pixels above threshold  
            threshold (float, optional): Intensity threshold. If None, uses 99th percentile
            
        Returns:
            np.ndarray: Enhanced image as uint16, clipped to valid range
            
        Algorithm:
        1. Auto-threshold at 99th percentile if not specified
        2. Apply lower_factor to dim pixels (< threshold)
        3. Apply upper_factor to bright pixels (>= threshold)
        4. Clip to valid 16-bit range and return
        
        Use Case:
            Enhances fluorescence signal visibility while preserving detail in
            both background and bright signal regions.
        """
        with tf.device(self.device):
            image_tf = tf.convert_to_tensor(image, dtype=tf.float32)
        
            # Auto-threshold at 99th percentile if not specified
            if threshold is None:
                threshold = tfp.stats.percentile(image_tf, 99)
        
            # Apply piecewise transformation
            result = tf.where(
                image_tf < threshold,
                image_tf * lower_factor,
                image_tf * upper_factor
            )
        
            # Clip and convert back to uint16
            return tf.clip_by_value(result, 0, 65535).numpy().astype(np.uint16)
   

    def generate_mask(self, img_norm):
        """
        Generate binary mask identifying the largest connected region in the image.
        
        Creates a binary mask using Otsu's thresholding followed by connected component
        analysis to identify and isolate the largest continuous region. Useful for
        focusing analysis on primary biological structures.
        
        Args:
            img_norm (np.ndarray): Normalized input image (2D array)
            
        Returns:
            np.ndarray: Binary mask as uint8 (255=foreground, 0=background)
            
        Algorithm:
        1. Apply Otsu's threshold to create initial binary mask
        2. Perform connected component labeling
        3. Identify largest connected component by pixel count
        4. Create final mask containing only the largest component
        
        Note:
            Assumes the largest connected component represents the main biological
            structure of interest (e.g., cell culture, tissue section).
        """
        threshold = threshold_otsu(img_norm)
        binary_mask = img_norm <= threshold
        labeled_mask, num_features = ndimage.label(binary_mask)
        sizes = ndimage.sum(binary_mask, labeled_mask, range(num_features + 1))
        largest_component = (labeled_mask == np.argmax(sizes))
        return (largest_component * 255).astype(np.uint8)


    def extract_tiles(self, img_norm, final_mask, square_size=400, target_size=128):
        """
        Extract standardized tiles around the centroid of a binary mask.
        
        Crops a square region centered on the mask centroid, then divides into
        quadrant tiles and resizes to target dimensions. Provides consistent
        spatial sampling for training deep learning models.
        
        Args:
            img_norm (np.ndarray): Normalized source image (2D)
            final_mask (np.ndarray): Binary mask defining region of interest
            square_size (int): Size of square region to extract around centroid
            target_size (int): Final tile size after resizing
            
        Returns:
            list: List of resized tiles as numpy arrays, empty list if extraction fails
            
        Algorithm:
        1. Calculate mask centroid using center of mass
        2. Define square crop region centered on centroid
        3. Handle boundary conditions (image edges)
        4. Divide cropped region into 4 quadrant tiles
        5. Resize each tile to target dimensions with anti-aliasing
        
        Error Handling:
            Returns empty list if:
            - Centroid calculation fails
            - Crop boundaries are invalid
            - Any tile is empty after cropping
        """
        # Convert mask to binary if needed
        if final_mask.dtype != bool:
            final_mask = final_mask > 0
    
        # Get centroid coordinates
        centroid_y, centroid_x = ndimage.center_of_mass(final_mask)
        centroid_y, centroid_x = int(round(float(centroid_y))), int(round(float(centroid_x)))
    
        half_size = square_size // 2
    
        # Explicitly get image dimensions as integers
        img_height, img_width = int(img_norm.shape[0]), int(img_norm.shape[1])
    
        # Calculate boundaries with explicit type conversion
        y_min = max(0, int(centroid_y - half_size))
        y_max = min(img_height, int(centroid_y + half_size))
        x_min = max(0, int(centroid_x - half_size))
        x_max = min(img_width, int(centroid_x + half_size))
    
        # Safety check for valid boundaries
        if y_min >= y_max or x_min >= x_max:
            print(f"Warning: Invalid crop coordinates (y:{y_min}-{y_max}, x:{x_min}-{x_max})")
            return []
    
        # Proceed with cropping
        cropped_region = img_norm[y_min:y_max, x_min:x_max]
    
        # Divide into quadrant tiles
        mid_y, mid_x = cropped_region.shape[0] // 2, cropped_region.shape[1] // 2
    
        tiles = [
            cropped_region[:mid_y, :mid_x],   # Top-left
            cropped_region[:mid_y, mid_x:],   # Top-right
            cropped_region[mid_y:, :mid_x],   # Bottom-left
            cropped_region[mid_y:, mid_x:]    # Bottom-right
        ]
    
        # Resize tiles to target dimensions
        resized_tiles = []
        for tile in tiles:
            if tile.size > 0:  # Only process non-empty tiles
                resized = resize(tile, (target_size, target_size),
                           preserve_range=True, anti_aliasing=True)
                resized_tiles.append(resized)
    
        return resized_tiles


class cGANDataPipeline:
    """
    Comprehensive data pipeline for conditional GAN training on microscopy images.
    
    Manages the complete data preprocessing workflow for multi-channel microscopy data,
    creating aligned triplets of brightfield and fluorescence images, extracting tiles,
    and preparing training batches for conditional GAN models.
    
    Key Features:
    - Ensures perfect alignment between channels (ch1, ch2, ch3)
    - Handles train/validation/test data splits
    - Creates paired data for both generator and discriminator training
    - Generates tiles from region-of-interest masks
    - Saves preprocessed data in efficient batch format
    
    Attributes:
        channel_matrices (dict): Organized file paths by channel
        tile_size (int): Target size for extracted tiles
        image_processor (ImageProcessor): Instance for image processing operations
        val_split (float): Validation set proportion (deprecated - using organized splits)
    """
    
    def __init__(self, channel_matrices, image_processor, tile_size=128, val_split=0.2):
        """
        Initialize the cGAN data pipeline.
        
        Args:
            channel_matrices (dict): Organized file paths from ImageProcessor
            image_processor (ImageProcessor): Instance for image processing
            tile_size (int): Target tile size for model input
            val_split (float): Validation split ratio (deprecated)
        """
        self.channel_matrices = channel_matrices
        self.tile_size = tile_size
        self.image_processor = image_processor
        self.val_split = val_split
        self._prepare_datasets()


    def _get_aligned_triplets(self, subset="train"):
        """
        Create aligned triplets of corresponding images across all three channels.
        
        Ensures that ch1 (dead), ch2 (alive), and ch3 (brightfield) images from the 
        same microscopy position and timepoint are properly matched for training.
        Critical for maintaining correspondence between input and target images.
        
        Args:
            subset (str): Data subset ('train', 'val', or 'test')
            
        Returns:
            list: List of tuples (ch1_path, ch2_path, ch3_path) for aligned images
            
        Algorithm:
        1. Find common position keys across all three channels
        2. Sort files within each position by base filename (ignoring channel suffix)
        3. Verify alignment by comparing base filenames
        4. Create triplets only for perfectly aligned images
        
        Raises:
            ValueError: If channel directories are missing for the subset
            
        Note:
            Prints warnings for misaligned files but continues processing.
            Alignment is critical - misaligned triplets would corrupt training.
        """
        triplets = []
        base_dir = os.path.join("channel_matrix", subset)
        ch1_dir = os.path.join(base_dir, "ch1")
        ch2_dir = os.path.join(base_dir, "ch2")
        ch3_dir = os.path.join(base_dir, "ch3")

        # Verify all channel directories exist
        if not all(os.path.exists(d) for d in [ch1_dir, ch2_dir, ch3_dir]):
            raise ValueError(f"Missing channel directories for subset {subset}")

        ch1_keys = set(os.listdir(ch1_dir))
        ch2_keys = set(os.listdir(ch2_dir))
        ch3_keys = set(os.listdir(ch3_dir))
        common_keys = ch1_keys & ch2_keys & ch3_keys

        for key in common_keys:
            # Get files and sort by BASE filename (ignoring channel)
            ch1_files = sorted(
                [f for f in os.listdir(os.path.join(ch1_dir, key)) if f.endswith(".tiff")],
                key=lambda x: x.split("-ch1")[0]  
            )
            ch2_files = sorted(
                [f for f in os.listdir(os.path.join(ch2_dir, key)) if f.endswith(".tiff")],
                key=lambda x: x.split("-ch2")[0]
            )
            ch3_files = sorted(
                [f for f in os.listdir(os.path.join(ch3_dir, key)) if f.endswith(".tiff")],
                key=lambda x: x.split("-ch3")[0]
            )

            # Verify alignment after sorting
            for i, (f1, f2, f3) in enumerate(zip(ch1_files, ch2_files, ch3_files)):
                base1 = f1.split("-ch1")[0]
                base2 = f2.split("-ch2")[0]
                base3 = f3.split("-ch3")[0]
        
                if base1 == base2 == base3:
                    triplet = (
                        os.path.join(ch1_dir, key, f1),
                        os.path.join(ch2_dir, key, f2),
                        os.path.join(ch3_dir, key, f3)
                    )
                    triplets.append(triplet)
                else:
                    print(f"[WARNING] Mismatched base in {subset} set, key {key} index {i}: {base1} vs {base2} vs {base3}")

        print(f"[INFO] Found {len(triplets)} aligned triplets in {subset} set")
        return triplets

    def _process_triplet(self, ch3_path, ch1_path, ch2_path):
        """
        Process a single aligned triplet of images into training-ready tiles.
        
        Loads, preprocesses, and extracts tiles from aligned microscopy images.
        Creates data pairs for both generator training (brightfield → fluorescence)
        and discriminator training (real vs. fake fluorescence pairs).
        
        Args:
            ch3_path (str): Path to brightfield image (generator input)
            ch1_path (str): Path to dead cell fluorescence image (target 1)
            ch2_path (str): Path to alive cell fluorescence image (target 2)
            
        Returns:
            tuple: (generator_inputs, generator_targets_dead, generator_targets_alive, discriminator_real_pairs)
            - generator_inputs: List of ch3 tiles for generator input
            - generator_targets_dead: List of ch1 tiles for generator target
            - generator_targets_alive: List of ch2 tiles for generator target  
            - discriminator_real_pairs: List of (ch3, ch1/ch2) pairs for discriminator training
            
        Processing Pipeline:
        1. Load raw images from all three channels
        2. Apply brightness enhancement to fluorescence channels
        3. Normalize all images to consistent intensity ranges
        4. Generate binary mask from brightfield for ROI detection
        5. Extract aligned tiles from all channels using the same mask
        6. Prepare generator input/target pairs
        7. Create discriminator real/fake training pairs
        
        Note:
            Returns empty lists if tile extraction fails.
            Randomly assigns real discriminator pairs between dead/alive channels.
        """

        print("-----Inside _process_triplet------")

        # Load raw images from all channels
        ch3_raw = np.array(Image.open(ch3_path))
        ch1_raw = np.array(Image.open(ch1_path))
        ch2_raw = np.array(Image.open(ch2_path))

        # Apply brightness enhancement to fluorescence channels
        ch1_img = self.image_processor.piecewise_brighten(ch1_raw, 0.8, 3.6)
        ch2_img = self.image_processor.piecewise_brighten(ch2_raw, 0.8, 3.6)
        
        # Normalize all three channels to consistent ranges
        ch1_img = self.image_processor.normalize_image(ch1_img)
        ch2_img = self.image_processor.normalize_image(ch2_img)
        ch3_img = self.image_processor.normalize_image(ch3_raw)
        

        # Generate binary mask from brightfield to focus on relevant regions
        binary_mask = self.image_processor.generate_mask(ch3_img)

        # Extract aligned tiles from all channels using the same mask
        ch1_tiles = self.image_processor.extract_tiles(ch1_img, binary_mask)
        ch2_tiles = self.image_processor.extract_tiles(ch2_img, binary_mask)
        ch3_tiles = self.image_processor.extract_tiles(ch3_img, binary_mask)


        if not ch3_tiles:
            print("Warning: No tiles extracted.")
            return [], [], [], []

        # Prepare data for both dead (ch1) and alive (ch2) generation tasks
        generator_inputs = []         # ch3 image tiles (brightfield input)
        generator_targets_dead = []   # ch1 tiles (ground truth for dead cell generation)
        generator_targets_alive = []  # ch2 tiles (ground truth for alive cell generation)

        discriminator_real_pairs = [] # List to store (ch3_tile, ch1_tile) or (ch3_tile, ch2_tile) for real pairs

        for i in range(len(ch3_tiles)):
            ch3 = np.expand_dims(ch3_tiles[i], axis=-1)  # Input to generator (brightfield)
            ch1 = np.expand_dims(ch1_tiles[i], axis=-1)  # Dead fluorescence (target 1)
            ch2 = np.expand_dims(ch2_tiles[i], axis=-1)  # Alive fluorescence (target 2)

            # Generator training pairs: G(ch3) → ch1 or ch2
            generator_inputs.append(ch3)
            generator_targets_dead.append(ch1)
            generator_targets_alive.append(ch2)

            # Discriminator real pairs: randomly choose between dead/alive channels
            if np.random.rand() > 0.5:
                discriminator_real_pairs.append((ch3, ch1)) # (brightfield, dead_fluorescence)
            else:
                discriminator_real_pairs.append((ch3, ch2)) # (brightfield, alive_fluorescence)


        return generator_inputs, generator_targets_dead, generator_targets_alive, discriminator_real_pairs

     
    
    def _prepare_datasets(self, output_dir="tiles", batch_size=24):
        """
        Prepare all datasets (train/validation/test) by processing triplets into batches.
        
        High-level orchestrator that processes each subset by calling _process_subset.
        Creates the complete preprocessed dataset structure for GAN training.
        
        Args:
            output_dir (str): Base directory for saving processed tile batches
            batch_size (int): Number of samples per saved batch file
        """
        for subset in ["train", "test", "val"]:
            self._process_subset(output_dir, batch_size, subset)

            
    def _process_subset(self, output_dir, batch_size, subset):
        """
        Process a complete data subset into batched, normalized training files.
        
        Converts aligned image triplets into normalized tile batches suitable for
        GAN training. Handles data normalization, concatenation for discriminator
        inputs, and efficient batch-wise saving.
        
        Args:
            output_dir (str): Base directory for saving processed tiles
            batch_size (int): Number of samples per batch file
            subset (str): Data subset to process ('train', 'val', 'test')
            
        Processing Workflow:
        1. Get all aligned triplets for the subset
        2. Process each triplet into tiles and training pairs  
        3. Normalize data to [-1, 1] range for stable GAN training
        4. Concatenate brightfield + fluorescence for discriminator inputs
        5. Save data in batches when batch_size is reached
        6. Log progress and data statistics
        
        Data Normalization:
            Converts uint16 [0, 65535] → float32 [-1, 1] using formula:
            normalized = (data / 32767.5) - 1.0
            
        File Structure:
            - gen_inputs_{batch_idx}.npy: Generator inputs (brightfield tiles)
            - gen_targets_dead_{batch_idx}.npy: Dead fluorescence targets
            - gen_targets_alive_{batch_idx}.npy: Alive fluorescence targets  
            - disc_inputs_real_{batch_idx}.npy: Concatenated real pairs for discriminator
            - disc_targets_real_{batch_idx}.npy: Labels (all 1.0 for real data)
        """
        subset_dir = os.path.join(output_dir, subset)
        os.makedirs(subset_dir, exist_ok=True)
    
        # Get triplets for the specified subset
        triplets = self._get_aligned_triplets(subset)

        # Initialize batch accumulators
        gen_inputs, gen_targets_dead, gen_targets_alive = [], [], []
        disc_inputs_real_concatenated, disc_targets_real = [], [] # Renamed for clarity
        file_index = 0

        def normalize_to_float32(data):
            """
            Normalize uint16 [0,65535] to float32 [-1,1] for stable GAN training.
            
            Args:
                data (np.ndarray): Input data in uint16 range
                
            Returns:
                np.ndarray: Normalized data in [-1,1] range as float32
            """
            data = data.astype(np.float32)  # First convert to float32
            return (data / 32767.5) - 1.0  # Then normalize to [-1,1]

        for i, (ch1_path, ch2_path, ch3_path) in enumerate(tqdm(triplets, 
                                                          desc=f"Processing {subset} triplets", 
                                                          unit="triplet")):

            # Process triplet into tiles and training pairs
            generator_inputs_batch, generator_targets_dead_batch, generator_targets_alive_batch, discriminator_real_pairs_batch = self._process_triplet(
                ch3_path, ch1_path, ch2_path)


            if not generator_inputs_batch:
                print(f"⚠️  Skipping empty tile result in {subset} set.")
                continue

            # Normalize and accumulate generator data
            gen_inputs.extend([normalize_to_float32(x) for x in generator_inputs_batch])
            gen_targets_dead.extend([normalize_to_float32(x) for x in generator_targets_dead_batch])
            gen_targets_alive.extend([normalize_to_float32(x) for x in generator_targets_alive_batch])


            # Prepare discriminator real pairs by concatenating channels
            for brightfield_tile, fluo_tile in discriminator_real_pairs_batch:
                # Concatenate brightfield (ch3) with real fluorescence (ch1 or ch2)
                # Both are (128, 128, 1), concatenating gives (128, 128, 2)
                concatenated_input = np.concatenate([brightfield_tile, fluo_tile], axis=-1)
                disc_inputs_real_concatenated.append(normalize_to_float32(concatenated_input))
                disc_targets_real.append(1.0)  # Label = 1 (real)

            # Save batch when ready or at end of triplets
            if len(gen_inputs) >= batch_size or i == len(triplets) - 1:
                np.save(os.path.join(subset_dir, f"gen_inputs_{file_index}.npy"), np.stack(gen_inputs))
                np.save(os.path.join(subset_dir, f"gen_targets_dead_{file_index}.npy"), np.stack(gen_targets_dead))
                np.save(os.path.join(subset_dir, f"gen_targets_alive_{file_index}.npy"), np.stack(gen_targets_alive))
                np.save(os.path.join(subset_dir, f"disc_inputs_real_{file_index}.npy"), np.stack(disc_inputs_real_concatenated))
                np.save(os.path.join(subset_dir, f"disc_targets_real_{file_index}.npy"), np.array(disc_targets_real))

                # Log batch statistics periodically
                if file_index % 100 == 0 or i == len(triplets) - 1:
                    print(f"[{subset} Batch {file_index}] Range check - "
                          f"Gen inputs: [{gen_inputs[0].min():.3f}, {gen_inputs[0].max():.3f}] | "
                          f"Disc inputs: [{disc_inputs_real_concatenated[0].min():.3f}, {disc_inputs_real_concatenated[0].max():.3f}]")
                    print(f"[{subset} Batch {file_index}] Shape check - "
                          f"Gen inputs shape: {np.stack(gen_inputs).shape} | "
                          f"Disc inputs shape: {np.stack(disc_inputs_real_concatenated).shape}")                

                # Reset accumulators for next batch
                gen_inputs, gen_targets_dead, gen_targets_alive = [], [], []
                disc_inputs_real_concatenated, disc_targets_real = [], []
                file_index += 1

        print(f"🎉 {subset} dataset preparation completed. All batches saved to:", subset_dir)


class TileDataLoader(tf.keras.utils.Sequence):
    """
    Efficient data loader for preprocessed tile batches in GAN training.
    
    Implements Keras Sequence interface for memory-efficient loading of large
    datasets. Uses memory mapping to avoid loading entire dataset into RAM
    while providing fast access to individual batches during training.
    
    Key Features:
    - Memory-mapped file loading for large datasets
    - Automatic shuffling between epochs
    - GPU/CPU device awareness  
    - Comprehensive error handling
    - Data shape and type validation
    
    Attributes:
        data_dir (str): Directory containing preprocessed batch files
        batch_size (int): Number of samples per batch (not used - batches are pre-sized)
        shuffle (bool): Whether to shuffle batch order between epochs
        file_indices (list): Sorted list of available batch file indices
        device (str): Compute device for operations
        sample_shapes (dict): Reference shapes for data validation
    """
    
    def __init__(self, data_dir, batch_size=24, shuffle=True):
        """
        Initialize the data loader with batch file discovery and validation.
        
        Args:
            data_dir (str): Directory containing .npy batch files
            batch_size (int): Target batch size (informational - actual size from files)
            shuffle (bool): Whether to shuffle batch order between epochs
            
        Raises:
            FileNotFoundError: If data_dir doesn't exist or contains no valid batches
            ValueError: If batch files are corrupted or have inconsistent shapes
        """
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.file_indices = self._get_file_indices()
        self.device = '/GPU:0' if gpus else '/CPU:0'
        self.on_epoch_end()
        
        # Verify first sample for shape validation
        sample = self.__getitem__(0)
        self.sample_shapes = {k: v.shape for k, v in sample.items()}
        print(f"Initialized DataLoader with {len(self.file_indices)} batches")
        print(f"Sample shapes: {self.sample_shapes}")

    def _get_file_indices(self):
        """
        Discover and validate available batch file indices in the data directory.
        
        Scans directory for files matching expected naming pattern and extracts
        batch indices. Ensures all required file types are present for each index.
        
        Returns:
            list: Sorted list of valid batch indices
            
        Raises:
            ValueError: If no valid batch files found
        """
        files = os.listdir(self.data_dir)
        indices = set()
        for f in files:
            if f.startswith("gen_inputs_"):
                idx = f.split("_")[-1].split(".")[0]
                indices.add(int(idx))
        
        if not indices:
            raise ValueError(f"No valid batch files found in {self.data_dir}")
            
        return sorted(indices)

    def __len__(self):
        """Return the number of batches available for training."""
        return len(self.file_indices)

    def __getitem__(self, index):
        """
        Load and return a single batch of training data.
        
        Uses memory mapping for efficient loading without keeping entire dataset
        in memory. Loads all required data types for a complete training batch.
        
        Args:
            index (int): Batch index to load
            
        Returns:
            dict: Training batch containing:
                - 'gen_input': Generator input data (brightfield tiles)
                - 'gen_target_dead': Dead fluorescence targets  
                - 'gen_target_alive': Alive fluorescence targets
                - 'disc_input': Discriminator input (concatenated real pairs)
                - 'disc_target': Discriminator targets (real labels)
                
        Raises:
            Exception: If batch loading fails (file corruption, missing files, etc.)
        """
        idx = self.file_indices[index]
        
        try:
            # Memory-mapped loading for efficiency
            data = {
                "gen_input": np.load(os.path.join(self.data_dir, f"gen_inputs_{idx}.npy"), mmap_mode='r'),
                "gen_target_dead": np.load(os.path.join(self.data_dir, f"gen_targets_dead_{idx}.npy"), mmap_mode='r'),
                "gen_target_alive": np.load(os.path.join(self.data_dir, f"gen_targets_alive_{idx}.npy"), mmap_mode='r'),
                "disc_input": np.load(os.path.join(self.data_dir, f"disc_inputs_real_{idx}.npy"), mmap_mode='r'),
                "disc_target": np.load(os.path.join(self.data_dir, f"disc_targets_real_{idx}.npy"), mmap_mode='r').astype(np.float32)
            }
            
            # Convert to array (loads into memory only when needed)
            return {k: np.array(v) for k, v in data.items()}
            
        except Exception as e:
            print(f"Error loading batch {idx}: {str(e)}")
            raise

    def on_epoch_end(self):
        """Shuffle batch indices if shuffling is enabled. Called automatically by Keras."""
        if self.shuffle:
            np.random.shuffle(self.file_indices)
    

class cGAN(keras.Model):
    """
    Conditional Generative Adversarial Network (cGAN) implementation.
    
    Wraps generator and discriminator models with training metrics tracking.
    Provides the foundation for the adversarial training process.
    
    Attributes:
        generator: Generator model for creating fake images
        discriminator: Discriminator model for classifying real/fake images  
        gen_loss_tracker: Metric tracker for generator loss
        disc_loss_tracker: Metric tracker for discriminator loss
    """
    
    def __init__(self, generator, discriminator):
        """
        Initialize the cGAN with generator and discriminator models.
        
        Args:
            generator: Keras model that generates fake images from inputs
            discriminator: Keras model that classifies real vs fake image pairs
        """
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")


class GANMonitor(keras.callbacks.Callback):
    """
    Training monitoring callback for GAN models.
    
    Provides visualization of training progress and automatic model checkpointing.
    Generates sample images during training to monitor generator quality.
    
    Attributes:
        num_img (int): Number of sample images to generate
        latent_dim (int): Dimension of latent noise vector
    """
    
    def __init__(self, num_img=4, latent_dim=256):
        """
        Initialize the monitoring callback.
        
        Args:
            num_img (int): Number of images to generate for monitoring
            latent_dim (int): Latent vector dimension for noise generation
        """
        self.num_img = num_img
        self.latent_dim = latent_dim
        os.makedirs("training_progress", exist_ok=True)
        os.makedirs("saved_models", exist_ok=True)
    
    def on_epoch_end(self, epoch, logs=None):
        """
        Generate sample images and save models at epoch end.
        
        Creates visualization grids showing dead vs alive cell generation
        quality and saves model checkpoints periodically.
        
        Args:
            epoch (int): Current epoch number
            logs (dict): Training metrics (unused)
        """
        noise = tf.random.normal([self.num_img, self.latent_dim])
        dead_labels = tf.zeros([self.num_img, 1], dtype=tf.int32)
        alive_labels = tf.ones([self.num_img, 1], dtype=tf.int32)
        
        dead_images = self.model.generator([noise, dead_labels])
        alive_images = self.model.generator([noise, alive_labels])
        
        combined = np.vstack([dead_images, alive_images])
        combined = (combined * 127.5 + 127.5).astype(np.uint8)
        
        fig = plt.figure(figsize=(8, 8))
        for i in range(2 * self.num_img):
            plt.subplot(2, self.num_img, i+1)
            plt.imshow(combined[i,:,:,0], cmap='gray')
            plt.axis('off')
            plt.title("Dead" if i < self.num_img else "Alive")
        
        plt.savefig(f"training_progress/epoch_{epoch+1}.png")
        plt.close()
        
        if (epoch + 1) % 5 == 0:
            self.model.generator.save(f"saved_models/generator_epoch_{epoch+1}.h5")
            self.model.discriminator.save(f"saved_models/discriminator_epoch_{epoch+1}.h5")


def verify_dtypes(inputs_dict, name="input"):
    """
    Comprehensive data type and validity verification for model inputs.
    
    Validates tensor dtypes, shapes, value ranges, and checks for NaN/Inf values.
    Critical for debugging numerical issues in GAN training.
    
    Args:
        inputs_dict (dict): Dictionary of tensors to verify
        name (str): Descriptive name for logging context
        
    Raises:
        AssertionError: If any tensor has wrong dtype or contains invalid values
        
    Checks performed:
    - Ensures all tensors are float32 dtype
    - Logs tensor shapes and value ranges  
    - Detects NaN and Infinity values
    - Provides detailed diagnostic output
    """
    print(f"\n=== {name} DType Verification ===")
    for key, tensor in inputs_dict.items():
        dtype = tensor.dtype
        shape = tensor.shape
        min_val = tf.reduce_min(tensor).numpy()
        max_val = tf.reduce_max(tensor).numpy()
        print(f"{key}: {dtype} | Shape: {shape} | Range: [{min_val:.4f}, {max_val:.4f}]")
        
        # Assert float32 dtype
        assert dtype == tf.float32, f"{key} has wrong dtype: {dtype}"
        
        # Check for NaN/Inf
        assert not tf.reduce_any(tf.math.is_nan(tensor)), f"NaN detected in {key}"
        assert not tf.reduce_any(tf.math.is_inf(tensor)), f"Inf detected in {key}"


def plot_samples(inputs, targets, predictions, title=""):
    """
    Create visualization comparing inputs, targets, and model predictions.
    
    Generates a grid plot showing brightfield inputs, ground truth targets,
    and model predictions for visual assessment of training progress.
    
    Args:
        inputs (np.ndarray): Input images (brightfield)
        targets (np.ndarray): Ground truth target images  
        predictions (np.ndarray): Model-generated predictions
        title (str): Plot title for context
        
    Display Layout:
        Each row shows: [Input | Target | Prediction] for up to 3 samples
    """
    plt.figure(figsize=(15, 5))
    for i in range(min(3, len(inputs))):
        plt.subplot(3, 3, i*3+1)
        plt.imshow(inputs[i,...,0], cmap='gray')
        plt.title("Input")
        plt.axis('off')
        
        plt.subplot(3, 3, i*3+2)
        plt.imshow(targets[i,...,0], cmap='gray')
        plt.title("Target")
        plt.axis('off')
        
        plt.subplot(3, 3, i*3+3)
        plt.imshow(predictions[i,...,0], cmap='gray')
        plt.title("Prediction")
        plt.axis('off')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()


def train(generator, discriminator, gan, data_gen, epochs=250, verbose=2, plot_interval=2):
    """
    Comprehensive training loop for Wasserstein GAN with Gradient Penalty (WGAN-GP).
    
    Implements adversarial training between generator and discriminator using WGAN-GP
    objective with gradient penalty for stable training. Includes comprehensive
    logging, validation monitoring, and model checkpointing.
    
    Args:
        generator: Generator model to train
        discriminator: Discriminator model to train  
        gan: Combined GAN model (unused in this implementation)
        data_gen: Data generator providing training batches
        epochs (int): Number of training epochs
        verbose (int): Verbosity level (0=silent, 1=every epoch, 2=every 2 epochs, 3=every 10 epochs)
        plot_interval (int): Frequency of generating plots and sample images
        
    Returns:
        tuple: (trained_generator, trained_discriminator, training_history)
        
    Training Algorithm:
    1. Discriminator Training:
       - Generate fake samples from current generator
       - Compute Wasserstein loss on real vs fake samples  
       - Add gradient penalty for Lipschitz constraint
       - Update discriminator parameters
       
    2. Generator Training:
       - Generate fake samples 
       - Compute adversarial loss (fool discriminator)
       - Add L1 reconstruction losses for both dead/alive targets
       - Update generator parameters
       
    Key Features:
    - WGAN-GP implementation with gradient penalty
    - Dual-target generator training (dead + alive fluorescence)
    - Comprehensive metric tracking and visualization
    - Automatic model checkpointing
    - Validation monitoring on fixed batch
    - Gradient clipping for numerical stability
    
    Loss Components:
    - Discriminator: Wasserstein loss + 10.0 * gradient_penalty
    - Generator: 2.0 * L1_dead + 2.0 * L1_alive + 3.0 * adversarial_loss
    """
    # ====================== INITIALIZATION ======================

    # Optimizers with different learning rates for stability
    g_opt = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
    d_opt = tf.keras.optimizers.Adam(learning_rate=5e-5, beta_1=0.5)

    # Fixed validation batch for consistent monitoring
    val_loader = TileDataLoader("tiles/val", batch_size=16)
    val_batch = next(iter(val_loader))
    val_input = val_batch["gen_input"]
    val_target_dead = val_batch["gen_target_dead"]
    val_target_alive = val_batch["gen_target_alive"]

    total_batches = 496  # Update with actual batch count
    
    # Create logging directory with timestamp
    current_time = datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = f"logs/{current_time}"
    os.makedirs(log_dir, exist_ok=True)
    
    # Initialize comprehensive metric tracking
    history = {
        'd_loss': [],           # Total discriminator loss
        'g_loss': [],           # Total generator loss  
        'd_loss_real': [],      # Discriminator loss on real samples
        'd_loss_fake': [],      # Discriminator loss on fake samples
        'grad_penalty': [],     # Gradient penalty component
        'l1_dead': [],          # L1 loss for dead fluorescence
        'l1_alive': [],         # L1 loss for alive fluorescence
        'g_loss_gan': [],       # Generator adversarial loss
        'val_l1_dead': [],      # Validation L1 for dead
        'val_l1_alive': []      # Validation L1 for alive
    }
    
    # ====================== TRAINING LOOP ======================
    for epoch in range(epochs):
        
        # Reset epoch metrics
        epoch_metrics = {k: 0 for k in history.keys() if not k.startswith('val_')}
        batch_count = 0
        
        for batch_idx in range(total_batches):    
            batch = data_gen.__getitem__(batch_idx)
            real_ch3 = batch["gen_input"]
            real_dead = batch["gen_target_dead"]
            real_alive = batch["gen_target_alive"]
            disc_real = batch["disc_input"]

            # ============ DISCRIMINATOR TRAINING ============
            with tf.GradientTape() as d_tape:
                # Generate fake samples (generator in inference mode)
                fake_dead, fake_alive = generator(real_ch3, training=False)
                fake_dead = tf.clip_by_value(fake_dead, -1.0, 1.0)

                # Get discriminator scores for real and fake data
                real_out = discriminator(disc_real)
                fake_dead_concatenated_for_D = tf.keras.layers.concatenate([real_ch3, fake_dead], axis=-1)
                fake_out_dead = discriminator(fake_dead_concatenated_for_D)

                # Binary cross-entropy losses with label smoothing
                d_loss_real = tf.reduce_mean(
                    tf.keras.losses.binary_crossentropy(
                        tf.ones_like(real_out) * 0.9,  # Label smoothing
                        real_out,
                        from_logits=True
                    )
                )
                
                d_loss_fake = tf.reduce_mean(
                    tf.keras.losses.binary_crossentropy(
                        tf.zeros_like(fake_out_dead) + 0.1,  # Label smoothing
                        fake_out_dead,
                        from_logits=True
                    )
                )

                # Wasserstein adversarial loss
                d_loss_adversarial = tf.reduce_mean(fake_out_dead) - tf.reduce_mean(real_out)

                # Gradient Penalty (WGAN-GP) for Lipschitz constraint
                with tf.GradientTape() as gp_tape:
                    # Create interpolated samples
                    alpha_shape = tf.concat([
                        [tf.shape(disc_real)[0]],
                        [1] * (len(disc_real.shape) - 1)
                    ], axis=0)
                    alpha = tf.random.uniform(alpha_shape, 0., 1., dtype=tf.float32)
                    interpolated = alpha * disc_real + (1 - alpha) * fake_dead_concatenated_for_D
        
                    # Watch interpolated samples for gradient computation
                    gp_tape.watch(interpolated)
                    crit_interpolated = discriminator(interpolated)
            
                # Compute gradient penalty
                grads = gp_tape.gradient(crit_interpolated, [interpolated])[0]
                grad_norms = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
                grad_penalty = 10 * tf.reduce_mean(tf.square(grad_norms - 1.0)) # Lambda = 10.0

                # Total discriminator loss
                d_loss = d_loss_adversarial + grad_penalty

            # Update discriminator with gradient clipping
            d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
            d_grads = [tf.where(tf.math.is_nan(g), tf.zeros_like(g), g) for g in d_grads]
            d_opt.apply_gradients(zip(d_grads, discriminator.trainable_variables))

            # ============ GENERATOR TRAINING ============
            with tf.GradientTape() as g_tape:
                # Generate fake samples (generator in training mode)
                fake_dead, fake_alive = generator(real_ch3, training=True)
                fake_dead = tf.clip_by_value(fake_dead, -1.0, 1.0)
                
                # Prepare discriminator input for generator loss
                fake_dead_concatenated_for_G = tf.keras.layers.concatenate([real_ch3, fake_dead], axis=-1)
                fake_out_dead_for_G_loss = discriminator(fake_dead_concatenated_for_G)

                # Generator adversarial loss (maximize discriminator score on fake)
                g_loss_gan = -tf.reduce_mean(fake_out_dead_for_G_loss)

                # Reconstruction losses for both dead and alive channels
                l1_dead = tf.reduce_mean(tf.abs(real_dead - fake_dead))
                l1_alive = tf.reduce_mean(tf.abs(real_alive - fake_alive))
                
                # Combined generator loss with weighting
                g_loss = 2.0 * l1_dead + 2.0 * l1_alive + 3.0 * g_loss_gan

            # Update generator with gradient clipping
            g_grads = g_tape.gradient(g_loss, generator.trainable_variables)
            g_grads = [tf.where(tf.math.is_nan(g), tf.zeros_like(g), g) for g in g_grads]
            g_opt.apply_gradients(zip(g_grads, generator.trainable_variables))

            # Accumulate batch metrics
            metrics = {
                'd_loss': d_loss,
                'g_loss': g_loss,
                'd_loss_real': d_loss_real,
                'd_loss_fake': d_loss_fake,
                'grad_penalty': grad_penalty,
                'l1_dead': l1_dead,
                'l1_alive': l1_alive,
                'g_loss_gan': g_loss_gan
            }
            
            for k, v in metrics.items():
                epoch_metrics[k] += v
            batch_count += 1

        # Calculate epoch averages and update history
        for k in epoch_metrics.keys():
            epoch_metrics[k] /= batch_count
            history[k].append(float(epoch_metrics[k]))

        # Validation evaluation
        val_pred_dead, val_pred_alive = generator(val_input, training=False)
        history['val_l1_dead'].append(tf.reduce_mean(tf.abs(val_target_dead - val_pred_dead)).numpy())
        history['val_l1_alive'].append(tf.reduce_mean(tf.abs(val_target_alive - val_pred_alive)).numpy())

        # Generate validation samples for visualization
        if epoch % plot_interval == 0:
            plot_samples(
                val_input[:3], 
                val_target_dead[:3], 
                val_pred_dead[:3],
                title=f"Epoch {epoch} Validation"
            )
        
        # ============ VERBOSE OUTPUT ============
        if verbose == 1:  # Print every epoch
            print(f"\nEpoch {epoch+1}/{epochs}")
            print(f"D_loss: {epoch_metrics['d_loss']:.4f} (Real: {epoch_metrics['d_loss_real']:.4f}, Fake: {epoch_metrics['d_loss_fake']:.4f}, GP: {epoch_metrics['grad_penalty']:.4f})")
            print(f"G_loss: {epoch_metrics['g_loss']:.4f} (L1: {epoch_metrics['l1_dead']+epoch_metrics['l1_alive']:.4f}, GAN: {epoch_metrics['g_loss_gan']:.4f})")
        elif verbose == 2:  # Print every 2 epochs
            if (epoch + 1) % 2 == 0 or epoch == 0:
                print(f"\nEpoch {epoch+1}/{epochs}")
                print(f"D_loss: {epoch_metrics['d_loss']:.4f} (Real: {epoch_metrics['d_loss_real']:.4f}, Fake: {epoch_metrics['d_loss_fake']:.4f}, GP: {epoch_metrics['grad_penalty']:.4f})")
                print(f"G_loss: {epoch_metrics['g_loss']:.4f} (L1: {epoch_metrics['l1_dead']+epoch_metrics['l1_alive']:.4f}, GAN: {epoch_metrics['g_loss_gan']:.4f})")
                print(f"Val L1 Dead: {history['val_l1_dead'][-1]:.4f} | Alive: {history['val_l1_alive'][-1]:.4f}")
        elif verbose == 3:  # Print every 10 epochs
            if (epoch + 1) % 10 == 0 or epoch == 0:
                print(f"\nEpoch {epoch+1}/{epochs}")
                print(f"D_loss: {epoch_metrics['d_loss']:.4f} (Real: {epoch_metrics['d_loss_real']:.4f}, Fake: {epoch_metrics['d_loss_fake']:.4f}, GP: {epoch_metrics['grad_penalty']:.4f})")
                print(f"G_loss: {epoch_metrics['g_loss']:.4f} (L1: {epoch_metrics['l1_dead']+epoch_metrics['l1_alive']:.4f}, GAN: {epoch_metrics['g_loss_gan']:.4f})")

        # ============ PLOTTING ============
        if (epoch + 1) % plot_interval == 0 or epoch == epochs - 1:
            plt.figure(figsize=(15, 10))
            
            # Plot main losses
            plt.subplot(2, 2, 1)
            plt.plot(history['d_loss'], label='Discriminator Loss')
            plt.plot(history['g_loss'], label='Generator Loss')
            plt.title('Training Losses')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.legend()
            
            # Plot discriminator components
            plt.subplot(2, 2, 2)
            plt.plot(history['d_loss_real'], label='D Loss Real')
            plt.plot(history['d_loss_fake'], label='D Loss Fake')
            plt.plot(history['grad_penalty'], label='Gradient Penalty')
            plt.title('Discriminator Loss Components')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.legend()
            
            # Plot generator components
            plt.subplot(2, 2, 3)
            plt.plot(history['l1_dead'], label='L1 Dead')
            plt.plot(history['l1_alive'], label='L1 Alive')
            plt.title('Generator Reconstruction Losses')
            plt.xlabel('Epoch')
            plt.ylabel('L1')
            plt.legend()
            
            # Plot GAN loss
            plt.subplot(2, 2, 4)
            plt.plot(history['g_loss_gan'], label='GAN Loss')
            plt.title('Generator GAN Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.legend()
            
            plt.tight_layout()
            plt.savefig(f'{log_dir}/training_plot_epoch_{epoch+1}.png')
            if verbose > 0:  # Only show plot if some verbosity is enabled
                plt.show()
            plt.close()

        # Model checkpointing every 5 epochs
        if (epoch + 1) % 5 == 0:
            generator.save(f"{log_dir}/generator_epoch{epoch+1}.h5")
            discriminator.save(f"{log_dir}/discriminator_epoch{epoch+1}.h5")

    # Save comprehensive final plots
    save_final_plots(history, log_dir)
    
    return generator, discriminator, history


def save_final_plots(history, log_dir):
    """
    Generate and save comprehensive final training visualization plots.
    
    Creates a detailed 6-panel visualization showing all aspects of GAN training
    including loss components, reconstruction quality, and training dynamics.
    
    Args:
        history (dict): Training history containing all tracked metrics
        log_dir (str): Directory to save the final plots
        
    Plot Panels:
    1. Main losses (D loss vs G loss)
    2. Discriminator loss components (real, fake, gradient penalty)
    3. Generator loss components (L1 dead, L1 alive, GAN loss)
    4. Reconstruction losses comparison
    5. Generator adversarial loss trend
    6. Combined discriminator loss analysis
    """
    plt.figure(figsize=(18, 12))
    
    # Main losses comparison
    plt.subplot(2, 3, 1)
    plt.plot(history['d_loss'], label='D Loss')
    plt.plot(history['g_loss'], label='G Loss')
    plt.title('Training Losses')
    plt.legend()
    
    # Discriminator loss breakdown
    plt.subplot(2, 3, 2)
    plt.plot(history['d_loss_real'], label='D Real')
    plt.plot(history['d_loss_fake'], label='D Fake')
    plt.plot(history['grad_penalty'], label='Gradient Penalty')
    plt.title('Discriminator Loss Components')
    plt.legend()
    
    # Generator loss breakdown
    plt.subplot(2, 3, 3)
    plt.plot(history['l1_dead'], label='L1 Dead')
    plt.plot(history['l1_alive'], label='L1 Alive')
    plt.plot(history['g_loss_gan'], label='GAN Loss')
    plt.title('Generator Loss Components')
    plt.legend()
    
    # Reconstruction quality comparison
    plt.subplot(2, 3, 4)
    plt.plot(history['l1_dead'], label='Dead')
    plt.plot(history['l1_alive'], label='Alive')
    plt.title('Reconstruction Losses')
    plt.legend()
    
    # Generator adversarial performance
    plt.subplot(2, 3, 5)
    plt.plot(history['g_loss_gan'], label='GAN Loss')
    plt.title('Generator Adversarial Loss')
    plt.legend()
    
    # Combined discriminator analysis
    plt.subplot(2, 3, 6)
    plt.plot([r + f for r, f in zip(history['d_loss_real'], history['d_loss_fake'])], 
             label='D Real+Fake')
    plt.plot(history['grad_penalty'], label='Gradient Penalty')
    plt.title('Discriminator Combined Loss')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(f'{log_dir}/final_training_plots.png')
    plt.close()


def build_and_compile_models():
    """
    Build and compile the complete cGAN model architecture.
    
    Constructs generator and discriminator models, configures their compilation
    settings, and creates a combined GAN model for end-to-end training.
    
    Returns:
        tuple: (generator, discriminator, gan) - All compiled and ready for training
        
    Model Architecture:
    - Generator: U-Net style encoder-decoder with dual outputs (dead/alive)
    - Discriminator: CNN classifier for real/fake image pair classification
    - Combined GAN: End-to-end model linking generator output to discriminator input
    
    Compilation Settings:
    - Optimizers: Adam with beta_1=0.5 (recommended for GAN training)
    - Learning rates: 2e-4 for stability
    - Loss functions: Configured for WGAN-GP training
    - Loss weights: [2.0, 2.0, 1.0] for MSE_dead, MSE_alive, GAN_loss
    
    Note:
        Discriminator is frozen during generator training phase via trainable=False.
        Actual WGAN-GP losses are implemented manually in the training loop.
    """
    # Build individual models
    generator = build_generator()
    discriminator = build_discriminator()

    # Verify model dtypes for debugging
    print("Generator input dtype:", generator.input.dtype)
    print("Discriminator input dtype:", discriminator.input.dtype)

    # Compile discriminator (loss is placeholder - WGAN-GP implemented manually)
    discriminator.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
        loss='mse',  # Placeholder for WGAN-GP manual implementation
    )

    # Create combined GAN model for generator training
    discriminator.trainable = False  # Freeze during generator training

    # Define GAN input and data flow
    gan_input_brightfield = Input(shape=(128, 128, 1), dtype=tf.float32, name='gan_input_brightfield')

    # Generator produces dual outputs
    gen_dead, gen_alive = generator(gan_input_brightfield)

    # Discriminator evaluates concatenated (brightfield, generated) pairs
    discriminator_input_for_gan = concatenate([gan_input_brightfield, gen_dead], axis=-1)
    gan_validity = discriminator(discriminator_input_for_gan)

    # Combined model with multiple outputs
    gan = Model(gan_input_brightfield, [gen_dead, gen_alive, gan_validity], name='gan_model')

    # Compile GAN with weighted losses
    gan.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
        loss=['mse', 'mse', 'mse'],  # Placeholder losses
        loss_weights=[2.0, 2.0, 1.0]  # MSE_dead, MSE_alive, GAN_loss weights
    )
    
    return generator, discriminator, gan
    

def main():
    """
    Main execution pipeline for the complete cGAN training workflow.
    
    Orchestrates the entire process from raw image organization through
    model training, including:
    1. Image organization and train/val/test splitting
    2. Dataset preprocessing and tile extraction
    3. Model building and compilation
    4. Training execution with monitoring
    
    Workflow:
    - Initialize ImageProcessor with source image directory
    - Organize images into channel-specific directory structure
    - Create data pipeline for tile extraction and batch preparation
    - Build and compile GAN models
    - Execute training with comprehensive monitoring
    
    Configuration:
    - Expects TIFF images in specified directory structure
    - Uses 128x128 tiles for training
    - Implements WGAN-GP training for 250 epochs
    - Saves models and visualizations automatically
    
    Note:
        Update image_path to point to your actual data directory.
        Adjust training parameters (epochs, batch_size) as needed.
    """
    # Initialize image processor with source directory
    image_path = r"YOUR_IMAGE_PATH\Images"
    processor = ImageProcessor(image_path)
    
    # Step 1: Organize images into train/val/test splits
    print("Organizing images...")
    channel_matrices = processor.organize_images()
    
    # Step 2: Prepare datasets with tile extraction and normalization
    print("Preparing datasets...")
    data_pipeline = cGANDataPipeline(channel_matrices, processor)
    
    # Step 3: Build and compile models
    print("Building models...")
    generator, discriminator, gan = build_and_compile_models()

    # Step 4: Initialize data loader
    data_gen = TileDataLoader("tiles/train", batch_size=16)

    # Verify data pipeline integrity
    sample_batch = next(iter(data_gen))
    for k, v in sample_batch.items():
        print(f"{k}: {v.dtype}")

    # Step 5: Execute training
    trained_generator, trained_discriminator, history = train(
         generator=generator,
         discriminator=discriminator,
         gan=gan,
         data_gen=data_gen,
         epochs=250,
         verbose=2,
         plot_interval=2
    )

    print("Training completed!")
    
    # Optional: Save final models
    # trained_generator.save("final_generator.h5")
    # trained_discriminator.save("final_discriminator.h5")

if __name__ == "__main__":
    main()
