# **Notebook Overview and Setup Guide**

Welcome to the interactive demonstration for the research paper: *"A Robust and Explainable Deep Learning Pipeline for Lung Cancer Classification: Integrating Transfer Learning, Ensemble Methods, and Interpretability with Systematic Debugging."*

This notebook allows you to step through the key components of the pipeline, including data loading and preprocessing, model definition, training, evaluation, and the generation of explainability reports.

## **Prerequisites and Setup**

### 1. Accessing the Project Files and Dataset
All necessary files for this project, including the dataset (IQ-OTH/NCCD Lung Cancer Dataset), Python scripts (`.py`), and this notebook (`.ipynb`), are contained within a shared Google Drive folder.

*   **Access Link:** [https://drive.google.com/drive/folders/1wNcH36vbfWjVZY8H9Zqg8RPvgYFKfxgM?usp=drive_link]
*   **Instructions for Use (Google Colab - Recommended):**
    1.  Open the shared Google Drive folder using the link above.
    2.  **Make a copy of the entire project folder** to your own Google Drive. This is important so you can run the notebook and save outputs without affecting the original shared files. Right-click the main project folder (e.g., `CNN_Medical_Imaging_Project`) and select "Make a copy."
    3.  Navigate to your copy of the folder in your Google Drive.
    4.  Open this notebook (`FINAL_CNN_PROJECT_DEMO.ipynb` or similar name) from your copied folder in Google Colab.
    5.  The notebook is designed to work with the project structure within this folder. The initial code cells will attempt to mount your Google Drive and set the project path to where it expects the files (e.g., `/content/drive/MyDrive/CNN_Medical_Imaging_Project/` if you copied it to the root of your "My Drive"). If you place your copied project folder elsewhere, you might need to adjust the `PROJECT_PATH_COLAB` variable in the first code cell.
    6.  Select a GPU runtime for faster execution: `Runtime -> Change runtime type -> Hardware accelerator -> GPU`.

### 2. Alternative Environments (Kaggle, Local)
While Google Colab is recommended due to the integrated Drive access and GPU availability:

*   **Kaggle Kernels:**
    1.  Download the entire project folder from the shared Google Drive link.
    2.  Upload the notebook, all `.py` script files, and the dataset (specifically the `data/raw/` contents) to a new Kaggle Kernel.
    3.  You will need to modify the `base_dir` variable in the final cell (where `main(args)` is called) to point to Kaggle's input paths (e.g., `../input/your-dataset-name/`).

*   **Local Environment (e.g., Jupyter Notebook/Lab):**
    1.  Download the entire project folder.
    2.  Ensure you have Python 3.x installed.
    3.  The notebook includes cells to install dependencies. You might prefer to set up a virtual environment (`venv` or `conda`) and install them there.
    4.  The `base_dir` in the final cell will likely default to a local path like `./CNN_Medical_Imaging_Project`. Ensure this matches your local project structure.

### 3. Dependencies
The notebook contains cells (`!pip install ...`) to install the necessary Python libraries. Key dependencies include:
`tensorflow`, `pandas`, `scikit-learn`, `matplotlib`, `seaborn`, `opencv-python`, `lime`, and `google-generativeai`.
These will be installed when you run those specific cells.

## **Running the Notebook**

1.  **Run Cells Sequentially:** Execute the cells in order from top to bottom.
2.  **Step 1 (Mount Drive & Setup):** This cell handles initial environment checks, Google Drive mounting (if in Colab), and adds the project path to `sys.path`.
3.  **Step 3 (Install Dependencies):** These cells install the required libraries.
4.  **Code Definition Cells (Steps 4-8):** These cells define the Python classes and functions (`DataProcessor`, model architectures, `ModelTrainer`, visualization functions, `LimeExplainer`) as described in the research paper. They essentially replicate the content of the `.py` files from your project folder.
5.  **Step 9 (main.py Logic & Execution):** This final large cell contains the `main.py` script's logic.
    *   **Configuration:** At the end of this cell, within the `if __name__ == "__main__":` block, you can find an `if is_notebook:` section. This is where you can easily modify arguments for different modes (`train`, `evaluate`, `predict`) and parameters (e.g., number of epochs, which models to train/use, whether to run XAI features).
    *   **Execution:** Running this cell will execute the `main(args)` function, thereby running the entire pipeline based on the configured arguments.

## **Important Notes**
*   **File Paths:** The notebook is primarily configured for Google Colab and expects the project structure from the shared Drive link. If running in other environments, path adjustments in the argument parsing section of the final cell might be necessary.
*   **Execution Time:** Training deep learning models (`--mode train`) can be very time-consuming, especially without a GPU. The `predict` and `evaluate` modes using pre-trained models (if available in your `models` directory) will be significantly faster.
*   **Gemini API Key:** For the AI-generated explanations (if `--run_gemini` is enabled in 'predict' mode), a valid Google AI API key is required. The notebook attempts to use a hardcoded key placeholder or one you might set. Ensure this is correctly configured in the "Gemini AI Setup" section within the final cell if you wish to use this feature.
*   **Output Files:** Logs, trained models, evaluation results (JSON), prediction CSVs, and report images will be saved into subdirectories (`logs/`, `models/`, `results/`, `results/prediction_reports/`) within your `base_dir` (which will be inside your copied project folder on Google Drive if using Colab).

---
Now, you can proceed with running the cells below.

# ***A Robust and Explainable Deep Learning Pipeline for Lung Cancer Classification: Integrating Transfer Learning, Ensemble Methods, and Interpretability with Systematic Debugging***

# **Step 1: Mount Drive**

In [1]:
import sys
import os
import logging
import psutil

# --- Basic Logging Setup ---
# Configure logging to show timestamps, module name, level, and message.
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

# --- Google Drive Mounting ---
IS_COLAB_ENVIRONMENT = 'google.colab' in sys.modules
PROJECT_PATH_COLAB = '/content/drive/MyDrive/CNN_Medical_Imaging_Project'
PROJECT_PATH_LOCAL = './CNN_Medical_Imaging_Project'

project_path = "" # Initialize project_path

if IS_COLAB_ENVIRONMENT:
    logger.info("Colab environment detected. Mounting Google Drive...")
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=False) # Mount drive
        project_path = PROJECT_PATH_COLAB
        logger.info("Google Drive mounted.")
    except Exception as e:
        logger.error(f"Failed to mount Google Drive: {e}. Halting.")
        raise
else:
    logger.info("Not in Colab. Skipping Drive mount.")
    project_path = PROJECT_PATH_LOCAL

# --- Project Directory Setup ---
try:
    os.makedirs(project_path, exist_ok=True)
    if project_path not in sys.path:
        sys.path.append(project_path)
    logger.info(f"Project path set to: {project_path}")
    logger.info(f"Project path added to sys.path: {project_path in sys.path}")
except OSError as e:
    logger.error(f"Error setting up project directory '{project_path}': {e}. Halting.")
    raise

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# **Step 2: Create Directories in Drive**


# **Step 3: Install Dependencies**

In [2]:
!pip install -q tensorflow pandas scikit-learn matplotlib seaborn opencv-python

In [3]:
!pip install -q lime # Had to be seperated due to conflicts

In [4]:
!pip3 install -qU google_genai # An attempt at enhancing explainability

In [5]:
# Standard library imports
import sys
import os
import logging
from pathlib import Path

# Third-party library imports
import tensorflow as tf
import numpy as np
import pandas as pd
import sklearn
import matplotlib
import seaborn as sns
import cv2
import lime

# --- Initialize Logger ---
if logging.getLogger().hasHandlers():
    logging.getLogger().handlers.clear()

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

# --- Environment and Path Configuration ---
IS_COLAB_ENVIRONMENT = 'google.colab' in sys.modules
DRIVE_BASE_PATH = Path('/content/drive')
PROJECT_DIR_NAME = 'CNN_Medical_Imaging_Project'
LOCAL_PROJECT_PATH = Path(f'./{PROJECT_DIR_NAME}')

project_path = None

if IS_COLAB_ENVIRONMENT:
    logger.info("Colab environment: Verifying Google Drive mount and project path.")
    google_drive_project_path = DRIVE_BASE_PATH / 'MyDrive' / PROJECT_DIR_NAME
    try:
        if not (DRIVE_BASE_PATH / 'MyDrive').exists(): # Checks if MyDrive is accessible
            from google.colab import drive
            logger.info("Attempting to mount Google Drive...")
            drive.mount(str(DRIVE_BASE_PATH)) # Mounts the base '/content/drive'
            logger.info("Google Drive mounted.")
        else:
            logger.info("Google Drive appears to be already mounted or accessible.")

        project_path = google_drive_project_path
    except Exception as e:
        logger.error(f"Error related to Google Drive in Colab: {e}. Using local fallback path.")
        project_path = LOCAL_PROJECT_PATH
else:
    logger.info("Non-Colab environment: Using local project path.")
    project_path = LOCAL_PROJECT_PATH

# Ensures project directory exists and add its absolute path to sys.path
try:
    project_path.mkdir(parents=True, exist_ok=True)
    project_path_abs_str = str(project_path.resolve())
    if project_path_abs_str not in sys.path:
        sys.path.append(project_path_abs_str)
        logger.info(f"Project path '{project_path_abs_str}' added to sys.path.")
    else:
        logger.info(f"Project path '{project_path_abs_str}' already in sys.path.")
    logger.info(f"Project directory verified: {project_path_abs_str}")
except Exception as e:
    logger.error(f"Failed to setup project path '{project_path}': {e}. Halting.")
    raise

# --- Library Version Verification ---
logger.info("--- Library Versions ---")
libraries_to_check = {
    "TensorFlow": tf, "NumPy": np, "Scikit-learn": sklearn,
    "Pandas": pd, "Seaborn": sns, "Matplotlib": matplotlib,
    "OpenCV": cv2, "LIME": lime
}
for name, lib in libraries_to_check.items():
    try:
        logger.info(f"{name} version: {lib.__version__}")
    except AttributeError:
        logger.warning(f"Could not retrieve version for {name} (possibly not imported or no __version__ attribute).")
    except Exception as e:
        logger.error(f"Error checking version for {name}: {e}")
logger.info("------------------------")

# --- TensorFlow GPU Sanity Check ---
logger.info("Performing TensorFlow GPU sanity check...")
try:
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        logger.info(f"GPUs available: {gpus}")
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logger.info("GPU memory growth enabled for all detected GPUs.")
        except RuntimeError as e:
            logger.warning(f"Could not set memory growth for GPUs (may be already initialized or in use): {e}")
    else:
        logger.info("No GPUs detected by TensorFlow. TensorFlow will run on CPU.")

    # Basic TensorFlow operation test
    tensor_a = tf.constant([1.0, 2.0], name="tensor_a")
    tensor_b = tf.constant([3.0, 4.0], name="tensor_b")
    tensor_c = tensor_a + tensor_b
    logger.info(f"TensorFlow test (tensor_a + tensor_b): {tensor_c.numpy()}")
    logger.info("TensorFlow basic operation successful.")

except Exception as e:
    logger.error(f"TensorFlow sanity check FAILED: {e}", exc_info=True)
    raise  # Critical failure if TensorFlow isn't working

logger.info("Setup, imports, path configuration, and verifications completed.")

2025-05-24 09:21:24,337 - __main__ - INFO - Colab environment: Verifying Google Drive mount and project path.
2025-05-24 09:21:24,344 - __main__ - INFO - Google Drive appears to be already mounted or accessible.
2025-05-24 09:21:24,347 - __main__ - INFO - Project path '/content/drive/MyDrive/CNN_Medical_Imaging_Project' already in sys.path.
2025-05-24 09:21:24,348 - __main__ - INFO - Project directory verified: /content/drive/MyDrive/CNN_Medical_Imaging_Project
2025-05-24 09:21:24,350 - __main__ - INFO - --- Library Versions ---
2025-05-24 09:21:24,351 - __main__ - INFO - TensorFlow version: 2.18.0
2025-05-24 09:21:24,352 - __main__ - INFO - NumPy version: 2.0.2
2025-05-24 09:21:24,355 - __main__ - INFO - Scikit-learn version: 1.6.1
2025-05-24 09:21:24,356 - __main__ - INFO - Pandas version: 2.2.2
2025-05-24 09:21:24,357 - __main__ - INFO - Seaborn version: 0.13.2
2025-05-24 09:21:24,360 - __main__ - INFO - Matplotlib version: 3.10.0
2025-05-24 09:21:24,363 - __main__ - INFO - OpenCV v

# **Step 4: data_processing.py: Load and Preprocess Data**

In [6]:
# Data is the lifeblood of any AI system.
# We're using the IQ-OTH/NCCD lung cancer dataset from Kaggle,
# containing CT scan slices categorized by experts as 'normal,'
# 'benign,' or 'malignant'. To create a more robust training set
# from this initially modest dataset, we employ data augmentation techniques.

import tensorflow as tf
import numpy as np
from pathlib import Path
import logging
from sklearn.utils import class_weight
import math
from typing import List, Tuple, Dict, Optional, Union

logger = logging.getLogger(__name__)

class DataProcessor:
    IMAGE_EXTENSIONS = ['*.jpg', '*.png', '*.jpeg'] # Defines supported image extensions

    def __init__(self,
                 target_size: Tuple[int, int] = (224, 224),
                 batch_size: int = 32,
                 seed: int = 42,
                 class_names: Optional[List[str]] = None):
        self.target_size = target_size
        self.batch_size = batch_size
        self.seed = seed

        if class_names is None:
            self.class_names = ['benign', 'malignant', 'normal']
        else:
            self.class_names = class_names

        self.num_classes = len(self.class_names)
        self.class_map = {name: i for i, name in enumerate(self.class_names)}

        logger.info(
            f"DataProcessor initialized: Target Size={self.target_size}, "
            f"Batch Size={self.batch_size}, Seed={self.seed}, "
            f"Classes={self.class_names}"
        )

    # Responsible for locating all our image files
    def _get_paths_and_labels(self, data_dir_str: str) -> Tuple[List[str], List[int]]:
        data_root = Path(data_dir_str)
        if not data_root.is_dir():
            logger.error(f"Data directory not found or is not a directory: {data_root}")
            raise FileNotFoundError(f"Directory not found: {data_root}")

        all_image_paths: List[str] = []
        all_image_labels: List[int] = []

        logger.info(f"Scanning for images in {data_root}...")
        for class_name in self.class_names:
            class_dir = data_root / class_name
            if not class_dir.is_dir():
                logger.warning(f"Class directory '{class_name}' not found in {data_root}. Skipping.")
                continue

            current_class_paths = [
                str(p) for ext in self.IMAGE_EXTENSIONS for p in class_dir.glob(ext)
            ]

            if not current_class_paths:
                logger.warning(f"No images with extensions {self.IMAGE_EXTENSIONS} found in {class_dir}")
                continue

            label = self.class_map[class_name]
            all_image_paths.extend(current_class_paths)
            all_image_labels.extend([label] * len(current_class_paths))
            logger.info(f"Found {len(current_class_paths)} images for class '{class_name}' (label {label}).")

        if not all_image_paths:
            logger.error(f"No images found in any class subdirectories of {data_root}.")
            raise ValueError(f"No images found in {data_root}")

        # Shuffle paths and labels consistently for reproducibility
        indices = np.arange(len(all_image_paths))
        np.random.seed(self.seed)
        np.random.shuffle(indices)

        all_image_paths = np.array(all_image_paths)[indices].tolist()
        all_image_labels = np.array(all_image_labels)[indices].tolist()

        logger.info(f"Found and shuffled {len(all_image_paths)} total images from {data_root}.")
        return all_image_paths, all_image_labels

    # Raw image files are not directly understood by neural networks. The _parse_image method converts them into a usable format
    def _parse_image(self, filename_tensor: tf.Tensor, label_tensor: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
        try:
            image_string = tf.io.read_file(filename_tensor) # loads the raw pixel data from the disk.
            image = tf.io.decode_image(image_string, channels=3, dtype=tf.float32, expand_animations=False) # transforms this raw data into a tensor. A tensor is a fundamental data structure in deep learning.
            # Pixel values in a typical image file range from 0 (black) to 255 (white), tf.io.decode_image with dtype=tf.float32 handles this, converting the 0-255 range to 0.0-1.0.
            image.set_shape([None, None, 3])
            image_resized = tf.image.resize(image, self.target_size) # The tf.image.resize function ensures every image is resized to our defined self.target_size, which is 224x224 pixels if you scroll up to the initialization of the class.
            image_resized.set_shape([self.target_size[0], self.target_size[1], 3])

            return image_resized, label_tensor
        except Exception as e:
            logger.error(f"Error parsing an image file (see stack trace for details): {e}", exc_info=True)
            raise

    # Creating More Data from Less
    def _augment(self, image: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
        """applies random transformations to the training images during the training process, creating slightly modified versions."""
        image = tf.image.random_flip_left_right(image, seed=self.seed)
        image = tf.image.random_brightness(image, max_delta=0.1, seed=self.seed)
        image = tf.clip_by_value(image, 0.0, 1.0) # Ensure pixel values stay in [0,1]
        return image, label

    def _configure_for_performance(self,
                                   dataset: tf.data.Dataset,
                                   shuffle: bool = False,
                                   shuffle_buffer_size: Optional[int] = None) -> tf.data.Dataset:
        if shuffle:
            buffer_to_use = shuffle_buffer_size if shuffle_buffer_size is not None else self.batch_size * 10
            if buffer_to_use <= 0: # Ensure buffer size is positive
                logger.warning(f"Shuffle buffer size was {buffer_to_use}, defaulting to 1000.")
                buffer_to_use = 1000
            dataset = dataset.shuffle(buffer_size=buffer_to_use, seed=self.seed, reshuffle_each_iteration=True)

        dataset = dataset.batch(self.batch_size)
        dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
        return dataset

    def create_dataset_for_splitting(self, data_dir_str: str) -> Tuple[Optional[tf.data.Dataset], List[str], List[int], int]:
        logger.info(f"Creating initial dataset for splitting from: {data_dir_str}")
        try:
            image_paths, labels = self._get_paths_and_labels(data_dir_str)
        except (FileNotFoundError, ValueError) as e:
            logger.error(f"Failed to get image paths and labels from '{data_dir_str}': {e}")
            return None, [], [], 0

        dataset_size = len(image_paths)
        if dataset_size == 0:
            # This case should be caught by _get_paths_and_labels, but as a safeguard:
            logger.error(f"No images found to create a dataset from '{data_dir_str}'.")
            return None, [], [], 0

        logger.info(f"Total image samples found for splitting: {dataset_size}")
        unique_labels, counts = np.unique(labels, return_counts=True)
        stats = {self.class_names[l]: c for l, c in zip(unique_labels, counts)}
        logger.info(f"Initial dataset statistics (before split): {stats}")

        dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
        dataset = dataset.map(self._parse_image, num_parallel_calls=tf.data.AUTOTUNE)

        logger.info("Successfully created intermediate tf.data.Dataset for splitting.")
        return dataset, image_paths, labels, dataset_size

    def split_and_configure_dataset(self,
                                    dataset: tf.data.Dataset,
                                    dataset_size: int,
                                    validation_split: float = 0.2
                                    ) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
        if not (0 < validation_split < 1):
            msg = f"validation_split ({validation_split}) must be between 0 and 1 (exclusive)."
            logger.error(msg)
            raise ValueError(msg)

        val_size = math.floor(dataset_size * validation_split)
        train_size = dataset_size - val_size
        logger.info(f"Splitting dataset: {train_size} training samples, {val_size} validation samples.")

        # Shuffle before splitting. A larger buffer provides better shuffling.
        # reshuffle_each_iteration=False as this is a one-time split.
        # Heuristic for initial shuffle buffer: 10% of dataset size, or at least 1000.
        initial_shuffle_buffer_size = max(1000, dataset_size // 10)
        dataset = dataset.shuffle(buffer_size=initial_shuffle_buffer_size, seed=self.seed, reshuffle_each_iteration=False)

        val_ds = dataset.take(val_size)
        train_ds = dataset.skip(val_size)

        # Apply augmentations only to the training set
        train_ds = train_ds.map(self._augment, num_parallel_calls=tf.data.AUTOTUNE)

        # Configure performance (batching, prefetching, and per-epoch shuffle for train_ds)
        # Use the same large buffer for the per-epoch shuffle of the training set
        train_ds = self._configure_for_performance(train_ds, shuffle=True, shuffle_buffer_size=initial_shuffle_buffer_size)
        logger.info("Configured training dataset with augmentations and shuffling.")

        val_ds = self._configure_for_performance(val_ds, shuffle=False) # No shuffle for validation
        logger.info("Configured validation dataset.")

        return train_ds, val_ds

    def create_test_dataset(self, data_dir_str: str) -> Tuple[Optional[tf.data.Dataset], List[str], Optional[List[int]]]:
        logger.info(f"Creating test dataset from: {data_dir_str}")
        data_root = Path(data_dir_str)
        if not data_root.is_dir():
            logger.error(f"Test data directory not found or is not a directory: {data_root}")
            raise FileNotFoundError(f"Directory not found: {data_root}")

        all_image_paths: List[str] = []
        all_image_labels: List[int] = [] # Will hold integer labels if found

        # 1. Attempt to load labeled data from class subdirectories
        logger.info(f"Attempting to load labeled test data from subdirectories of {data_root}...")
        for class_idx, class_name in enumerate(self.class_names):
            class_dir = data_root / class_name
            if class_dir.is_dir():
                current_class_paths = [
                    str(p) for ext in self.IMAGE_EXTENSIONS for p in class_dir.glob(ext)
                ]
                if current_class_paths:
                    all_image_paths.extend(current_class_paths)
                    all_image_labels.extend([class_idx] * len(current_class_paths))
                    logger.info(f"Found {len(current_class_paths)} test images for class '{class_name}'")

        # 2. If no labeled data found, try to load unlabeled data from the root directory
        if not all_image_paths:
            logger.info(f"No images found in class subdirectories of {data_root}. Checking root directory for unlabeled images.")
            root_image_paths = [
                str(p) for ext in self.IMAGE_EXTENSIONS for p in data_root.glob(ext)
            ]
            if root_image_paths:
                all_image_paths.extend(root_image_paths)
                # all_image_labels remains empty, signifying unlabeled data
                logger.info(f"Found {len(root_image_paths)} images directly in {data_root} (assuming unlabeled).")

        # 3. If still no images, raise error
        if not all_image_paths:
            logger.error(f"No images found in test directory '{data_root}' (neither in subdirs nor root).")
            raise ValueError(f"No images found in {data_root}")

        # Create the TensorFlow dataset
        final_labels_to_return: Optional[List[int]] = None
        if all_image_labels: # Labeled dataset for evaluation
            logger.info(f"Creating labeled test dataset with {len(all_image_paths)} images.")
            unique_labels_found, counts = np.unique(all_image_labels, return_counts=True)
            stats = {self.class_names[l]: c for l, c in zip(unique_labels_found, counts)}
            logger.info(f"Labeled test dataset statistics: {stats}")

            dataset = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))
            dataset = dataset.map(self._parse_image, num_parallel_calls=tf.data.AUTOTUNE)
            final_labels_to_return = all_image_labels
        else: # Unlabeled dataset for prediction
            logger.info(f"Creating unlabeled test dataset with {len(all_image_paths)} images.")
            # Dummy labels are needed for the initial from_tensor_slices structure
            dummy_labels = tf.zeros(len(all_image_paths), dtype=tf.int32)
            dataset = tf.data.Dataset.from_tensor_slices((all_image_paths, dummy_labels))
            # Map to keep only the image part after parsing
            dataset = dataset.map(lambda img_path, lbl: self._parse_image(img_path, lbl)[0],
                                  num_parallel_calls=tf.data.AUTOTUNE)
            # final_labels_to_return remains None

        # Configure performance (no shuffle or augmentation for test data)
        dataset = self._configure_for_performance(dataset, shuffle=False)
        logger.info("Successfully created and configured test tf.data.Dataset.")
        return dataset, all_image_paths, final_labels_to_return

    def get_class_weights(self, labels: List[int]) -> Optional[Dict[int, float]]:
        if not labels: # Handles None or empty list
            logger.warning("Cannot compute class weights: No labels provided.")
            return None

        try:
            int_labels = np.array(labels, dtype=int) # Ensure integer type
            if not np.issubdtype(int_labels.dtype, np.integer):
                 # This should ideally be caught by the previous line, but as a safeguard
                 raise TypeError(f"Labels must be integers; received type {np.array(labels).dtype}.")

            unique_labels_in_data = np.unique(int_labels)

            # Validate labels against the defined number of classes
            if np.any(unique_labels_in_data < 0) or np.any(unique_labels_in_data >= self.num_classes):
                logger.error(
                    f"Labels contain values outside expected range [0, {self.num_classes-1}]. "
                    f"Found: min={np.min(unique_labels_in_data)}, max={np.max(unique_labels_in_data)}. "
                    "Cannot compute weights."
                )
                return None

            if len(unique_labels_in_data) < self.num_classes:
                missing_classes_indices = set(range(self.num_classes)) - set(unique_labels_in_data)
                missing_classes_names = [self.class_names[i] for i in missing_classes_indices]
                logger.warning(
                    f"Training data contains only {len(unique_labels_in_data)} of {self.num_classes} classes. "
                    f"Present labels: {unique_labels_in_data}. Missing classes ({missing_classes_names}) will get default weight 1.0."
                )

            # Calculate weights using sklearn for the classes present in the data
            computed_weights_array = class_weight.compute_class_weight(
                class_weight='balanced',
                classes=unique_labels_in_data, # Only pass classes present in y
                y=int_labels
            )

            # Map computed weights to their class indices
            weights_for_present_classes = {
                int(cls_idx): float(weight)
                for cls_idx, weight in zip(unique_labels_in_data, computed_weights_array)
            }

            # Assign weights for all classes, defaulting to 1.0 for missing ones
            final_class_weights = {
                i: weights_for_present_classes.get(i, 1.0) for i in range(self.num_classes)
            }

            logger.info(f"Calculated class weights: {final_class_weights}")
            return final_class_weights

        except Exception as e:
            logger.error(f"Failed to compute class weights: {e}", exc_info=True)
            return None

# **Step 5: model_architecture.py : Model Creation Script**

In [7]:
# Imagine a neural network as a complex system of
# interconnected processing units, which we call 'neurons.'
# This structure is loosely inspired by the human brain.
# Each neuron receives signals (data), performs a simple
# calculation, and then passes its output to other neurons.
# When these neurons are arranged in layers, the network as
# a whole can learn to recognize very complex patterns from the input data.
# Instead of starting from zero, we leverage models (of neural networks(particularly convolutional))
# like EfficientNetB0, DenseNet121, InceptionV3, ResNet50V2, and MobileNetV2

import logging
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, applications, regularizers
from pathlib import Path
from typing import Tuple, Optional, Dict, List, Union, Callable

logger_ma = logging.getLogger(__name__)

from tensorflow.keras.applications import inception_v3

@tf.keras.utils.register_keras_serializable(package="Custom", name="inception_v3_preprocess_wrapper")
def inception_v3_preprocess_wrapper(x: tf.Tensor) -> tf.Tensor:
    x_scaled_to_255 = x * 255.0
    return inception_v3.preprocess_input(x_scaled_to_255)

from tensorflow.keras.applications import resnet_v2

@tf.keras.utils.register_keras_serializable(package="Custom", name="resnet_v2_preprocess_wrapper")
def resnet_v2_preprocess_wrapper(x: tf.Tensor) -> tf.Tensor:
    x_scaled_to_255 = x * 255.0
    return resnet_v2.preprocess_input(x_scaled_to_255)

@tf.keras.utils.register_keras_serializable(package="Custom", name="mobilenet_preprocess_wrapper")
def mobilenet_preprocess_wrapper(x: tf.Tensor) -> tf.Tensor:
    x_scaled_to_255 = x * 255.0
    return applications.mobilenet_v2.preprocess_input(x_scaled_to_255)

# =============================================================================
# Standardized Classifier Head
# =============================================================================

def build_standard_classifier_head(inputs: tf.Tensor,
                                   num_classes: int,
                                   base_model_name: str,
                                   dropout_rate: float = 0.4, # Standardized dropout
                                   l2_lambda: float = 0.001   # Standardized L2
                                   ) -> tf.Tensor:
    if not isinstance(num_classes, int) or num_classes <= 0:
        raise ValueError(f"num_classes must be a positive integer, got {num_classes}")
    if not (0.0 <= dropout_rate < 1.0):
        logger_ma.warning(f"Invalid dropout_rate ({dropout_rate}). Clamping to [0, 0.99].")
        dropout_rate = np.clip(dropout_rate, 0.0, 0.99)
    if l2_lambda < 0:
        logger_ma.warning(f"Invalid l2_lambda ({l2_lambda}). Disabling L2 regularization.")
        l2_lambda = 0.0

    prefix = f"{base_model_name}_head"
    regularizer = regularizers.l2(l2_lambda) if l2_lambda > 0 else None

    x = layers.GlobalAveragePooling2D(name=f'{prefix}_gap')(inputs)

    # First Dense Block
    x = layers.Dense(256, kernel_regularizer=regularizer, name=f'{prefix}_dense_1')(x)
    x = layers.BatchNormalization(name=f'{prefix}_bn_1')(x)
    x = layers.Activation('relu', name=f'{prefix}_relu_1')(x)
    x = layers.Dropout(dropout_rate, name=f'{prefix}_dropout_1')(x)

    # Second Dense Block
    x = layers.Dense(128, kernel_regularizer=regularizer, name=f'{prefix}_dense_2')(x)
    x = layers.BatchNormalization(name=f'{prefix}_bn_2')(x)
    x = layers.Activation('relu', name=f'{prefix}_relu_2')(x)
    x = layers.Dropout(max(0.0, dropout_rate * 0.75), name=f'{prefix}_dropout_2')(x)

    # Output layer
    logits = layers.Dense(num_classes, activation=None, name=f'{prefix}_logits')(x)
    outputs = layers.Activation('softmax', dtype='float32', name=f'{prefix}_softmax')(logits)

    logger_ma.debug(f"Classification head '{prefix}' built for {num_classes} classes.")
    return outputs

# =============================================================================
# Base Model Freezing Logic
# =============================================================================

def _freeze_base_model_layers(base_model: keras.Model,
                              freeze_ratio: float = 1.0,
                              fine_tune_from_layer: Optional[Union[str, int]] = None):
    if not isinstance(base_model, keras.Model):
        logger_ma.error(f"'{base_model.name}' is not a Keras Model. Cannot freeze layers.")
        return

    num_layers = len(base_model.layers)
    if num_layers == 0:
        logger_ma.warning(f"Base model '{base_model.name}' has no layers to freeze/unfreeze.")
        return

    if not (0.0 <= freeze_ratio <= 1.0):
        logger_ma.warning(f"Invalid freeze_ratio ({freeze_ratio}) for '{base_model.name}'. Clamping to 1.0 (freeze all).")
        freeze_ratio = 1.0

    # Initial pass based on freeze_ratio
    if freeze_ratio == 1.0:
        base_model.trainable = False
        logger_ma.info(f"Froze ALL layers in base model '{base_model.name}' (base_model.trainable = False).")
    elif freeze_ratio == 0.0:
        base_model.trainable = True # Make the container trainable
        for layer in base_model.layers: # Then iterate through its layers
            if not isinstance(layer, layers.BatchNormalization):
                layer.trainable = True
            else:
                layer.trainable = False # Keep BN frozen
        logger_ma.info(f"Unfroze ALL non-BN layers in base model '{base_model.name}'. BN layers remain frozen.")
    else: # Partial freeze based on ratio
        base_model.trainable = True # Make the container trainable
        freeze_until_index = int(num_layers * freeze_ratio)
        logger_ma.info(f"Partially freezing base model '{base_model.name}': "
                       f"First {freeze_until_index}/{num_layers} layers ({freeze_ratio*100:.1f}%) targeted for freezing.")
        for i, layer in enumerate(base_model.layers):
            if i < freeze_until_index:
                layer.trainable = False
            else:
                if not isinstance(layer, layers.BatchNormalization):
                    layer.trainable = True
                else:
                    layer.trainable = False # Keep BN frozen

    # Second pass for fine_tune_from_layer, if specified
    if fine_tune_from_layer is not None:
        base_model.trainable = True # Ensure container is trainable for fine-tuning
        unfreeze_start_index = -1

        if isinstance(fine_tune_from_layer, str):
            try:
                layer_names = [l.name for l in base_model.layers]
                unfreeze_start_index = layer_names.index(fine_tune_from_layer)
            except ValueError:
                logger_ma.error(f"Layer name '{fine_tune_from_layer}' not found in base model '{base_model.name}'. "
                                f"Fine-tuning unfreeze rule not applied. Current state based on freeze_ratio stands.")
                fine_tune_from_layer = None # Invalidate to skip further processing
        elif isinstance(fine_tune_from_layer, int):
            if 0 <= fine_tune_from_layer < num_layers:
                unfreeze_start_index = fine_tune_from_layer
            else:
                logger_ma.error(f"Invalid fine_tune_from_layer index {fine_tune_from_layer} (max: {num_layers-1}). "
                                f"Fine-tuning unfreeze rule not applied. Current state based on freeze_ratio stands.")
                fine_tune_from_layer = None # Invalidate
        else:
            logger_ma.error(f"Invalid type for fine_tune_from_layer: {type(fine_tune_from_layer)}. "
                            f"Fine-tuning unfreeze rule not applied.")
            fine_tune_from_layer = None # Invalidate

        if fine_tune_from_layer is not None and unfreeze_start_index != -1:
            logger_ma.info(f"Fine-tuning '{base_model.name}': Unfreezing layers from index {unfreeze_start_index} ('{base_model.layers[unfreeze_start_index].name}') onwards.")
            for i, layer in enumerate(base_model.layers):
                if i >= unfreeze_start_index:
                    if not isinstance(layer, layers.BatchNormalization):
                        layer.trainable = True
                    else:
                        layer.trainable = False # BN layers remain frozen

    trainable_layers_count = sum(1 for layer in base_model.layers if layer.trainable and layer.weights)
    logger_ma.info(f"Base model '{base_model.name}' final state: {trainable_layers_count}/{len(base_model.weights)} trainable weights/layers with weights.")


# =============================================================================
# Specific Model Creation Functions
# =============================================================================

def _create_base_model(model_name: str,
                       input_shape: Tuple[int, int, int],
                       weights_source: Optional[Union[str, Path]] = 'imagenet',
                       input_tensor: Optional[tf.Tensor] = None
                       ) -> keras.Model:
    logger_ma.info(f"Loading base model structure: {model_name}")
    logger_ma.info(f"  Input Shape: {input_shape}, Weights: {weights_source}")

    model_fn_map = {
        'EfficientNetB0': applications.EfficientNetB0,
        'DenseNet121': applications.DenseNet121,
        'MobileNetV2': applications.MobileNetV2,
        'ResNet50V2': applications.ResNet50V2,
        'InceptionV3': applications.InceptionV3,
    }

    if model_name not in model_fn_map:
        raise ValueError(f"Unsupported base model_name: {model_name}. Supported: {list(model_fn_map.keys())}")

    model_application = model_fn_map[model_name]
    actual_weights_arg = 'imagenet'
    load_local_weights_path = None

    if isinstance(weights_source, (str, Path)):
        if str(weights_source).lower() == 'none':
            actual_weights_arg = None
        elif str(weights_source).lower() != 'imagenet':
            weights_path_obj = Path(weights_source)
            if weights_path_obj.is_file():
                actual_weights_arg = None # Load structure first, then weights
                load_local_weights_path = weights_path_obj
                logger_ma.info(f"  Will load local weights from: {load_local_weights_path}")
            else:
                logger_ma.warning(f"  Local weights path '{weights_source}' not found. Defaulting to ImageNet.")
                actual_weights_arg = 'imagenet'

    if input_shape[-1] != 3 and actual_weights_arg == 'imagenet':
         error_msg = (f"Input shape {input_shape} has {input_shape[-1]} channels, but ImageNet weights "
                      f"for {model_name} require 3 channels.")
         logger_ma.error(error_msg)
         raise ValueError(error_msg)

    try:
        base = model_application(
            input_shape=input_shape,
            include_top=False,
            weights=actual_weights_arg,
            input_tensor=input_tensor
        )
        logger_ma.info(f"SUCCESS: Instantiated base structure for {model_name}.")

        if load_local_weights_path:
            logger_ma.info(f"Attempting to load local weights into {model_name} from {load_local_weights_path}")
            try:
                base.load_weights(str(load_local_weights_path), by_name=True, skip_mismatch=True)
                logger_ma.info(f"  SUCCESS: Loaded local weights into {model_name}.")
            except Exception as e_load:
                logger_ma.error(f"  FAILED to load local weights for {model_name}: {e_load}. Model has '{actual_weights_arg}' weights.", exc_info=True)
                raise RuntimeError(f"Failed to load critical local weights for {model_name}") from e_load
        return base
    except Exception as e:
        logger_ma.error(f"FAILED to instantiate base model {model_name}. Error: {e}", exc_info=True)
        raise RuntimeError(f"Base model instantiation failed for {model_name}") from e


def create_efficientnet_b0_classifier(input_shape: Tuple[int, int, int],
                                      num_classes: int,
                                      initial_freeze_ratio: float = 1.0,
                                      weights_path: Optional[Union[str, Path]] = 'imagenet',
                                      dropout_rate: float = 0.4,
                                      l2_lambda: float = 0.001,
                                      fine_tune_from_layer: Optional[Union[str, int]] = None
                                      ) -> keras.Model:
    model_key = "efficientnet_b0"
    logger_ma.info(f"--- Creating {model_key} Classifier ---")

    image_input = layers.Input(shape=input_shape, name=f'{model_key}_input')

    base_model = _create_base_model(
        model_name='EfficientNetB0',
        input_shape=input_shape,
        weights_source=weights_path,
        input_tensor=image_input
    )
    base_model._name = f"{model_key}_base"

    _freeze_base_model_layers(base_model, initial_freeze_ratio, fine_tune_from_layer)

    classifier_outputs = build_standard_classifier_head(
        inputs=base_model.output,
        num_classes=num_classes,
        base_model_name=model_key,
        dropout_rate=dropout_rate,
        l2_lambda=l2_lambda
    )

    model = models.Model(inputs=image_input, outputs=classifier_outputs, name=f"{model_key}_classifier")
    logger_ma.info(f"Successfully created model: '{model.name}'")
    return model


def create_densenet121_classifier(input_shape: Tuple[int, int, int],
                                  num_classes: int,
                                  trainable_base: bool = False,
                                  weights_path: Optional[Union[str, Path]] = 'imagenet',
                                  dropout_rate: float = 0.4,
                                  l2_lambda: float = 0.001,
                                  fine_tune_from_layer: Optional[Union[str, int]] = None
                                 ) -> keras.Model:
    model_key = "densenet121"
    logger_ma.info(f"--- Creating {model_key} Classifier ---")

    image_input = layers.Input(shape=input_shape, name=f'{model_key}_input')

    base_model = _create_base_model(
        model_name='DenseNet121',
        input_shape=input_shape,
        weights_source=weights_path,
        input_tensor=image_input
    )
    base_model._name = f"{model_key}_base"

    initial_freeze_ratio = 1.0 if not trainable_base else 0.0
    _freeze_base_model_layers(base_model, initial_freeze_ratio, fine_tune_from_layer)

    classifier_outputs = build_standard_classifier_head(
        inputs=base_model.output,
        num_classes=num_classes,
        base_model_name=model_key,
        dropout_rate=dropout_rate,
        l2_lambda=l2_lambda
    )

    model = models.Model(inputs=image_input, outputs=classifier_outputs, name=f"{model_key}_classifier")
    logger_ma.info(f"Successfully created model: '{model.name}'")
    return model


def create_mobilenetv2_classifier(input_shape: Tuple[int, int, int],
                                  num_classes: int,
                                  trainable_base: bool = False,
                                  weights_path: Optional[Union[str, Path]] = 'imagenet',
                                  dropout_rate: float = 0.4,
                                  l2_lambda: float = 0.001,
                                  fine_tune_from_layer: Optional[Union[str, int]] = None
                                 ) -> keras.Model:
    model_key = "mobilenetv2"
    logger_ma.info(f"--- Creating {model_key} Classifier ---")

    image_input = layers.Input(shape=input_shape, name=f'{model_key}_input')
    preprocessed_inputs = layers.Lambda(
        mobilenet_preprocess_wrapper,
        name=f'{model_key}_preprocessing'
    )(image_input)

    base_model = _create_base_model(
        model_name='MobileNetV2',
        input_shape=input_shape,
        weights_source=weights_path,
        input_tensor=preprocessed_inputs
    )
    base_model._name = f"{model_key}_base"

    initial_freeze_ratio = 1.0 if not trainable_base else 0.0
    _freeze_base_model_layers(base_model, initial_freeze_ratio, fine_tune_from_layer)

    classifier_outputs = build_standard_classifier_head(
        inputs=base_model.output,
        num_classes=num_classes,
        base_model_name=model_key,
        dropout_rate=dropout_rate,
        l2_lambda=l2_lambda
    )

    model = models.Model(inputs=image_input, outputs=classifier_outputs, name=f"{model_key}_classifier")
    logger_ma.info(f"Successfully created model: '{model.name}'")
    return model

def create_resnet50v2_classifier(input_shape: Tuple[int, int, int],
                                 num_classes: int,
                                 trainable_base: bool = False,
                                 weights_path: Optional[Union[str, Path]] = 'imagenet',
                                 dropout_rate: float = 0.4,
                                 l2_lambda: float = 0.001,
                                 fine_tune_from_layer: Optional[Union[str, int]] = None
                                ) -> keras.Model:
    model_key = "resnet50v2"
    logger_ma.info(f"--- Creating {model_key} Classifier ---")

    image_input = layers.Input(shape=input_shape, name=f'{model_key}_input')
    preprocessed_inputs = layers.Lambda(
        resnet_v2_preprocess_wrapper,
        name=f'{model_key}_preprocessing'
    )(image_input)

    base_model = _create_base_model(
        model_name='ResNet50V2',
        input_shape=input_shape,
        weights_source=weights_path,
        input_tensor=preprocessed_inputs
    )
    base_model._name = f"{model_key}_base"

    initial_freeze_ratio = 1.0 if not trainable_base else 0.0
    _freeze_base_model_layers(base_model, initial_freeze_ratio, fine_tune_from_layer)

    classifier_outputs = build_standard_classifier_head(
        inputs=base_model.output,
        num_classes=num_classes,
        base_model_name=model_key,
        dropout_rate=dropout_rate,
        l2_lambda=l2_lambda
    )

    model = models.Model(inputs=image_input, outputs=classifier_outputs, name=f"{model_key}_classifier")
    logger_ma.info(f"Successfully created model: '{model.name}'")
    return model

def create_inceptionv3_classifier(input_shape: Tuple[int, int, int],
                                  num_classes: int,
                                  trainable_base: bool = False,
                                  weights_path: Optional[Union[str, Path]] = 'imagenet',
                                  dropout_rate: float = 0.4,
                                  l2_lambda: float = 0.001,
                                  fine_tune_from_layer: Optional[Union[str, int]] = None
                                 ) -> keras.Model:
    model_key = "inceptionv3"
    logger_ma.info(f"--- Creating {model_key} Classifier ---")

    image_input = layers.Input(shape=input_shape, name=f'{model_key}_input')

    # Apply InceptionV3 specific preprocessing
    preprocessed_inputs = layers.Lambda(
        inception_v3_preprocess_wrapper,
        name=f'{model_key}_preprocessing'
    )(image_input)

    base_model = _create_base_model(
        model_name='InceptionV3',
        input_shape=input_shape,
        weights_source=weights_path,
        input_tensor=preprocessed_inputs
    )
    base_model._name = f"{model_key}_base"

    initial_freeze_ratio = 1.0 if not trainable_base else 0.0
    _freeze_base_model_layers(base_model, initial_freeze_ratio, fine_tune_from_layer)

    classifier_outputs = build_standard_classifier_head(
        inputs=base_model.output,
        num_classes=num_classes,
        base_model_name=model_key,
        dropout_rate=dropout_rate,
        l2_lambda=l2_lambda
    )

    model = models.Model(inputs=image_input, outputs=classifier_outputs, name=f"{model_key}_classifier")
    logger_ma.info(f"Successfully created model: '{model.name}'")
    return model

# =============================================================================
# Ensemble Configuration
# =============================================================================
ENSEMBLE_MODEL_FUNCTIONS: Dict[str, Callable[..., keras.Model]] = {
    'efficientnet_b0': create_efficientnet_b0_classifier,
    'densenet121': create_densenet121_classifier,
    'mobilenetv2': create_mobilenetv2_classifier,
    'resnet50v2': create_resnet50v2_classifier,
    'inceptionv3': create_inceptionv3_classifier,
}

# =============================================================================
# Script Self-Test Block
# =============================================================================
if __name__ == '__main__':
    # Ensure basicConfig is called only once if running standalone for testing
    if not logging.getLogger().hasHandlers() or not any(isinstance(h, logging.StreamHandler) for h in logging.getLogger().handlers):
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[logging.StreamHandler(sys.stdout)] # Ensure output to console
        )

    logger_ma.info("model_architecture.py executed as script for testing/debugging.")

    test_input_shape = (224, 224, 3)
    test_num_classes = 3

    def test_model_creation(model_key: str, creation_func: Callable):
        logger_ma.info(f"\n--- Testing: {model_key} ---")
        try:
            # Using default parameters for simplicity in test
            model = creation_func(
                input_shape=test_input_shape,
                num_classes=test_num_classes
            )
            model.summary(print_fn=logger_ma.info)
            logger_ma.info(f"SUCCESS: {model.name} created.")

            # Test freezing logic on the base model part
            # Heuristic to find the base model layer within the created classifier
            base_model_component = None
            for layer in model.layers:
                if isinstance(layer, keras.Model) and "base" in layer.name:
                    base_model_component = layer
                    break

            if base_model_component:
                logger_ma.info(f"Testing freezing logic on base component: {base_model_component.name}")
                _freeze_base_model_layers(base_model_component, freeze_ratio=0.8) # Partial freeze
                _freeze_base_model_layers(base_model_component, freeze_ratio=0.0) # Unfreeze all (except BN)
                _freeze_base_model_layers(base_model_component, freeze_ratio=1.0) # Freeze all
                _freeze_base_model_layers(base_model_component, freeze_ratio=0.5, fine_tune_from_layer=-10) # Unfreeze last 10 layers
            else:
                logger_ma.warning(f"Could not find a distinct 'base' Keras Model layer within {model.name} for detailed freeze test.")

        except Exception as ex:
            logger_ma.error(f"Error during {model_key} creation or test: {ex}", exc_info=True)

    for key, func in ENSEMBLE_MODEL_FUNCTIONS.items():
        test_model_creation(key, func)

    logger_ma.info("\nAll model creation tests completed.")

2025-05-24 09:21:24,623 - __main__ - INFO - model_architecture.py executed as script for testing/debugging.
2025-05-24 09:21:24,624 - __main__ - INFO - 
--- Testing: efficientnet_b0 ---
2025-05-24 09:21:24,628 - __main__ - INFO - --- Creating efficientnet_b0 Classifier ---
2025-05-24 09:21:24,641 - __main__ - INFO - Loading base model structure: EfficientNetB0
2025-05-24 09:21:24,642 - __main__ - INFO -   Input Shape: (224, 224, 3), Weights: imagenet
2025-05-24 09:21:26,651 - __main__ - INFO - SUCCESS: Instantiated base structure for EfficientNetB0.
2025-05-24 09:21:26,656 - __main__ - INFO - Froze ALL layers in base model 'efficientnetb0' (base_model.trainable = False).
2025-05-24 09:21:26,659 - __main__ - INFO - Base model 'efficientnetb0' final state: 0/312 trainable weights/layers with weights.
2025-05-24 09:21:26,773 - __main__ - INFO - Successfully created model: 'efficientnet_b0_classifier'


2025-05-24 09:21:27,252 - __main__ - INFO - Model: "efficientnet_b0_classifier"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃    Param # ┃ Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ efficientnet_b0_in… │ (None, 224, 224,  │          0 │ -                 │
│ (InputLayer)        │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ rescaling           │ (None, 224, 224,  │          0 │ efficientnet_b0_… │
│ (Rescaling)         │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization       │ (None, 224, 224,  │          7 │ rescaling[0][0]   │
│ (Normalization)     │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼─────────────────

2025-05-24 09:21:37,357 - __main__ - INFO - Model: "densenet121_classifier"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃    Param # ┃ Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ densenet121_input   │ (None, 224, 224,  │          0 │ -                 │
│ (InputLayer)        │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ zero_padding2d      │ (None, 230, 230,  │          0 │ densenet121_inpu… │
│ (ZeroPadding2D)     │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv1_conv (Conv2D) │ (None, 112, 112,  │      9,408 │ zero_padding2d[0… │
│                     │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤


2025-05-24 09:21:39,724 - __main__ - INFO - Model: "mobilenetv2_classifier"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃    Param # ┃ Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ mobilenetv2_input   │ (None, 224, 224,  │          0 │ -                 │
│ (InputLayer)        │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ mobilenetv2_prepro… │ (None, 224, 224,  │          0 │ mobilenetv2_inpu… │
│ (Lambda)            │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ Conv1 (Conv2D)      │ (None, 112, 112,  │        864 │ mobilenetv2_prep… │
│                     │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤


2025-05-24 09:21:43,692 - __main__ - INFO - Model: "resnet50v2_classifier"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃    Param # ┃ Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ resnet50v2_input    │ (None, 224, 224,  │          0 │ -                 │
│ (InputLayer)        │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ resnet50v2_preproc… │ (None, 224, 224,  │          0 │ resnet50v2_input… │
│ (Lambda)            │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv1_pad           │ (None, 230, 230,  │          0 │ resnet50v2_prepr… │
│ (ZeroPadding2D)     │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│

2025-05-24 09:21:52,477 - __main__ - INFO - Model: "inceptionv3_classifier"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃    Param # ┃ Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ inceptionv3_input   │ (None, 224, 224,  │          0 │ -                 │
│ (InputLayer)        │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ inceptionv3_prepro… │ (None, 224, 224,  │          0 │ inceptionv3_inpu… │
│ (Lambda)            │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d (Conv2D)     │ (None, 111, 111,  │        864 │ inceptionv3_prep… │
│                     │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤


# **Step 6: visualisation.py**

In [8]:
# Grad-CAM provides a visual explanation by producing a heatmap
# that we can overlay onto the original input image. This heatmap
# highlights the regions in the image that were most influential
# or 'important' to the neural network when it made its prediction
# for a specific class. Areas with 'hotter' colors (e.g., red, orange)
# on the heatmap indicate regions that strongly contributed to the model's decision for that class.
# visualization.py (Ensure this is the content of your Step 6 cell)

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
import logging
from pathlib import Path
from typing import Optional, Tuple, List, Any, Union

logger = logging.getLogger(__name__)

def find_suitable_gradcam_layer(model: tf.keras.Model, model_key: Optional[str] = None) -> Optional[str]: # <--- CONFIRM THIS LINE
    """
    Finds a suitable convolutional layer for Grad-CAM.
    Prioritizes known good layers if model_key is provided.
    Searches within nested base models if applicable.
    """
    known_good_layers = {
        "efficientnet_b0": "top_activation",
        "densenet121": "relu",
        "mobilenetv2": "out_relu",
        "resnet50v2": "conv5_block3_out",
        "inceptionv3": "mixed10"
    }

    base_model_component = None
    for layer_outer in model.layers: # Use a different variable name here
        if isinstance(layer_outer, tf.keras.Model) and "base" in layer_outer.name.lower():
            base_model_component = layer_outer
            logger.debug(f"Grad-CAM: Found nested base model component: {base_model_component.name}")
            break

    target_model_for_search = base_model_component if base_model_component else model
    # Construct prefix carefully, only if base_model_component is truly a part of the 'model' passed to this function
    # and not the model itself.
    prefix = ""
    if base_model_component and base_model_component != model : # Check if it's a distinct nested model
        # Verify if the base_model_component is indeed a layer within the top-level model
        is_nested_layer = any(base_model_component.name == l.name for l in model.layers if isinstance(l, tf.keras.Model))
        if is_nested_layer:
             prefix = f"{base_model_component.name}/"


    if model_key and model_key in known_good_layers:
        potential_layer_name_in_base = known_good_layers[model_key]

        # Try to access the layer within the target_model_for_search (which could be the base model)
        try:
            target_model_for_search.get_layer(potential_layer_name_in_base)
            # If found, construct the full path using the prefix if target_model_for_search was a nested base
            full_path_name = f"{prefix}{potential_layer_name_in_base}"
            logger.info(f"Grad-CAM: Using known good layer '{potential_layer_name_in_base}' within '{target_model_for_search.name}'. Full path for model.get_layer: '{full_path_name}'")
            return full_path_name
        except ValueError:
            logger.warning(f"Grad-CAM: Known good layer '{potential_layer_name_in_base}' for '{model_key}' not found directly in '{target_model_for_search.name}'. Trying absolute in top model.")
            # If not in base, try if it's an absolute name in the top model (less likely for base model layers)
            try:
                model.get_layer(potential_layer_name_in_base)
                logger.info(f"Grad-CAM: Using known good layer '{potential_layer_name_in_base}' (absolute name) in '{model.name}'.")
                return potential_layer_name_in_base
            except ValueError:
                logger.warning(f"Grad-CAM: Known good layer '{potential_layer_name_in_base}' also not found as absolute in '{model.name}'. Falling back to auto-detection.")


    suitable_layer_name_auto = None
    for layer in reversed(target_model_for_search.layers): # Search in the determined target_model_for_search
        if hasattr(layer, 'output_shape') and \
           isinstance(layer.output_shape, (list, tuple)) and \
           len(layer.output_shape) == 4 and \
           all(isinstance(dim, int) and dim > 0 for dim in layer.output_shape[1:] if dim is not None):

            if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.DepthwiseConv2D)):
                suitable_layer_name_auto = layer.name
                logger.debug(f"Grad-CAM (auto): Selected Conv2D/DepthwiseConv2D layer: {layer.name} in {target_model_for_search.name}")
                break
            if suitable_layer_name_auto is None and isinstance(layer, tf.keras.layers.Activation):
                if "relu" in str(layer.activation).lower() or "swish" in str(layer.activation).lower() or layer.name.endswith("_relu") or layer.name.endswith("_activation"):
                    suitable_layer_name_auto = layer.name
                    logger.debug(f"Grad-CAM (auto): Tentatively selected Activation layer: {layer.name} in {target_model_for_search.name}")
            if suitable_layer_name_auto is None and isinstance(layer, tf.keras.layers.Add):
                suitable_layer_name_auto = layer.name
                logger.debug(f"Grad-CAM (auto): Tentatively selected Add layer: {layer.name} in {target_model_for_search.name}")

    if suitable_layer_name_auto:
        full_auto_name = f"{prefix}{suitable_layer_name_auto}"
        logger.info(f"Grad-CAM: Final auto-selected layer: '{suitable_layer_name_auto}' within '{target_model_for_search.name}'. Full path for model.get_layer: '{full_auto_name}'")
        return full_auto_name
    else:
        logger.warning(f"Grad-CAM: Could not automatically find a suitable 4D output layer in '{target_model_for_search.name}'.")
        return None


def make_gradcam_heatmap(img_array_batch: np.ndarray,
                         model: tf.keras.Model,
                         last_conv_layer_name: str,
                         pred_index: Optional[int] = None
                         ) -> Optional[np.ndarray]:
    if not last_conv_layer_name:
        logger.error("Grad-CAM: No convolutional layer name provided.")
        return None
    if img_array_batch.ndim != 4 or img_array_batch.shape[0] != 1:
        logger.error(f"Grad-CAM: Requires a batch of 1 image (shape (1,H,W,C)), but got shape {img_array_batch.shape}")
        return None

    try:
        # This get_layer needs to access the layer within the full 'model' context
        last_conv_layer = model.get_layer(last_conv_layer_name)
    except ValueError:
        logger.error(f"Grad-CAM: Layer '{last_conv_layer_name}' not found in model '{model.name}'. Available layers (top level): {[l.name for l in model.layers]}")
        if '/' in last_conv_layer_name:
            parent_name, child_name = last_conv_layer_name.split('/',1)
            try:
                parent_model = model.get_layer(parent_name)
                if isinstance(parent_model, tf.keras.Model):
                    logger.error(f"  Layers in nested model '{parent_name}': {[l.name for l in parent_model.layers]}")
            except ValueError:
                pass
        return None

    model_key_inferred = model.name.split('_classifier')[0]
    potential_logits_layer_name = f"{model_key_inferred}_head_logits"
    logits_layer = None

    try:
        logits_layer = model.get_layer(potential_logits_layer_name)
        logger.info(f"Grad-CAM: Using identified logits layer: '{logits_layer.name}' for gradient calculation.")
    except ValueError:
        logger.warning(f"Grad-CAM: Logits layer '{potential_logits_layer_name}' not found. Falling back to model.output.")
        if isinstance(model.layers[-1], tf.keras.layers.Activation) and \
           model.layers[-1].activation == tf.keras.activations.softmax and \
           len(model.layers) > 1 and isinstance(model.layers[-2], tf.keras.layers.Dense):
            logits_layer = model.layers[-2]
            logger.warning(f"Grad-CAM: Using layer '{logits_layer.name}' (assumed pre-softmax) as logits layer.")
        else:
            logits_layer = model.layers[-1]
            logger.warning(f"Grad-CAM: Could not reliably identify a pre-softmax logits layer. Using model's final output layer '{logits_layer.name}'. Results might be suboptimal if it's a softmax.")

    classifier_output_tensor = logits_layer.output

    grad_model = tf.keras.models.Model(
        inputs=model.inputs,
        outputs=[last_conv_layer.output, classifier_output_tensor]
    )

    with tf.GradientTape() as tape:
        inputs_for_grad = tf.cast(img_array_batch, tf.float32)
        last_conv_layer_output, preds_logits = grad_model(inputs_for_grad)

        if pred_index is None:
            pred_index = tf.argmax(preds_logits[0])
        class_channel = preds_logits[:, pred_index]

    grads = tape.gradient(class_channel, last_conv_layer_output)
    if grads is None:
        logger.error(f"Grad-CAM: Gradient computation returned None for layer '{last_conv_layer.name}'.")
        return None

    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    last_conv_layer_output_squeezed = last_conv_layer_output[0]
    heatmap = last_conv_layer_output_squeezed @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    heatmap = tf.maximum(heatmap, 0)

    epsilon = tf.keras.backend.epsilon()
    max_val = tf.math.reduce_max(heatmap)
    if tf.abs(max_val) < epsilon:
        logger.warning(f"Grad-CAM: Heatmap for layer '{last_conv_layer.name}' is all zeros or near-zero. Predicted class index: {pred_index}.")
        return np.zeros(heatmap.shape, dtype=np.float32)

    heatmap = heatmap / max_val
    return heatmap.numpy()


def display_gradcam(original_img_normalized: np.ndarray,
                    heatmap: Optional[np.ndarray],
                    alpha: float = 0.5,
                    colormap_name: str = "jet"
                    ) -> Optional[np.ndarray]:
    if heatmap is None:
        logger.warning("Grad-CAM display: Heatmap is None, returning original image converted to uint8.")
        return np.clip(original_img_normalized * 255, 0, 255).astype(np.uint8)

    if not (original_img_normalized.ndim == 3 and original_img_normalized.shape[-1] == 3):
        logger.error(f"Grad-CAM display: Invalid original image shape {original_img_normalized.shape}. Expected (H,W,3).")
        return None
    if not (heatmap.ndim == 2):
        logger.error(f"Grad-CAM display: Invalid heatmap shape {heatmap.shape}. Expected 2D.")
        return None

    try:
        heatmap_resized = tf.image.resize(
            heatmap[..., tf.newaxis],
            [original_img_normalized.shape[0], original_img_normalized.shape[1]]
        ).numpy()
        heatmap_resized = np.squeeze(heatmap_resized)
        heatmap_uint8 = np.uint8(255 * heatmap_resized)
        colormap = plt.colormaps.get_cmap(colormap_name)
        colored_heatmap_rgb = colormap(heatmap_uint8)[:, :, :3]
        original_img_uint8 = np.uint8(255 * original_img_normalized)
        superimposed_img_float = (colored_heatmap_rgb * 255 * alpha) + (original_img_uint8 * (1 - alpha))
        superimposed_img_uint8 = np.clip(superimposed_img_float, 0, 255).astype(np.uint8)
        return superimposed_img_uint8
    except Exception as e:
        logger.error(f"Grad-CAM display: Error during heatmap overlay: {e}", exc_info=True)
        return np.clip(original_img_normalized * 255, 0, 255).astype(np.uint8)

# **Step 7: explainability.py**

In [9]:
# While Grad-CAM gives a somewhat broad heatmap,
# LIME provides a more granular explanation for a single,
# specific image prediction. It identifies which distinct
# segments or superpixels (small, perceptually meaningful patches of pixels)
# of that particular image were most influential in pushing the model towards
# its prediction for a given class

import numpy as np
import matplotlib.pyplot as plt
from lime import lime_image
from skimage.segmentation import mark_boundaries
import logging
from pathlib import Path
import tensorflow as tf
from typing import List, Tuple, Optional, Any, Dict, Union, Callable

logger = logging.getLogger(__name__)

class LimeExplainer:
    def __init__(self, model: tf.keras.Model, class_names: List[str]):
        if not isinstance(model, tf.keras.Model):
            raise TypeError("Input 'model' must be a tf.keras.Model instance.")
        if not (isinstance(class_names, list) and all(isinstance(name, str) for name in class_names)):
            raise TypeError("'class_names' must be a list of strings.")

        self.model = model
        self.class_names = class_names
        self.num_classes = len(class_names)

        try:
            # Initialize LIME's image explainer.
            # verbose=False keeps console cleaner during LIME's internal sampling.
            # random_state can be set for reproducibility of LIME's perturbations.
            self.explainer = lime_image.LimeImageExplainer(verbose=False, random_state=42)
            logger.info("LIME explainer initialized successfully.")
        except Exception as e:
            logger.error(f"Failed to initialize LIME explainer: {e}", exc_info=True)
            raise RuntimeError("LIME explainer initialization failed.") from e

    def _predict_fn(self, images: np.ndarray) -> np.ndarray:
        try:
            # LIME typically provides images in float64 and range [0,1] if the input
            # to explain_instance was in that range.
            # Keras models usually expect float32.
            images_float32 = images.astype(np.float32)
            predictions = self.model.predict(images_float32, verbose=0)

            if predictions.shape[0] != images.shape[0] or predictions.shape[1] != self.num_classes:
                logger.error(
                    f"LIME _predict_fn: Unexpected prediction shape from model: {predictions.shape}. "
                    f"Expected ({images.shape[0]}, {self.num_classes})."
                )
                # Return uniform probabilities to prevent LIME from crashing,
                # though the explanation will likely be meaningless.
                return np.ones((images.shape[0], self.num_classes)) / self.num_classes

            return predictions

        except Exception as e:
            logger.error(f"LIME _predict_fn: Prediction failed: {e}", exc_info=True)
            # Return uniform probabilities to prevent LIME from crashing.
            return np.ones((images.shape[0], self.num_classes)) / self.num_classes


    def explain_instance(self,
                         image_array_normalized: np.ndarray,
                         num_samples: int = 1000,
                         top_labels: Optional[int] = 1,
                         num_features: int = 10,
                         positive_only: bool = True,
                         hide_rest: bool = False,
                         segmentation_fn: Optional[Callable] = None, # Allow custom segmentation
                         output_dir: Optional[Union[str, Path]] = None,
                         filename_prefix: str = "lime_explanation"
                         ) -> Tuple[Optional[Any], Optional[plt.Figure], Optional[np.ndarray]]:
        if not (isinstance(image_array_normalized, np.ndarray) and
                image_array_normalized.ndim == 3 and
                image_array_normalized.shape[-1] == 3):
            logger.error("LIME explain: Invalid input image array shape.")
            return None, None, None

        if not (image_array_normalized.min() >= -0.01 and image_array_normalized.max() <= 1.01):
             logger.warning(
                 "LIME explain: Input image doesn't appear to be normalized to [0, 1] "
                 f"(min: {image_array_normalized.min():.2f}, max: {image_array_normalized.max():.2f}). "
                 "LIME results might be affected if the _predict_fn expects [0,1]."
             )

        logger.info(
            f"Generating LIME explanation: num_samples={num_samples}, "
            f"num_features={num_features}, top_labels={top_labels}..."
        )

        explanation = None
        image_mask = None
        mask = None
        try:
            explanation = self.explainer.explain_instance(
                image=image_array_normalized, # LIME expects float [0,1] or uint8 [0,255]
                classifier_fn=self._predict_fn,
                top_labels=top_labels,
                hide_color=0,
                num_samples=num_samples,
                num_features=num_features,
                segmentation_fn=segmentation_fn, # Use LIME's default if None
                random_seed=42 # For reproducibility of LIME's sampling process
            )
            logger.info("LIME explanation object generated.")

            # --- Prepare data for plotting and overlay ---
            model_preds_on_original = self._predict_fn(np.expand_dims(image_array_normalized, axis=0))[0]
            model_top_pred_index = np.argmax(model_preds_on_original)

            # Get image and mask for the class the model predicted as top
            label_to_explain = model_top_pred_index
            if top_labels is not None and top_labels > 0 and explanation.top_labels:
                 label_to_explain = explanation.top_labels[0]

            lime_explained_class_name = self.class_names[label_to_explain]

            image_mask, mask = explanation.get_image_and_mask(
                label=label_to_explain,
                positive_only=positive_only,
                num_features=num_features,
                hide_rest=hide_rest
            )
        except Exception as e_lime_gen:
            logger.error(f"Failed during LIME explanation generation (explain_instance or get_image_and_mask): {e_lime_gen}", exc_info=True)
            # If generation fails, return None for explanation and visual outputs
            return None, None, None

        # --- Generate the visual overlay array ---
        lime_overlay_array_uint8 = None
        try:
            # mark_boundaries returns float64 or float32 in [0,1] if input image_mask is float [0,1]
            lime_overlay_array_float = mark_boundaries(image_mask, mask)

            # Convert to uint8 [0,255] for consistent handling in plotting/saving
            # Ensure clipping in case of floating point inaccuracies just outside [0,1]
            lime_overlay_array_uint8 = np.clip(lime_overlay_array_float * 255, 0, 255).astype(np.uint8)

        except Exception as e_overlay:
            logger.error(f"LIME explain: Failed to generate mark_boundaries overlay: {e_overlay}", exc_info=True)
            lime_overlay_array_uint8 = None # Set to None on failure

        # --- Create and optionally save the plot ---
        fig = None
        try:
            model_top_class_name = self.class_names[model_top_pred_index]
            model_top_confidence = model_preds_on_original[model_top_pred_index]
            confidence_scores_str = ", ".join([f"{name}: {score:.2f}" for name, score in zip(self.class_names, model_preds_on_original)])

            fig, axs = plt.subplots(1, 2, figsize=(12, 6))
            fig.patch.set_facecolor('white')

            axs[0].imshow(image_array_normalized)
            axs[0].set_title("Original Image")
            axs[0].axis('off')

            if lime_overlay_array_uint8 is not None:
                 axs[1].imshow(lime_overlay_array_uint8)
                 axs[1].set_title(f"LIME: Explaining '{lime_explained_class_name}'")
            else:
                 axs[1].text(0.5, 0.5, "LIME Overlay N/A", ha='center', va='center', fontsize=12, color='gray')
                 axs[1].set_title("LIME Explanation")
            axs[1].axis('off')

            fig.suptitle(
                f"Model Top Pred: '{model_top_class_name}' ({model_top_confidence:.2f})\n"
                f"Confidences: {confidence_scores_str}",
                fontsize=10, y=1.0)

            plt.tight_layout(rect=[0, 0, 1, 0.93])

            if output_dir:
                try:
                    output_path = Path(output_dir) / f"{filename_prefix}.png"
                    output_path.parent.mkdir(parents=True, exist_ok=True)
                    plt.savefig(output_path, bbox_inches='tight')
                    logger.info(f"LIME explanation plot saved to: {output_path}")
                except Exception as e_save_plot:
                    logger.error(f"LIME explain: Failed to save plot to {output_path}: {e_save_plot}", exc_info=True)
                finally:
                    plt.close(fig) # Close the figure after saving

        except Exception as e_plot:
            logger.error(f"LIME explain: Failed to create or save plot: {e_plot}", exc_info=True)
            if fig: plt.close(fig) # Ensure figure is closed even if saving failed
            fig = None # Set fig to None to indicate plotting failure

        # Return explanation object, the figure, and the overlay array
        return explanation, fig, lime_overlay_array_uint8

# **Step 8: training.py**

In [10]:
import tensorflow as tf
import numpy as np
import os
import logging
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, f1_score, precision_score, recall_score, accuracy_score
import seaborn as sns
import datetime
import pandas as pd
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Any, Union

logger = logging.getLogger(__name__)

class ModelTrainer:
    def __init__(self, model_dir: Union[str, Path], log_dir: Union[str, Path]):
        self.model_dir = Path(model_dir)
        self.log_dir = Path(log_dir)

        try:
            self.model_dir.mkdir(parents=True, exist_ok=True)
            self.log_dir.mkdir(parents=True, exist_ok=True)
        except OSError as e:
            logger.error(f"Error creating directories {self.model_dir} or {self.log_dir}: {e}", exc_info=True)
            raise # Critical if directories can't be made

        logger.info(f"ModelTrainer initialized. Models dir: {self.model_dir}, Logs dir: {self.log_dir}")

    def _get_callbacks(self,
                       model_name_prefix: str,
                       checkpoint_monitor: str = 'val_accuracy',
                       early_stopping_monitor: str = 'val_loss',
                       early_stopping_patience: int = 10,
                       lr_reduce_patience: int = 5
                       ) -> List[tf.keras.callbacks.Callback]:
        timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        run_log_dir = self.log_dir / f"{model_name_prefix}_{timestamp}"
        best_model_path = self.model_dir / f"{model_name_prefix}_best.keras" # Use .keras format

        callbacks = [
            EarlyStopping(
                monitor=early_stopping_monitor,
                patience=early_stopping_patience,
                verbose=1,
                restore_best_weights=True # Restores model weights from the epoch with the best value
            ),
            ModelCheckpoint(
                filepath=str(best_model_path),
                monitor=checkpoint_monitor,
                save_best_only=True,
                save_weights_only=False, # Save the entire model
                mode='max' if 'accuracy' in checkpoint_monitor.lower() else 'min',
                verbose=1
            ),
            ReduceLROnPlateau(
                monitor=early_stopping_monitor,
                factor=0.2, # Reduce LR by a factor of 5
                patience=lr_reduce_patience,
                min_lr=1e-6, # Minimum learning rate
                verbose=1
            ),
            TensorBoard(
                log_dir=str(run_log_dir),
                histogram_freq=1, # Log histograms for weights/biases once per epoch
                write_graph=True
            )
        ]
        logger.info(f"Callbacks configured for '{model_name_prefix}'. Best model will be saved to: {best_model_path}")
        logger.info(f"TensorBoard logs will be saved to: {run_log_dir}")
        return callbacks

    def _plot_training_history(self,
                               history_data: Dict[str, List[float]],
                               model_name: str,
                               history_plot_path: Path):
        required_keys = ['accuracy', 'val_accuracy', 'loss', 'val_loss']
        if not all(key in history_data for key in required_keys):
            available_keys = list(history_data.keys())
            logger.error(
                f"History data for {model_name} is missing some required keys ({required_keys}). "
                f"Available: {available_keys}. Cannot plot full history."
            )
            return # Exit if essential data is missing

        acc = history_data.get('accuracy', [])
        val_acc = history_data.get('val_accuracy', [])
        loss = history_data.get('loss', [])
        val_loss = history_data.get('val_loss', [])

        # Ensures all metrics have the same length for plotting
        min_len = min(len(acc), len(val_acc), len(loss), len(val_loss))
        if min_len == 0:
            logger.warning(f"No history data points found for {model_name}. Skipping plot.")
            return

        epochs = range(1, min_len + 1)
        acc, val_acc, loss, val_loss = acc[:min_len], val_acc[:min_len], loss[:min_len], val_loss[:min_len]


        try:
            fig, axs = plt.subplots(1, 2, figsize=(15, 6))
            fig.suptitle(f"Training History for {model_name}", fontsize=16, y=1.02)

            # Accuracy Plot
            axs[0].plot(epochs, acc, 'bo-', label='Training Accuracy')
            axs[0].plot(epochs, val_acc, 'ro-', label='Validation Accuracy')
            axs[0].set_title('Model Accuracy', fontsize=14)
            axs[0].set_xlabel('Epoch', fontsize=12)
            axs[0].set_ylabel('Accuracy', fontsize=12)
            axs[0].legend(fontsize=10)
            axs[0].grid(True, linestyle='--', alpha=0.7)

            # Loss Plot
            axs[1].plot(epochs, loss, 'bo-', label='Training Loss')
            axs[1].plot(epochs, val_loss, 'ro-', label='Validation Loss')
            axs[1].set_title('Model Loss', fontsize=14)
            axs[1].set_xlabel('Epoch', fontsize=12)
            axs[1].set_ylabel('Loss', fontsize=12)
            axs[1].legend(fontsize=10)
            axs[1].grid(True, linestyle='--', alpha=0.7)

            plt.tight_layout(rect=[0, 0, 1, 0.95])
            history_plot_path.parent.mkdir(parents=True, exist_ok=True)
            plt.savefig(history_plot_path)
            logger.info(f"Training history plot saved to: {history_plot_path}")
            plt.close(fig)
        except Exception as e:
            logger.error(f"Failed to plot training history for {model_name}: {e}", exc_info=True)

    def _compile_model(self, model: tf.keras.Model, learning_rate: float):
        """Compiles the Keras model."""
        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )
        logger.info(f"Model '{model.name}' compiled with Adam optimizer (LR={learning_rate:.0e}).")


    def _perform_fine_tuning_setup(self, model: tf.keras.Model, fine_tune_at_layer: Optional[Union[str, int]]):
        """Unfreezes layers of the model for fine-tuning."""
        base_model_to_unfreeze = None
        # Heuristic to find the base model component
        for layer in model.layers:
            if isinstance(layer, tf.keras.Model) and "base" in layer.name.lower():
                base_model_to_unfreeze = layer
                break

        if not base_model_to_unfreeze:
            logger.warning(f"Could not identify a distinct 'base' sub-model in '{model.name}'. "
                           f"Will attempt to make the entire model trainable. This might not be optimal.")
            base_model_to_unfreeze = model # Fallback to unfreezing the whole passed model

        base_model_to_unfreeze.trainable = True
        unfreeze_from_index = 0 # Default: unfreeze all layers in the base_model_to_unfreeze

        if fine_tune_at_layer is not None and base_model_to_unfreeze != model:
            num_base_layers = len(base_model_to_unfreeze.layers)
            if isinstance(fine_tune_at_layer, str):
                try:
                    layer_names = [l.name for l in base_model_to_unfreeze.layers]
                    unfreeze_from_index = layer_names.index(fine_tune_at_layer)
                except ValueError:
                    logger.error(f"Layer name '{fine_tune_at_layer}' not found in base model '{base_model_to_unfreeze.name}'. Unfreezing all its layers.")
            elif isinstance(fine_tune_at_layer, int):
                if 0 <= fine_tune_at_layer < num_base_layers:
                    unfreeze_from_index = fine_tune_at_layer
                else:
                    logger.error(f"Invalid fine_tune_at index {fine_tune_at_layer} for base '{base_model_to_unfreeze.name}'. Unfreezing all its layers.")
            else:
                logger.error(f"Invalid fine_tune_at_layer type: {type(fine_tune_at_layer)}. Unfreezing all base layers.")

            for i, layer in enumerate(base_model_to_unfreeze.layers):
                if i < unfreeze_from_index:
                    layer.trainable = False
                else: # Layers from unfreeze_from_index onwards
                    if not isinstance(layer, tf.keras.layers.BatchNormalization):
                        layer.trainable = True
                    else:
                        layer.trainable = False # Keep BN frozen
            logger.info(f"Fine-tuning '{base_model_to_unfreeze.name}': Layers from index {unfreeze_from_index} ('{base_model_to_unfreeze.layers[unfreeze_from_index].name}') onwards are trainable (BN layers kept frozen).")
        else:
             logger.info(f"Fine-tuning: All layers in '{base_model_to_unfreeze.name}' (or the full model) are set to trainable (BN layers will be kept frozen by explicit check if applicable).")
             # Ensure BN layers are frozen if we are unfreezing the whole model
             for layer in base_model_to_unfreeze.layers:
                 if isinstance(layer, tf.keras.layers.BatchNormalization):
                     layer.trainable = False


        trainable_count = sum(1 for w in model.trainable_weights)
        logger.info(f"Total trainable weights in '{model.name}' after fine-tune setup: {trainable_count}")


    def train_model(self,
                    model: tf.keras.Model,
                    train_ds: tf.data.Dataset,
                    val_ds: tf.data.Dataset,
                    model_base_name: str,
                    initial_epochs: int = 5,
                    fine_tune_epochs: int = 10,
                    initial_lr: float = 1e-3,
                    fine_tune_lr: float = 1e-5,
                    fine_tune_at_layer: Optional[Union[str, int]] = None,
                    class_weights: Optional[Dict[int, float]] = None
                    ) -> Tuple[tf.keras.Model, Dict[str, List[float]]]:
        combined_history: Dict[str, List[float]] = {
            'loss': [], 'accuracy': [], 'val_loss': [], 'val_accuracy': [], 'lr': []
        }
        final_model_path = self.model_dir / f"{model_base_name}_final.keras"
        history_plot_path = self.log_dir / f"{model_base_name}_training_history.png"

        # --- Phase 1: Initial Training ---
        if initial_epochs > 0:
            logger.info(f"--- [{model_base_name}] Starting Initial Training Phase ({initial_epochs} epochs) ---")
            # Base model layers should be frozen by model creation function if intended
            self._compile_model(model, initial_lr)

            logger.info(f"Model summary before initial training for '{model.name}':")
            model.summary(print_fn=logger.info)

            callbacks_initial = self._get_callbacks(f"{model_base_name}_initial_train")

            history_initial = model.fit(
                train_ds,
                epochs=initial_epochs,
                validation_data=val_ds,
                callbacks=callbacks_initial,
                class_weight=class_weights,
                verbose=1
            )
            for key in history_initial.history:
                combined_history.setdefault(key, []).extend(history_initial.history[key])

            initial_epochs_trained = len(history_initial.history['loss']) # Actual epochs run
            logger.info(f"[{model_base_name}] Initial training phase completed after {initial_epochs_trained} epochs.")
        else:
            initial_epochs_trained = 0
            logger.info(f"[{model_base_name}] Skipping initial training phase (initial_epochs=0).")


        # --- Phase 2: Fine-tuning ---
        if fine_tune_epochs > 0:
            logger.info(f"--- [{model_base_name}] Starting Fine-tuning Phase ({fine_tune_epochs} epochs) ---")

            best_initial_model_path = self.model_dir / f"{model_base_name}_initial_train_best.keras"
            if initial_epochs > 0 and best_initial_model_path.exists():
                logger.info(f"Loading best weights from initial phase: {best_initial_model_path}")
                try:
                    custom_objects = {}
                    if "mobilenetv2" in model_base_name.lower():
                         from model_architectures import mobilenet_preprocess_wrapper # Ensure import
                         custom_objects['mobilenet_preprocess_wrapper'] = mobilenet_preprocess_wrapper
                    if "resnet50v2" in model_base_name.lower():
                         from model_architectures import resnet_v2_preprocess_wrapper # Ensure import
                         custom_objects['resnet_v2_preprocess_wrapper'] = resnet_v2_preprocess_wrapper
                    if "inceptionv3" in model_base_name.lower():
                         from model_architectures import inception_v3_preprocess_wrapper # Ensure import
                         custom_objects['inception_v3_preprocess_wrapper'] = inception_v3_preprocess_wrapper
                    model.load_weights(str(best_initial_model_path))
                    logger.info(f"Successfully loaded weights from {best_initial_model_path} into existing model structure.")

                except Exception as e:
                    logger.error(f"Failed to load best weights from initial phase: {e}. Continuing with current model weights.", exc_info=True)

            self._perform_fine_tuning_setup(model, fine_tune_at_layer)
            self._compile_model(model, fine_tune_lr)

            logger.info(f"Model summary before fine-tuning for '{model.name}':")
            model.summary(print_fn=logger.info)

            callbacks_finetune = self._get_callbacks(f"{model_base_name}_finetune")

            total_epochs_for_fit = initial_epochs_trained + fine_tune_epochs

            history_fine = model.fit(
                train_ds,
                epochs=total_epochs_for_fit,
                initial_epoch=initial_epochs_trained,
                validation_data=val_ds,
                callbacks=callbacks_finetune,
                class_weight=class_weights,
                verbose=1
            )
            for key in history_fine.history:
                combined_history.setdefault(key, []).extend(history_fine.history[key])

            logger.info(f"[{model_base_name}] Fine-tuning phase completed.")
        else:
            logger.info(f"[{model_base_name}] Skipping fine-tuning phase (fine_tune_epochs=0).")

        # --- Save Final Model ---
        try:
            model.save(str(final_model_path))
            logger.info(f"[{model_base_name}] Final trained model saved to: {final_model_path}")
        except Exception as e:
            logger.error(f"[{model_base_name}] Failed to save final model: {e}", exc_info=True)

        self._plot_training_history(combined_history, model_base_name, history_plot_path)
        return model, combined_history


    def _evaluate_single_or_ensemble(self,
                                     models_or_model: Union[tf.keras.Model, List[tf.keras.Model]],
                                     dataset: tf.data.Dataset,
                                     class_names: List[str],
                                     dataset_name: str = "Test Set",
                                     is_ensemble: bool = False
                                     ) -> Optional[Dict[str, Any]]:
        model_name_log_prefix = "[Ensemble]" if is_ensemble else f"[{models_or_model.name if not is_ensemble else 'UnknownSingleModel'}]"

        logger.info(f"--- {model_name_log_prefix} Evaluating on {dataset_name} ---") # Corrected log message
        if dataset is None:
            logger.error(f"{model_name_log_prefix} Evaluation dataset ({dataset_name}) is None.")
            return None

        y_true_list: List[int] = []
        y_pred_probs_list: List[np.ndarray] = [] # List of probability arrays

        try:
            logger.info(f"{model_name_log_prefix} Gathering predictions and true labels from {dataset_name}...")
            for images_batch, labels_batch in dataset:
                if is_ensemble:
                    batch_model_preds = []
                    for model_item in models_or_model: # type: ignore
                        preds = model_item.predict(images_batch, verbose=0)
                        batch_model_preds.append(preds)
                    # Average probabilities across models for the batch
                    avg_preds_batch = np.mean(np.stack(batch_model_preds, axis=0), axis=0)
                    y_pred_probs_list.extend(avg_preds_batch)
                else: # Single model
                    preds_batch = models_or_model.predict(images_batch, verbose=0) # type: ignore
                    y_pred_probs_list.extend(preds_batch)

                y_true_list.extend(labels_batch.numpy())

            if not y_true_list or not y_pred_probs_list:
                logger.error(f"{model_name_log_prefix} Could not extract labels or predictions from {dataset_name}.")
                return None

            y_true = np.array(y_true_list)
            y_pred_probs = np.array(y_pred_probs_list)
            y_pred_classes = np.argmax(y_pred_probs, axis=1)

            logger.info(f"{model_name_log_prefix} Calculating evaluation metrics...")
            accuracy = accuracy_score(y_true, y_pred_classes)
            f1 = f1_score(y_true, y_pred_classes, average='weighted', zero_division=0)
            precision = precision_score(y_true, y_pred_classes, average='weighted', zero_division=0)
            recall = recall_score(y_true, y_pred_classes, average='weighted', zero_division=0)
            report_str = classification_report(y_true, y_pred_classes, target_names=class_names, zero_division=0, output_dict=False)
            report_dict = classification_report(y_true, y_pred_classes, target_names=class_names, zero_division=0, output_dict=True)
            cm = confusion_matrix(y_true, y_pred_classes)

            auc_weighted = None
            num_actual_classes = len(class_names)
            if y_pred_probs.shape[1] == num_actual_classes and num_actual_classes > 1: # AUC needs at least 2 classes
                # Check if all true labels belong to a single class
                if len(np.unique(y_true)) > 1:
                    try:
                        y_true_one_hot = tf.keras.utils.to_categorical(y_true, num_classes=num_actual_classes)
                        auc_weighted = roc_auc_score(y_true_one_hot, y_pred_probs, multi_class='ovr', average='weighted')
                    except ValueError as auc_e:
                        logger.warning(f"{model_name_log_prefix} Could not calculate ROC AUC (ValueError): {auc_e}. "
                                       "This can happen if only one class is present in y_true for a given fold/dataset.")
                    except Exception as auc_e:
                        logger.error(f"{model_name_log_prefix} Unexpected error during AUC calculation: {auc_e}", exc_info=True)
                else:
                    logger.warning(f"{model_name_log_prefix} Skipping ROC AUC: Only one class present in true labels for {dataset_name}.")
            else:
                logger.warning(f"{model_name_log_prefix} Skipping ROC AUC: Prediction columns ({y_pred_probs.shape[1]}) != num classes ({num_actual_classes}) or num_classes <= 1.")

            logger.info(f"{model_name_log_prefix} Evaluation Results ({dataset_name}):")
            logger.info(f"  Accuracy: {accuracy:.4f}")
            logger.info(f"  Weighted F1-Score: {f1:.4f}")
            logger.info(f"  Weighted Precision: {precision:.4f}")
            logger.info(f"  Weighted Recall: {recall:.4f}")
            if auc_weighted is not None: logger.info(f"  Weighted ROC AUC: {auc_weighted:.4f}")
            logger.info(f"  Classification Report:\n{report_str}")

            # Plot Confusion Matrix
            cm_filename_suffix = "ensemble" if is_ensemble else (models_or_model.name if not is_ensemble else "model")
            cm_plot_path = self.log_dir / f"{cm_filename_suffix}_{dataset_name.replace(' ','_')}_confusion_matrix.png"
            try:
                plt.figure(figsize=(8, 6))
                sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                            xticklabels=class_names, yticklabels=class_names)
                plt.ylabel('Actual Class', fontsize=12)
                plt.xlabel('Predicted Class', fontsize=12)
                plt.title(f'Confusion Matrix - {model_name_log_prefix} ({dataset_name})', fontsize=14)
                plt.tight_layout()
                cm_plot_path.parent.mkdir(parents=True, exist_ok=True)
                plt.savefig(cm_plot_path)
                logger.info(f"Confusion matrix plot saved to: {cm_plot_path}")
                plt.close()
            except Exception as plot_e:
                logger.error(f"{model_name_log_prefix} Failed to plot confusion matrix: {plot_e}", exc_info=True)

            results = {
                'accuracy': accuracy, 'f1_score_weighted': f1, 'precision_weighted': precision,
                'recall_weighted': recall, 'roc_auc_weighted': auc_weighted,
                'classification_report_str': report_str,
                'classification_report_dict': report_dict,
                'confusion_matrix': cm.tolist() # Store as list for JSON serialization
            }
            return results

        except Exception as e:
            logger.error(f"{model_name_log_prefix} Error during evaluation on {dataset_name}: {e}", exc_info=True)
            return None

    def evaluate_model(self, model: tf.keras.Model, dataset: tf.data.Dataset, class_names: List[str], dataset_name: str = "Test Set") -> Optional[Dict[str, Any]]:
        return self._evaluate_single_or_ensemble(model, dataset, class_names, dataset_name, is_ensemble=False)

    def evaluate_ensemble(self, models: List[tf.keras.Model], dataset: tf.data.Dataset, class_names: List[str], dataset_name: str = "Test Set") -> Optional[Dict[str, Any]]:
        if not models:
            logger.error("[Ensemble] No models provided for ensemble evaluation.")
            return None
        return self._evaluate_single_or_ensemble(models, dataset, class_names, dataset_name, is_ensemble=True)


    def predict_with_ensemble(self,
                              models: List[tf.keras.Model],
                              test_ds: tf.data.Dataset,
                              test_paths: List[str],
                              class_names: List[str]
                              ) -> Optional[pd.DataFrame]:
        logger.info("--- [Ensemble] Making predictions with ensemble ---")
        if not models:
            logger.error("[Ensemble] No models provided for ensemble prediction.")
            return None

        model_names = [m.name for m in models]
        logger.info(f"[Ensemble] Predicting with {len(models)} models: {model_names}")

        all_ensemble_probs: List[np.ndarray] = []
        num_processed_images = 0

        try:
            logger.info("[Ensemble] Generating predictions from all models...")
            if isinstance(test_ds.element_spec, tuple) and len(test_ds.element_spec) == 2:
                 test_ds_images_only = test_ds.map(lambda img, lbl: img)
                 logger.info("[Ensemble] Input dataset yields (image, label), extracting images for prediction.")
            elif isinstance(test_ds.element_spec, tf.TensorSpec): # Dataset yields only images
                 test_ds_images_only = test_ds
                 logger.info("[Ensemble] Input dataset yields images directly.")
            else:
                logger.error(f"[Ensemble] Unexpected test_ds.element_spec: {test_ds.element_spec}. Cannot proceed with prediction.")
                return None


            for images_batch in test_ds_images_only:
                batch_model_preds = []
                for model in models:
                    preds = model.predict(images_batch, verbose=0)
                    batch_model_preds.append(preds)

                avg_preds_batch = np.mean(np.stack(batch_model_preds, axis=0), axis=0)
                all_ensemble_probs.extend(avg_preds_batch)
                num_processed_images += images_batch.shape[0]
                if num_processed_images % 100 == 0:
                     logger.info(f"[Ensemble] Processed {num_processed_images} images for prediction...")

            logger.info(f"[Ensemble] Finished predicting on {num_processed_images} images.")

            if num_processed_images != len(test_paths):
                 logger.warning(
                     f"[Ensemble] Number of predictions ({num_processed_images}) "
                     f"does not match number of test paths ({len(test_paths)}). "
                     "Results might be misaligned or incomplete. Truncating paths to match predictions."
                 )
                 test_paths = test_paths[:num_processed_images]
                 if not test_paths and num_processed_images > 0:
                     logger.error("[Ensemble] Test paths became empty after truncation. Cannot create meaningful CSV.")
                     return None

            if num_processed_images == 0:
                logger.error("[Ensemble] No predictions were generated. Cannot create results DataFrame.")
                return None

            all_ensemble_probs_np = np.array(all_ensemble_probs)
            ensemble_pred_indices = np.argmax(all_ensemble_probs_np, axis=1)

            results_list = []
            logger.info("[Ensemble] Compiling prediction results into DataFrame...")
            for i, path_str in enumerate(test_paths):
                pred_class_idx = ensemble_pred_indices[i]
                result_entry = {
                    'filename': Path(path_str).name,
                    'predicted_class_ensemble': class_names[pred_class_idx],
                    'confidence_ensemble': float(all_ensemble_probs_np[i, pred_class_idx])
                }
                for j, class_name_score in enumerate(class_names):
                    result_entry[f'score_ensemble_{class_name_score}'] = float(all_ensemble_probs_np[i, j])
                results_list.append(result_entry)

            results_df = pd.DataFrame(results_list)
            logger.info(f"[Ensemble] Successfully created DataFrame with {len(results_df)} predictions.")
            return results_df

        except Exception as e:
            logger.error(f"[Ensemble] Prediction process failed: {e}", exc_info=True)
            return None

# **Step 9: main.py**

In [None]:
import tensorflow as tf
import numpy as np
import os
import sys
import logging
import argparse
from pathlib import Path
import time
import matplotlib.pyplot as plt
import pandas as pd
import textwrap
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
import json
from skimage.segmentation import slic

# --- Gemini AI Setup ---
HARDCODED_GOOGLE_API_KEY = "AIzaSyDhPxLnr_BqfzQdYvY3R5LX3OhZeAnT2UQ"
ALL_AVAILABLE_MODEL_KEYS = ['efficientnet_b0', 'densenet121', 'mobilenetv2', 'resnet50v2', 'inceptionv3']
ENSEMBLE_MODEL_KEYS = ['densenet121', 'mobilenetv2', 'resnet50v2', 'inceptionv3']

genai = None
logger_gemini = logging.getLogger("GeminiAI")

try:
    import google.generativeai as genai_imported
    genai = genai_imported
    if HARDCODED_GOOGLE_API_KEY and HARDCODED_GOOGLE_API_KEY != "":
        GOOGLE_API_KEY_TO_USE = HARDCODED_GOOGLE_API_KEY
        try:
            genai.configure(api_key=GOOGLE_API_KEY_TO_USE)
            GEMINI_MODEL_ID = "models/gemini-1.5-flash-latest"
            logger_gemini.info(f"Google Generative AI SDK imported and API key seems valid. Using model: {GEMINI_MODEL_ID}")
        except Exception as e_api_test:
            logger_gemini.warning(f"Hardcoded Gemini API key found, but failed API test: {e_api_test}. Gemini AI explanations will be disabled.", exc_info=False)
            genai = None
    else:
        logger_gemini.warning("Hardcoded GOOGLE_API_KEY variable was empty or not a valid key. Gemini AI explanations will be disabled.")
        genai = None
except ImportError:
    logger_gemini.warning("google.generativeai SDK not found. Gemini AI explanations will be disabled.")
    genai = None
# --- End Gemini API Setup ---

# --- Import Project-Specific Modules ---
try:
    from data_processing import DataProcessor
    from model_architectures import (
        ENSEMBLE_MODEL_FUNCTIONS,
        mobilenet_preprocess_wrapper,
        resnet_v2_preprocess_wrapper,
        inception_v3_preprocess_wrapper
    )
    from training import ModelTrainer
    from visualization import find_suitable_gradcam_layer, make_gradcam_heatmap, display_gradcam
    from explainability import LimeExplainer
except ImportError as e:
    log_func = logging.error if logging.getLogger().hasHandlers() else print
    log_func(f"CRITICAL ERROR: Failed to import one or more project modules: {e}")
    log_func("Ensure all preceding cells/scripts (DataProcessor, model_architectures, etc.) have been executed/are in PYTHONPATH.")
    raise

# --- Global Configuration ---
logger = logging.getLogger() # Get root logger
SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)

# --- Helper Functions ---
def verify_directory(dir_path: Union[str, Path], dir_description: str = "Directory", check_empty: bool = False, create: bool = False) -> bool:
    path = Path(dir_path)
    if not path.exists():
        if create:
            logger.info(f"Creating {dir_description}: {path}")
            try:
                path.mkdir(parents=True, exist_ok=True)
                return True
            except OSError as e:
                logger.error(f"Failed to create {dir_description} at {path}: {e}")
                return False
        else:
            logger.error(f"{dir_description} not found: {path}")
            return False
    if not path.is_dir():
        logger.error(f"Path exists but is not a directory: {path}")
        return False
    if check_empty and not any(path.iterdir()): # Check if directory is empty
        logger.warning(f"{dir_description} is empty: {path}")
    return True

def load_ensemble_models(model_keys_to_load: List[str], model_dir: Path) -> Optional[List[tf.keras.Model]]:
    loaded_models_dict: Dict[str, tf.keras.Model] = {}
    logger.info(f"--- Loading Specified Models: {model_keys_to_load} ---")
    for model_key in model_keys_to_load:
        model_filename = f"{model_key}_final.keras"
        model_path = model_dir / model_filename
        custom_objects = {}
        if model_key == 'mobilenetv2':
            custom_objects['mobilenet_preprocess_wrapper'] = mobilenet_preprocess_wrapper
        elif model_key == 'resnet50v2':
            custom_objects['resnet_v2_preprocess_wrapper'] = resnet_v2_preprocess_wrapper
        elif model_key == 'inceptionv3':
            custom_objects['inception_v3_preprocess_wrapper'] = inception_v3_preprocess_wrapper

        if model_path.exists():
            logger.info(f"Loading model '{model_key}' from: {model_path}")
            try:
                model = tf.keras.models.load_model(
                    str(model_path),
                    custom_objects=custom_objects if custom_objects else None,
                    compile=False # Recommended to compile after loading if needed, esp. for fine-tuning
                )
                loaded_models_dict[model_key] = model
                logger.info(f"Successfully loaded model '{model.name}'.")
            except Exception as e:
                logger.error(f"Failed to load model '{model_key}' from {model_path}: {e}", exc_info=True)
                return None # Critical failure if a model can't be loaded
        else:
            logger.error(f"Model file not found for '{model_key}': {model_path}")
            return None
    if len(loaded_models_dict) == len(model_keys_to_load):
        logger.info(f"--- All {len(model_keys_to_load)} Specified Models Loaded Successfully ---")
        return [loaded_models_dict[key] for key in model_keys_to_load] # Ensure order matches input
    else:
        logger.error("Discrepancy in loaded models vs requested. Returning None.")
        return None

def generate_gemini_explanation(
    image_filename: str, predicted_class: str, confidence: float,
    grad_cam_description: str, lime_description: str, class_names: List[str],
    ensemble_model_names_str: str,
    primary_model_name_for_viz: Optional[str]
    ) -> str:
    if not genai:
        logger_gemini.warning("Gemini AI is disabled. Cannot generate explanation.")
        return "AI explanation not available (Gemini AI is disabled)."

    prompt = f"""
    Analyze the following AI model's interpretation of a lung CT scan image.
    The AI model is an ensemble of deep learning models ({ensemble_model_names_str})
    trained for classification of lung conditions into {', '.join(class_names)}.
    The interpretability (XAI) findings below are from the '{primary_model_name_for_viz if primary_model_name_for_viz else "a primary component"}' model of the ensemble.

    Image Filename: {image_filename}
    Ensemble Model's Predicted Condition: {predicted_class}
    Ensemble Confidence Score for this Prediction: {confidence:.2%}

    Interpretability Tool Findings (from '{primary_model_name_for_viz if primary_model_name_for_viz else "primary component"}'):
    - Grad-CAM Heatmap: {grad_cam_description}
    - LIME Explanation: {lime_description}

    Based ONLY on the information provided above, provide a technical analysis suitable for a data scientist or AI practitioner.
    Your explanation should:
    1.  Discuss the potential implications of the '{predicted_class}' prediction with {confidence:.2%} confidence, considering it's an ensemble output.
    2.  Relate the Grad-CAM findings for the '{primary_model_name_for_viz if primary_model_name_for_viz else "primary component"}' model to how a Convolutional Neural Network (CNN) might arrive at this decision. Mention concepts like feature extraction, activation maps, and how specific highlighted regions could correspond to learned patterns indicative of the '{predicted_class}'.
    3.  Explain the LIME superpixels for the '{primary_model_name_for_viz if primary_model_name_for_viz else "primary component"}' model in terms of local feature importance. How do these specific image patches contribute to or detract from the model's decision for the '{predicted_class}'?
    4.  If possible, briefly comment on the consistency or divergence between Grad-CAM and LIME findings for '{primary_model_name_for_viz if primary_model_name_for_viz else "primary component"}'.
    5.  Use appropriate technical terminology related to deep learning, computer vision, and model interpretability.
    Avoid definitive medical diagnoses. Focus on the AI's decision-making process.
    Be concise yet technically informative (target 4-6 sentences).
    """
    try:
        logger_gemini.info(f"Sending TECHNICAL prompt to Gemini for {image_filename} (explaining '{predicted_class}')...")
        gemini_model_instance = genai.GenerativeModel(GEMINI_MODEL_ID)
        generation_config = genai.GenerationConfig(temperature=0.3, max_output_tokens=300) # Adjusted token limit
        response = gemini_model_instance.generate_content(prompt, generation_config=generation_config)
        if response.candidates and response.candidates[0].content.parts:
            explanation_text = response.candidates[0].content.parts[0].text
            logger_gemini.info(f"Gemini TECHNICAL explanation received for {image_filename}.")
            return explanation_text.strip()
        else:
            full_response_text = "N/A"
            try: full_response_text = response.text
            except Exception: pass
            logger_gemini.warning(f"Gemini response for {image_filename} was empty or malformed. Full response: {full_response_text}. Prompt feedback: {response.prompt_feedback if hasattr(response, 'prompt_feedback') else 'N/A'}")
            return "Gemini could not generate an explanation for this case (empty or malformed response)."
    except Exception as e:
        logger_gemini.error(f"Error generating Gemini explanation for {image_filename}: {e}", exc_info=True)
        return f"Error during AI explanation: {str(e)}"

def create_combined_report_image(
    original_img_np: np.ndarray, gradcam_overlay_np: Optional[np.ndarray],
    lime_overlay_np: Optional[np.ndarray], filename: str, predicted_class: str,
    confidence: float, all_class_scores: Dict[str, float], gemini_explanation: str,
    output_path: Path, primary_viz_model_name: Optional[str]
    ):
    num_visualizations = (1 if gradcam_overlay_np is not None else 0) + \
                         (1 if lime_overlay_np is not None else 0)
    viz_model_info = f" (Viz: {primary_viz_model_name})" if primary_viz_model_name and primary_viz_model_name != "N/A" else ""

    if gradcam_overlay_np is not None and lime_overlay_np is not None:
        fig, axs = plt.subplots(2, 2, figsize=(18, 16), gridspec_kw={'height_ratios': [3, 1]})
        image_axs_flat = [axs[0, 0], axs[0, 1], axs[1, 0]] # Original, GradCAM, LIME
        text_ax = axs[1, 1]
    elif gradcam_overlay_np is not None or lime_overlay_np is not None:
         fig, axs = plt.subplots(2, 2, figsize=(16, 14), gridspec_kw={'height_ratios': [3, 1]}) # Original, One XAI, Empty, Text
         image_axs_flat = [axs[0, 0], axs[0, 1]]
         axs[1,0].axis('off') # Turn off the unused bottom-left plot
         text_ax = axs[1, 1]
    else: # Only original image and text
         fig, axs = plt.subplots(2, 1, figsize=(10, 12), gridspec_kw={'height_ratios': [2, 1]})
         image_axs_flat = [axs[0]]
         text_ax = axs[1]

    fig.patch.set_facecolor('white')
    current_ax_idx = 0

    # Plot Original Image
    if current_ax_idx < len(image_axs_flat):
        image_axs_flat[current_ax_idx].imshow(original_img_np)
        image_axs_flat[current_ax_idx].set_title("Original Image", fontsize=14)
        image_axs_flat[current_ax_idx].axis('off')
        current_ax_idx += 1

    # Plot Grad-CAM if available and requested
    if args.run_gradcam: # Check if GradCAM was requested
        if gradcam_overlay_np is not None and current_ax_idx < len(image_axs_flat):
            image_axs_flat[current_ax_idx].imshow(gradcam_overlay_np)
            image_axs_flat[current_ax_idx].set_title(f"Grad-CAM Overlay{viz_model_info}", fontsize=14)
            image_axs_flat[current_ax_idx].axis('off')
        elif current_ax_idx < len(image_axs_flat): # Space available but no GradCAM data
            image_axs_flat[current_ax_idx].text(0.5, 0.5, "Grad-CAM N/A", ha='center', va='center', fontsize=12, color='gray')
            image_axs_flat[current_ax_idx].set_title(f"Grad-CAM Overlay{viz_model_info}", fontsize=14)
            image_axs_flat[current_ax_idx].axis('off')
        current_ax_idx += 1

    # Plot LIME if available and requested
    if args.run_lime: # Check if LIME was requested
        if lime_overlay_np is not None and current_ax_idx < len(image_axs_flat):
            image_axs_flat[current_ax_idx].imshow(lime_overlay_np)
            image_axs_flat[current_ax_idx].set_title(f"LIME Explanation{viz_model_info}", fontsize=14)
            image_axs_flat[current_ax_idx].axis('off')
        elif current_ax_idx < len(image_axs_flat): # Space available but no LIME data
            image_axs_flat[current_ax_idx].text(0.5, 0.5, "LIME N/A", ha='center', va='center', fontsize=12, color='gray')
            image_axs_flat[current_ax_idx].set_title(f"LIME Explanation{viz_model_info}", fontsize=14)
            image_axs_flat[current_ax_idx].axis('off')
        current_ax_idx += 1

    # Turn off any remaining unused axes in the image row/grid
    while current_ax_idx < len(image_axs_flat):
         image_axs_flat[current_ax_idx].axis('off')
         current_ax_idx += 1

    # Add Text Information
    text_ax.axis('off')
    text_content = f"File: {filename}\\n\\n"
    text_content += f"Ensemble Prediction: {predicted_class}\\n"
    text_content += f"Confidence: {confidence:.2%}\\n\\n"
    text_content += "All Class Scores (Ensemble):\\n"
    sorted_scores = sorted(all_class_scores.items(), key=lambda item: item[1], reverse=True)
    for cls_name, score in sorted_scores:
        text_content += f"  - {cls_name}: {score:.3f}\\n"

    text_content += f"\\nAI Model's Potential Reasoning (Summarized by Gemini AI):\\n"
    wrapped_gemini_text = "\\n".join(textwrap.wrap(gemini_explanation, width=80 if num_visualizations > 0 else 60))
    text_content += wrapped_gemini_text

    text_ax.text(0.01, 0.98, text_content, ha='left', va='top', fontsize=10, wrap=True,
                 bbox=dict(boxstyle='round,pad=0.5', fc='aliceblue', alpha=0.9),
                 transform=text_ax.transAxes)

    plt.suptitle(f"Lung Cancer Classification Report: {filename}", fontsize=18, y=1.02 if num_visualizations > 0 else 0.98)
    plt.tight_layout(rect=[0, 0, 1, 0.95]) # Adjust rect to prevent suptitle overlap

    try:
        output_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(output_path, bbox_inches='tight', facecolor=fig.get_facecolor())
        logger.info(f"Combined report image saved to: {output_path}")
    except Exception as e:
        logger.error(f"Failed to save combined report image to {output_path}: {e}", exc_info=True)
    finally:
        plt.close(fig) # Ensure figure is closed

# --- Main Execution Logic ---
def main(args, specific_files_for_notebook_report=None): # specific_files_for_notebook_report is now ignored for selection
    start_time = time.time()
    logger.info(f"--- Starting Pipeline Execution (Mode: {args.mode}) ---")
    logger.info(f"Script arguments: {vars(args)}")
    logger.info(f"Using TensorFlow version: {tf.__version__}")

    if args.mode == 'train' and args.train_models:
        current_run_model_keys = [key for key in args.train_models if key in ALL_AVAILABLE_MODEL_KEYS]
        if not current_run_model_keys:
            logger.error(f"No valid models specified with --train_models. Available: {ALL_AVAILABLE_MODEL_KEYS}. Exiting.")
            return
    else: # Default to ENSEMBLE_MODEL_KEYS for predict/evaluate or if --train_models is None
        current_run_model_keys = ENSEMBLE_MODEL_KEYS
    logger.info(f"Models targeted for this run: {current_run_model_keys}")

    if args.run_gemini and not genai:
        logger.warning("Gemini AI explanations requested (--run_gemini), but Gemini AI is disabled. Skipping Gemini explanations.")
        args.run_gemini = False # Ensure it's off if not available

    base_path = Path(args.base_dir)
    data_base_path = base_path / 'data' / 'raw'
    train_dir = data_base_path / 'train'

    use_curated_test_set = args.use_curated_test_set
    if use_curated_test_set:
        test_dir = data_base_path / 'test' / 'my_curated_test_images_renamed'
        logger.info(f"INFO: Using CURATED RENAMED test directory: {test_dir}")
    else:
        test_dir = data_base_path / 'test'
        logger.info(f"INFO: Using ORIGINAL test directory: {test_dir}")


    model_dir = base_path / 'models'
    log_dir = base_path / 'logs'
    results_dir = base_path / 'results'
    report_images_dir = results_dir / "prediction_reports"

    if not verify_directory(args.base_dir, "Base project directory"): return
    for d, desc, create_flag in [
        (model_dir, "Models output directory", True),
        (log_dir, "Logs output directory", True),
        (results_dir, "Results output directory", True),
        (report_images_dir, "Prediction reports output directory", args.mode=='predict' and (args.run_gradcam or args.run_lime or args.run_gemini)) # Create if reports are on
    ]:
        if not verify_directory(d, desc, create=create_flag): return

    if args.mode == 'train' and not verify_directory(train_dir, "Training data directory", check_empty=True): return
    if args.mode in ['predict', 'evaluate'] and not verify_directory(test_dir, "Test data directory", check_empty=False): return # Don't check empty for test
    if args.mode == 'predict' and not verify_directory(train_dir, "Training data directory (for validation split in predict mode)", check_empty=False):
        logger.warning("Training directory not found for predict mode's validation step. Validation will be skipped if it was intended.")

    logger.info(f"Data source (train for train/validation, test for test/predict): {train_dir}, {test_dir}")
    logger.info(f"Models dir: {model_dir}, Logs dir: {log_dir}, Results dir: {results_dir}")

    try:
        data_processor = DataProcessor(
            target_size=(args.img_height, args.img_width),
            batch_size=args.batch_size,
            seed=SEED
        )
        class_names = data_processor.class_names
        num_classes = data_processor.num_classes
        if not class_names:
             logger.error("DataProcessor could not determine class names. Exiting.")
             return
        logger.info(f"Detected/Using Class Names: {class_names}")
    except Exception as e:
        logger.error(f"Failed to initialize DataProcessor: {e}", exc_info=True)
        return

    trainer = ModelTrainer(model_dir=model_dir, log_dir=log_dir)

    if args.mode == 'train':
        logger.info(f"--- Running Training Mode for models: {current_run_model_keys} ---")
        all_target_models_final_files_present = all((model_dir / f"{key}_final.keras").exists() for key in current_run_model_keys)
        trained_models_map: Dict[str, tf.keras.Model] = {}
        skip_individual_training = False

        if all_target_models_final_files_present and args.load_existing_models_if_present:
            logger.info(f"All final files for targeted models ({current_run_model_keys}) found. Attempting to load them.")
            loaded_list = load_ensemble_models(current_run_model_keys, model_dir)
            if loaded_list and len(loaded_list) == len(current_run_model_keys):
                for i, key in enumerate(current_run_model_keys):
                    trained_models_map[key] = loaded_list[i]
                skip_individual_training = True
                logger.info(f"Successfully loaded all existing targeted final models. Training steps will be skipped for these.")
            else:
                logger.warning(f"Failed to load one or more existing targeted models. Proceeding with training for: {current_run_model_keys}.")
                trained_models_map = {} # Reset if loading failed

        if not skip_individual_training:
            logger.info(f"Proceeding with individual model training for: {current_run_model_keys}")
            intermediate_ds, all_paths, all_labels, ds_size = data_processor.create_dataset_for_splitting(str(train_dir))
            if not intermediate_ds or ds_size == 0:
                logger.error("Failed to create training dataset or dataset is empty. Exiting training.")
                return
            train_ds, val_ds = data_processor.split_and_configure_dataset(intermediate_ds, ds_size, args.validation_split)
            class_weights_dict = data_processor.get_class_weights(all_labels) if args.use_class_weights else None
            all_models_trained_successfully_this_run = True

            for model_key in current_run_model_keys:
                logger.info(f"===== Training Model: {model_key} ===== ")
                if model_key not in ENSEMBLE_MODEL_FUNCTIONS:
                    logger.error(f"Model key '{model_key}' not found in ENSEMBLE_MODEL_FUNCTIONS. Skipping.")
                    all_models_trained_successfully_this_run = False
                    continue
                model_func = ENSEMBLE_MODEL_FUNCTIONS[model_key]
                try:
                    tf.keras.backend.clear_session() # Clear session before creating a new model
                    model_instance = model_func(
                        input_shape=(args.img_height, args.img_width, 3),
                        num_classes=num_classes,
                        # Pass other relevant args if model_func expects them, e.g., weights_path
                    )
                    logger.info(f"Model '{model_instance.name}' structure created.")
                    trained_model, _ = trainer.train_model(
                        model=model_instance, train_ds=train_ds, val_ds=val_ds,
                        model_base_name=model_key, initial_epochs=args.initial_epochs,
                        fine_tune_epochs=args.fine_tune_epochs, initial_lr=args.initial_lr,
                        fine_tune_lr=args.fine_tune_lr, fine_tune_at_layer=args.fine_tune_at,
                        class_weights=class_weights_dict
                    )
                    trained_models_map[model_key] = trained_model
                    logger.info(f"===== Finished Training Model: {model_key} ===== ")
                except Exception as train_e:
                    logger.error(f"Error during training of model '{model_key}': {train_e}", exc_info=True)
                    all_models_trained_successfully_this_run = False
                    logger.warning(f"Training failed for {model_key}. Continuing with other models if any.")
            if not all_models_trained_successfully_this_run:
                 logger.warning("One or more specified models failed during training in this run.")

        # Validation evaluation after training or loading
        if len(trained_models_map) > 0:
            log_prefix_val_eval = f"--- Evaluating {'Ensemble of' if len(trained_models_map) > 1 else 'Model'} ({list(trained_models_map.keys())}) on Validation Set ---"
            logger.info(log_prefix_val_eval)

            # Ensure val_ds is available
            if 'val_ds' not in locals() or val_ds is None: # 'val_ds' might not be defined if training was skipped
                 logger.info("Loading validation dataset for evaluation (as training was skipped or val_ds not available)...")
                 intermediate_ds_eval, _, _, ds_size_eval = data_processor.create_dataset_for_splitting(str(train_dir))
                 if not intermediate_ds_eval or ds_size_eval == 0:
                     logger.error("Failed to create dataset for validation evaluation or dataset is empty. Skipping.")
                     val_ds_for_eval = None
                 else:
                    _, val_ds_for_eval = data_processor.split_and_configure_dataset(intermediate_ds_eval, ds_size_eval, args.validation_split)
            else: # val_ds was defined during training
                 val_ds_for_eval = val_ds

            if val_ds_for_eval:
                models_for_eval = [trained_models_map[key] for key in trained_models_map.keys() if key in trained_models_map]
                if models_for_eval:
                    eval_results = trainer.evaluate_ensemble(models_for_eval, val_ds_for_eval, class_names, "Validation Set") if len(models_for_eval) > 1 else trainer.evaluate_model(models_for_eval[0], val_ds_for_eval, class_names, "Validation Set")
                    if eval_results:
                        results_filename = f"{'_'.join(sorted(trained_models_map.keys()))}_validation_results.json"
                        eval_results_path = results_dir / results_filename
                        try:
                            # Ensure numpy arrays are converted to lists for JSON serialization
                            eval_results_serializable = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in eval_results.items()}
                            if 'classification_report_dict' in eval_results_serializable: # Ensure keys in dict are strings
                                eval_results_serializable['classification_report_dict'] = {str(k_): v_ for k_, v_ in eval_results_serializable['classification_report_dict'].items()}

                            with open(eval_results_path, 'w') as f:
                                 json.dump(eval_results_serializable, f, indent=4)
                            logger.info(f"Validation results for {list(trained_models_map.keys())} saved to {eval_results_path}")
                        except Exception as save_e:
                            logger.error(f"Failed to save validation results: {save_e}", exc_info=True)
                else:
                    logger.error(f"Evaluation on validation set failed for {list(trained_models_map.keys())} or no models were available.")
            else:
                logger.warning("Validation dataset (val_ds_for_eval) could not be prepared. Skipping validation evaluation.")
        else:
             logger.warning("No models were successfully trained or loaded in this run. Skipping validation set evaluation.")


    elif args.mode == 'evaluate':
        logger.info(f"--- Running Evaluation Mode for models: {current_run_model_keys} ---")
        ensemble_models = load_ensemble_models(current_run_model_keys, model_dir)
        if not ensemble_models: return # Exit if models can't be loaded

        test_ds_eval, test_paths, test_labels = data_processor.create_test_dataset(str(test_dir))
        if not test_ds_eval or test_labels is None or not test_paths: # test_labels must exist for evaluation
            logger.error("Failed to create labeled test dataset or no images/labels found for evaluation. Exiting.")
            return
        logger.info(f"Evaluating on {len(test_paths)} test samples.")

        eval_results_ensemble = trainer.evaluate_ensemble(ensemble_models, test_ds_eval, class_names, "Test Set")
        if eval_results_ensemble:
            try:
                csv_suffix = "_curated" if use_curated_test_set else ""
                ensemble_results_path = results_dir / f"{'_'.join(sorted(current_run_model_keys))}_test_set_evaluation_results{csv_suffix}.json"

                eval_results_serializable = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in eval_results_ensemble.items()}
                if 'classification_report_dict' in eval_results_serializable:
                    eval_results_serializable['classification_report_dict'] = {str(k_): v_ for k_, v_ in eval_results_serializable['classification_report_dict'].items()}
                with open(ensemble_results_path, 'w') as f:
                     json.dump(eval_results_serializable, f, indent=4)
                logger.info(f"Ensemble test results saved to {ensemble_results_path}")
            except Exception as save_e:
                logger.error(f"Failed to save ensemble test results: {save_e}", exc_info=True)
        else:
            logger.error("Ensemble evaluation on test set failed.")

    elif args.mode == 'predict':
        logger.info(f"--- Running Prediction & Optional Validation Evaluation Mode for models: {current_run_model_keys} ---\n")
        ensemble_models = load_ensemble_models(current_run_model_keys, model_dir)
        if not ensemble_models:
            logger.error("Failed to load ensemble models. Exiting predict mode.")
            return

        primary_model_for_viz = None
        primary_model_name_for_viz = "N/A" # Default if not found
        selected_primary_model_key_for_viz = args.primary_viz_model_key

        if selected_primary_model_key_for_viz:
            try:
                model_index = current_run_model_keys.index(selected_primary_model_key_for_viz)
                if model_index < len(ensemble_models): # Check if index is valid for the loaded models
                    primary_model_for_viz = ensemble_models[model_index]
                    primary_model_name_for_viz = primary_model_for_viz.name
                    logger.info(f"Selected '{primary_model_name_for_viz}' (key: {selected_primary_model_key_for_viz}) as the primary model for visualizations.")
                else: # Should not happen if current_run_model_keys is source for index
                    logger.error(f"Index out of bounds for selected primary viz model key '{selected_primary_model_key_for_viz}'. Defaulting...")
                    selected_primary_model_key_for_viz = None # Force default
            except ValueError: # Key not in current_run_model_keys
                logger.warning(f"Primary visualization model key '{selected_primary_model_key_for_viz}' not found in the list of currently run/loaded models: {current_run_model_keys}. Defaulting...")
                selected_primary_model_key_for_viz = None # Force default
            except Exception as e_sel_mod:
                logger.error(f"Error selecting primary visualization model: {e_sel_mod}. Defaulting...")
                selected_primary_model_key_for_viz = None # Force default

        # Default to the first model if selection failed or was not specified properly
        if not primary_model_for_viz and ensemble_models: # If still None, and we have models
            primary_model_for_viz = ensemble_models[0]
            if current_run_model_keys: # Ensure there's a key to associate
                 selected_primary_model_key_for_viz = current_run_model_keys[0]
                 primary_model_name_for_viz = primary_model_for_viz.name
                 logger.info(f"Defaulted to using '{primary_model_name_for_viz}' (key: {selected_primary_model_key_for_viz}) as the primary model for visualizations.")
            else: # Should not happen if ensemble_models is populated
                 logger.error("Cannot default primary visualization model as current_run_model_keys is empty.")
                 # Disable detailed reports if no primary model can be set
                 args.run_gradcam = args.run_lime = args.run_gemini = False
        elif not primary_model_for_viz: # No models loaded at all
            logger.error("No models available for visualization. Cannot proceed with detailed reports.")
            args.run_gradcam = args.run_lime = args.run_gemini = False


        _, test_paths, _ = data_processor.create_test_dataset(str(test_dir)) # Get all paths, labels are not strictly needed for predict CSV
        if not test_paths:
            logger.error("No image paths found in the test directory. Cannot proceed with prediction or detailed reports.")
            return

        csv_suffix = "_curated" if use_curated_test_set else ""
        main_csv_path = results_dir / f"{'_'.join(sorted(current_run_model_keys))}_test_predictions{csv_suffix}.csv"

        main_predictions_df = None
        skip_full_test_prediction_and_validation = False

        if main_csv_path.exists() and args.load_existing_models_if_present:
            logger.info(f"Found existing prediction CSV: {main_csv_path}. Attempting to load.")
            try:
                main_predictions_df = pd.read_csv(main_csv_path)
                if not main_predictions_df.empty:
                    logger.info(f"Successfully loaded {len(main_predictions_df)} predictions from CSV. Skipping validation and full test set prediction.")
                    skip_full_test_prediction_and_validation = True
                else:
                    logger.warning(f"Existing prediction CSV {main_csv_path} is empty. Will re-run predictions.")
                    main_predictions_df = None # Ensure it's None to trigger re-prediction
            except Exception as e_csv_load:
                logger.warning(f"Failed to load or parse existing CSV {main_csv_path}: {e_csv_load}. Will re-run predictions.")
                main_predictions_df = None # Ensure it's None
        else: # CSV doesn't exist or load_existing is False
            if main_csv_path.exists(): # But load_existing is False
                 logger.info(f"Prediction CSV {main_csv_path} found, but --load_existing_models_if_present is False. Will re-run predictions.")
            else: # CSV doesn't exist
                 logger.info(f"Prediction CSV {main_csv_path} not found. Will run full prediction and validation.")


        # Validation on a split of training data (if not skipping)
        if not skip_full_test_prediction_and_validation:
            logger.info("--- [Ensemble] Validation Phase (using split from training data) ---")
            if not verify_directory(train_dir, "Training data directory (for validation split)", check_empty=False):
                logger.warning("Training directory not found or empty. Skipping validation phase in predict mode.")
            else:
                intermediate_val_ds, _, all_val_labels, val_ds_size = data_processor.create_dataset_for_splitting(str(train_dir))
                if not intermediate_val_ds or val_ds_size == 0 or (val_ds_size > 0 and not all_val_labels): # Check if labels are also present
                    logger.error("Failed to create intermediate dataset or labels for validation. Skipping validation phase.")
                else:
                    _, val_ds_for_eval = data_processor.split_and_configure_dataset(intermediate_val_ds, val_ds_size, args.validation_split)
                    logger.info("Evaluating loaded ensemble on the validation set...")
                    eval_results_validation = trainer.evaluate_ensemble(ensemble_models, val_ds_for_eval, class_names, "Validation Set")
                    if eval_results_validation:
                        val_results_filename = f"{'_'.join(sorted(current_run_model_keys))}_validation_set_evaluation_results_predict_mode.json"
                        val_eval_results_path = results_dir / val_results_filename
                        try:
                            eval_results_serializable = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in eval_results_validation.items()}
                            if 'classification_report_dict' in eval_results_serializable:
                                 eval_results_serializable['classification_report_dict'] = {str(k_): v_ for k_, v_ in eval_results_serializable['classification_report_dict'].items()}
                            with open(val_eval_results_path, 'w') as f:
                                 json.dump(eval_results_serializable, f, indent=4)
                            logger.info(f"Ensemble validation set evaluation results (predict mode) saved to {val_eval_results_path}")
                        except Exception as save_e:
                            logger.error(f"Failed to save ensemble validation set evaluation results (predict mode): {save_e}", exc_info=True)
                    else:
                        logger.warning("Ensemble evaluation on validation set (predict mode) failed or produced no results.")
        else: # skip_full_test_prediction_and_validation is True
            logger.info("Skipping [Ensemble] Validation Phase as main prediction CSV was loaded.")

        # Full test set prediction (if not skipping)
        if not skip_full_test_prediction_and_validation:
            logger.info("--- [Ensemble] Prediction Phase on Test Set ---")
            # Create test_ds again, this time it might be unlabeled or labeled depending on test_dir structure
            test_ds_for_actual_prediction, _, _ = data_processor.create_test_dataset(str(test_dir)) # test_labels might be None if unlabeled
            if not test_ds_for_actual_prediction:
                 logger.error("Test dataset (test_ds_for_actual_prediction) could not be created. Cannot run predictions.")
                 return # Critical

            logger.info(f"Preparing to predict on {len(test_paths)} test samples for CSV output.")
            main_predictions_df = trainer.predict_with_ensemble(
                ensemble_models, test_ds_for_actual_prediction, test_paths, class_names
            )

            if main_predictions_df is not None and not main_predictions_df.empty:
                # If using curated (renamed) test set, try to merge true labels from the mapping CSV
                if use_curated_test_set:
                    mapping_csv_path = data_base_path / 'test' / 'renaming_map.csv' # Expected location
                    if mapping_csv_path.exists():
                        try:
                            mapping_df = pd.read_csv(mapping_csv_path)
                            main_predictions_df = pd.merge(
                                main_predictions_df,
                                mapping_df[['new_random_filename', 'true_label']],
                                left_on='filename', # 'filename' in main_predictions_df is the new_random_filename
                                right_on='new_random_filename',
                                how='left'
                            )
                            # main_predictions_df.drop(columns=['new_random_filename'], inplace=True, errors='ignore') # Keep for clarity or drop
                            logger.info(f"Successfully merged prediction results with true labels from {mapping_csv_path}.")
                        except Exception as e_merge:
                            logger.error(f"Failed to merge with {mapping_csv_path}: {e_merge}", exc_info=True)
                    else:
                        logger.warning(f"Curated test set used, but mapping file {mapping_csv_path} not found. True labels will be missing in CSV.")

                main_predictions_df.to_csv(main_csv_path, index=False)
                logger.info(f"Ensemble predictions on test set completed and CSV saved to {main_csv_path}.")
            else:
                logger.warning("Ensemble prediction on test set failed or produced no DataFrame. Detailed reports might be affected.")
        else: # skip_full_test_prediction_and_validation is True
            logger.info("Skipping [Ensemble] Full Test Set Prediction as data was loaded from existing CSV.")


        # --- Detailed Report Generation for ALL files if XAI features are enabled ---
        target_filenames_for_report = []
        filename_to_details_map = {}
        generate_reports_for_all = args.mode == 'predict' and (args.run_gradcam or args.run_lime or args.run_gemini)

        if generate_reports_for_all:
            if main_predictions_df is not None and not main_predictions_df.empty:
                logger.info(f"Preparing to generate detailed reports for ALL {len(main_predictions_df)} samples found in the prediction results.")
                target_filenames_for_report = main_predictions_df['filename'].tolist()
                for index, row_data_report in main_predictions_df.iterrows():
                    filename = row_data_report['filename']
                    try:
                        pred_class_idx = class_names.index(row_data_report['predicted_class_ensemble'])
                    except ValueError:
                        logger.error(f"Predicted class '{row_data_report['predicted_class_ensemble']}' for file '{filename}' not in known class_names {class_names}. Skipping this file for detailed report.")
                        if filename in target_filenames_for_report:
                            target_filenames_for_report.remove(filename)
                        continue # Skip this file

                    filename_to_details_map[filename] = {
                        'pred_class': row_data_report['predicted_class_ensemble'],
                        'pred_idx': pred_class_idx,
                        'confidence': row_data_report['confidence_ensemble'],
                        'all_scores_dict': {name: row_data_report.get(f'score_ensemble_{name}', 0.0) for name in class_names}
                    }
                if not target_filenames_for_report:
                     logger.warning("No files selected for detailed reporting after processing main_predictions_df (possibly due to class name errors).")
            else:
                logger.warning("Prediction data (main_predictions_df) is not available or empty. Cannot generate detailed reports.")
        else:
            logger.info("Detailed reporting for individual samples is not applicable or not requested for this run.")


        if target_filenames_for_report: # If there are files to report on
            logger.info(f"--- Generating Detailed Reports for {len(target_filenames_for_report)} Test Samples ---")
            collected_samples_for_report: List[Dict[str, Any]] = []
            filename_to_fullpath_map = {Path(p).name: p for p in test_paths}

            # Efficiently get image data for all target files
            temp_image_store: Dict[str, np.ndarray] = {}
            # Create a dataset filtered for only the target images if possible, or load one by one
            # For simplicity here, we'll assume test_paths contains paths to all images in test_dir
            # and we filter based on target_filenames_for_report.
            for filename_report in target_filenames_for_report:
                full_path_str = filename_to_fullpath_map.get(filename_report)
                if full_path_str:
                    try:
                        # Use _parse_image which handles reading and initial processing
                        img_tensor, _ = data_processor._parse_image(tf.constant(full_path_str), tf.constant(0)) # Dummy label
                        temp_image_store[filename_report] = img_tensor.numpy()
                    except Exception as e_load_direct:
                        logger.error(f"Failed to load image {filename_report} from {full_path_str} using _parse_image: {e_load_direct}")
                else:
                    logger.warning(f"Full path for {filename_report} not found in filename_to_fullpath_map. Cannot generate report for it.")


            for filename_report in target_filenames_for_report:
                if filename_report in temp_image_store and filename_report in filename_to_details_map:
                    details = filename_to_details_map[filename_report]
                    collected_samples_for_report.append({
                        'path': filename_to_fullpath_map[filename_report], # Store full path
                        'image_np': temp_image_store[filename_report],
                        'pred_idx': details['pred_idx'],
                        'pred_class': details['pred_class'],
                        'confidence': details['confidence'],
                        'all_scores_dict': details['all_scores_dict']
                    })
                else:
                    logger.warning(f"Could not find image data or prediction details for selected filename {filename_report}. Skipping for detailed report.")

            if collected_samples_for_report:
                lime_explainer_instance = None
                if args.run_lime and primary_model_for_viz:
                    logger.info(f"Initializing LIME explainer for {primary_model_name_for_viz} for detailed reporting...")
                    try:
                        lime_explainer_instance = LimeExplainer(primary_model_for_viz, class_names)
                    except Exception as lime_init_e:
                        logger.error(f"Failed to initialize LIME for {primary_model_name_for_viz}: {lime_init_e}. Disabling LIME for reports.")
                        args.run_lime = False # Disable LIME if init fails

                for i_report, sample_data_report in enumerate(collected_samples_for_report):
                    img_path_report = Path(sample_data_report['path'])
                    img_np_normalized_report = sample_data_report['image_np'] # Already normalized by DataProcessor
                    logger.info(f"--- Generating detailed report for sample {i_report+1}/{len(collected_samples_for_report)}: {img_path_report.name} ---")

                    gradcam_overlay_np_report, grad_cam_description_report = None, "Grad-CAM not run or failed."
                    if args.run_gradcam and primary_model_for_viz:
                        logger.info(f"  Generating Grad-CAM for {primary_model_name_for_viz}...")
                        try:
                            gradcam_layer_name = find_suitable_gradcam_layer(primary_model_for_viz, model_key=selected_primary_model_key_for_viz)
                            if not gradcam_layer_name: # Fallback if auto-detection fails
                                logger.warning(f"find_suitable_gradcam_layer did not return a layer for {selected_primary_model_key_for_viz}. Attempting explicit fallback.")
                                fallbacks = {"inceptionv3": "mixed10", "resnet50v2": "conv5_block3_out", "mobilenetv2": "out_relu", "efficientnet_b0": "top_activation", "densenet121": "relu"}
                                gradcam_layer_name = fallbacks.get(selected_primary_model_key_for_viz)
                                if gradcam_layer_name: logger.info(f"  Explicitly setting Grad-CAM layer for {selected_primary_model_key_for_viz} to: {gradcam_layer_name}")

                            if gradcam_layer_name:
                                heatmap = make_gradcam_heatmap(
                                    np.expand_dims(img_np_normalized_report, axis=0),
                                    primary_model_for_viz,
                                    gradcam_layer_name,
                                    sample_data_report['pred_idx']
                                )
                                if heatmap is not None and heatmap.max() > 0: # Check if heatmap is not all zeros
                                    gradcam_overlay_np_report = display_gradcam(img_np_normalized_report, heatmap, alpha=0.6)
                                    grad_cam_description_report = f"Highlights areas for '{sample_data_report['pred_class']}' (layer '{gradcam_layer_name}')."
                                else:
                                    grad_cam_description_report = f"Grad-CAM heatmap was zero or non-indicative for layer '{gradcam_layer_name}'."
                                    if gradcam_overlay_np_report is None and heatmap is not None: # Still provide original if overlay fails but heatmap was generated
                                        gradcam_overlay_np_report = np.clip(img_np_normalized_report * 255, 0, 255).astype(np.uint8)
                            else: # No suitable layer found even after fallback
                                grad_cam_description_report = f"Suitable Grad-CAM layer not found in {primary_model_name_for_viz}."
                                gradcam_overlay_np_report = np.clip(img_np_normalized_report * 255, 0, 255).astype(np.uint8) # Show original
                        except Exception as viz_e:
                             logger.error(f"Grad-CAM failed for {img_path_report.name}: {viz_e}", exc_info=True)
                             grad_cam_description_report = "Grad-CAM generation encountered an error."
                             gradcam_overlay_np_report = np.clip(img_np_normalized_report * 255, 0, 255).astype(np.uint8) # Show original
                    elif not primary_model_for_viz:
                        grad_cam_description_report = "Primary model for Grad-CAM not available."
                        gradcam_overlay_np_report = np.clip(img_np_normalized_report * 255, 0, 255).astype(np.uint8)
                    elif not args.run_gradcam:
                        grad_cam_description_report = "Grad-CAM was not requested to run."
                        gradcam_overlay_np_report = np.clip(img_np_normalized_report * 255, 0, 255).astype(np.uint8)


                    lime_overlay_np_report, lime_description_report = None, "LIME not run or failed."
                    if args.run_lime and lime_explainer_instance: # Check if LIME was requested AND explainer is valid
                        logger.info(f"  Generating LIME for {primary_model_name_for_viz}...")
                        try:
                            # Use a simpler segmentation for faster processing if many images
                            seg_fn = slic if len(collected_samples_for_report) > 10 else None
                            _, lime_fig, lime_overlay_array = lime_explainer_instance.explain_instance(
                                img_np_normalized_report, num_samples=args.lime_samples, top_labels=1, num_features=10,
                                positive_only=True, hide_rest=False, output_dir=None, segmentation_fn=seg_fn
                            )
                            if lime_overlay_array is not None:
                                 lime_overlay_np_report = lime_overlay_array
                                 lime_description_report = f"Highlights superpixels contributing to '{sample_data_report['pred_class']}'."
                            else: lime_description_report = "LIME overlay generation failed."
                            if lime_fig: plt.close(lime_fig) # Close the figure LIME might generate
                        except Exception as lime_e:
                             logger.error(f"LIME failed for {img_path_report.name} on {primary_model_name_for_viz}: {lime_e}", exc_info=False) # exc_info=False for LIME verbosity
                             lime_description_report = "LIME generation encountered an error."
                    elif args.run_lime and not lime_explainer_instance: # LIME requested but explainer failed
                        lime_description_report = "LIME explainer not initialized."


                    gemini_text_report = "AI explanation disabled or not requested."
                    if args.run_gemini: # Check if Gemini was requested
                        logger.info("  Generating Gemini explanation...")
                        try:
                              gemini_text_report = generate_gemini_explanation(
                                   img_path_report.name, sample_data_report['pred_class'], sample_data_report['confidence'],
                                   grad_cam_description_report, lime_description_report, class_names,
                                   ensemble_model_names_str = ', '.join([m.name for m in ensemble_models]), # Pass ensemble names
                                   primary_model_name_for_viz = primary_model_name_for_viz
                              )
                        except Exception as gemini_gen_e:
                              logger.error(f"Gemini explanation generation failed for {img_path_report.name}: {gemini_gen_e}", exc_info=True)
                              gemini_text_report = f"Error during AI explanation: {str(gemini_gen_e)}"

                    report_output_path_detailed = report_images_dir / f"{img_path_report.stem}_classification_report.png"
                    create_combined_report_image(
                        img_np_normalized_report, gradcam_overlay_np_report, lime_overlay_np_report, img_path_report.name,
                        sample_data_report['pred_class'], sample_data_report['confidence'], sample_data_report['all_scores_dict'],
                        gemini_text_report, report_output_path_detailed, primary_model_name_for_viz
                    )
            else: # No samples collected for reporting
                logger.info("No samples were ultimately collected for detailed reporting.")
        else: # Detailed reporting not enabled or no files selected
            logger.info("Detailed reporting for individual samples is disabled or no samples were selected/available.")


    else: # Invalid mode
        logger.error(f"Invalid mode specified: {args.mode}. Choose 'train', 'evaluate', or 'predict'.")

    end_time = time.time()
    logger.info(f"--- Pipeline Execution Finished (Mode: {args.mode}) ---")
    logger.info(f"Total execution time: {end_time - start_time:.2f} seconds")

    # Clean up logging handlers
    handlers = logger.handlers[:]
    for handler in handlers:
        try:
            handler.flush()
            handler.close()
            logger.removeHandler(handler)
        except Exception as e_log_close:
            print(f"Error closing/removing log handler {handler}: {e_log_close}") # Use print as logger might be closing
    logging.shutdown()


# ==================================\n
# --- Argument Parsing & Entry Point ---\n
# ==================================
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Lung Cancer Classification Pipeline using Ensemble CNNs with Enhanced Reporting.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter # Shows default values in help
    )
    parser.add_argument('--mode', type=str, default='train', choices=['train', 'evaluate', 'predict'],
                        help='Pipeline execution mode.')
    parser.add_argument('--base_dir', type=str, default=None, # Default handled in script
                        help='Base directory for project data, models, logs, and results.')
    parser.add_argument('--img_height', type=int, default=224, help='Target image height for resizing.')
    parser.add_argument('--img_width', type=int, default=224, help='Target image width for resizing.')
    parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training and evaluation.')
    parser.add_argument('--validation_split', type=float, default=0.2, help='Fraction of training data to use for validation.')
    parser.add_argument('--initial_epochs', type=int, default=1, help='Number of epochs for initial training (feature extraction).')
    parser.add_argument('--fine_tune_epochs', type=int, default=1, help='Number of epochs for fine-tuning.')
    parser.add_argument('--initial_lr', type=float, default=1e-3, help='Learning rate for initial training.')
    parser.add_argument('--fine_tune_lr', type=float, default=1e-5, help='Learning rate for fine-tuning.')
    parser.add_argument('--fine_tune_at', type=str, default=None, # Can be int or str
                        help='Layer name or index (int) to start fine-tuning from within the base model. If None, behavior depends on model creation function.')
    parser.add_argument('--use_class_weights', action='store_true', default=False,
                        help='Whether to use class weights to handle imbalanced datasets during training.')
    parser.add_argument('--load_existing_models_if_present', action='store_true', default=False,
                        help='In "train" or "predict" mode, if final model files exist, load them instead of retraining/regenerating predictions.')
    parser.add_argument('--train_models', type=str, nargs='+', default=None, choices=ALL_AVAILABLE_MODEL_KEYS,
                        help=f'Specify which models to train if mode is "train". If None, defaults to: {ENSEMBLE_MODEL_KEYS}. Available: {ALL_AVAILABLE_MODEL_KEYS}')
    parser.add_argument('--num_predict_samples', type=int, default=0, # Default to 0, meaning no detailed reports unless XAI flags are on
                        help='Enables detailed visual reports (Grad-CAM, LIME, Gemini) if > 0 and XAI flags are set. Reports will be for ALL test images.')
    parser.add_argument('--run_gradcam', action='store_true', default=False, help='Generate Grad-CAM visualizations for detailed reports.')
    parser.add_argument('--run_lime', action='store_true', default=False, help='Generate LIME visualizations for detailed reports.')
    parser.add_argument('--lime_samples', type=int, default=200, # Reduced default for potentially many images
                        help='Number of samples for LIME explainer per image.')
    parser.add_argument('--run_gemini', action='store_true', default=False, help='Generate Gemini AI explanations for detailed reports.')
    parser.add_argument('--primary_viz_model_key', type=str,
                        default=ENSEMBLE_MODEL_KEYS[0] if ENSEMBLE_MODEL_KEYS else (ALL_AVAILABLE_MODEL_KEYS[0] if ALL_AVAILABLE_MODEL_KEYS else None),
                        choices=ALL_AVAILABLE_MODEL_KEYS,
                        help='Model key from the loaded/trained models to use for generating Grad-CAM and LIME visualizations in detailed reports.')
    parser.add_argument('--use_curated_test_set', action='store_true', default=False,
                        help='If true, uses the "my_curated_test_images_renamed" folder for prediction/evaluation and expects "renaming_map.csv".')


    is_notebook = 'ipykernel' in sys.modules or 'google.colab' in sys.modules
    # specific_files_for_notebook_report is no longer used for selection in the modified main()
    # but we keep num_predict_samples > 0 to trigger the report section.
    num_specific_files_placeholder = 1 # Set to 1 to enable report section if XAI flags are on

    if is_notebook:
        print(f"Running in notebook environment. If XAI reports are enabled, they will be generated for ALL test images.")
        default_base_dir_val = '/content/drive/MyDrive/CNN_Medical_Imaging_Project' if 'google.colab' in sys.modules else './CNN_Medical_Imaging_Project'
        chosen_viz_model = 'resnet50v2' # Or inceptionv3, densenet121, mobilenetv2. Model can be changed.
        args_list = [
            '--mode', 'predict',
            '--base_dir', default_base_dir_val,
            '--num_predict_samples', str(num_specific_files_placeholder),
            '--run_gradcam',
            '--run_lime',
            '--run_gemini',
            '--lime_samples', '2000',
            '--primary_viz_model_key', chosen_viz_model,
            '--load_existing_models_if_present'
        ]

        args = parser.parse_args(args_list)
        print(f"Notebook using base_dir: {args.base_dir}")
        print(f"Notebook mode: {args.mode}, initial num_predict_samples_flag: {args.num_predict_samples}, run_gradcam: {args.run_gradcam}, run_lime: {args.run_lime}, run_gemini: {args.run_gemini}, load_existing: {args.load_existing_models_if_present}, use_curated: {args.use_curated_test_set}")
    else: # Running as a script
        print("Running as a script, parsing command-line arguments.")
        args = parser.parse_args()
        if args.base_dir is None: # Set default base_dir if not provided for script
             args.base_dir = './CNN_Medical_Imaging_Project'
             print(f"Setting default base_dir for script: {args.base_dir}")
        # If XAI features are on but num_predict_samples is 0, set it to 1 to enable the report block
        if args.mode == 'predict' and (args.run_gradcam or args.run_lime or args.run_gemini) and args.num_predict_samples == 0:
            args.num_predict_samples = 1 # Enable report block, actual count determined later
            print("XAI reports enabled with num_predict_samples=0, setting flag to enable report block for all images.")


    if not args.base_dir: # Should be set by now
        print("Base directory (--base_dir) is not set. Cannot proceed.")
        sys.exit(1)

    # Setup logging (ensure it's done once)
    log_file_path = Path(args.base_dir) / 'ensemble_pipeline_main.log'
    # Clear existing handlers from root logger before basicConfig
    for handler in logging.getLogger().handlers[:]:
        logging.getLogger().removeHandler(handler)
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(sys.stdout), # To show logs in notebook output / console
            logging.FileHandler(log_file_path, mode='a') # Append to log file
        ]
    )
    logger.info(f"Logging to console and to file: {log_file_path}")

    try:
        # Validate arguments
        if not (0 < args.validation_split < 1):
             raise ValueError("--validation_split must be between 0 and 1 (exclusive).")
        if args.fine_tune_at is not None: # Attempt to convert to int if it's a digit string
             try: args.fine_tune_at = int(args.fine_tune_at)
             except ValueError: pass # Keep as string if not convertible (layer name)

        tf.keras.backend.clear_session() # Good practice before starting model operations
        logger.info("Cleared Keras backend session.")

        main(args) # Call main, specific_files_for_notebook_report is not used for selection anymore

    except ValueError as e_val:
        logger.critical(f"Configuration or Data Error: {e_val}", exc_info=True)
    except FileNotFoundError as e_file:
        logger.critical(f"File/Directory Error: {e_file}. Check paths derived from base_dir.", exc_info=True)
    except Exception as e_main:
        logger.critical(f"An unexpected error occurred in the main pipeline: {e_main}", exc_info=True)
    finally:
        # Final attempt to close log handlers
        handlers = logging.getLogger().handlers[:]
        for handler in handlers:
            try:
                handler.flush()
                handler.close()
                logging.getLogger().removeHandler(handler)
            except Exception as e_log_close:
                print(f"Error closing/removing log handler {handler}: {e_log_close}")
        logging.shutdown()

2025-05-24 09:21:54,001 - GeminiAI - INFO - Google Generative AI SDK imported and API key seems valid. Using model: models/gemini-1.5-flash-latest
Running in notebook environment. If XAI reports are enabled, they will be generated for ALL test images.
Notebook using base_dir: /content/drive/MyDrive/CNN_Medical_Imaging_Project
Notebook mode: predict, initial num_predict_samples_flag: 1, run_gradcam: True, run_lime: True, run_gemini: True, load_existing: True, use_curated: False
2025-05-24 09:21:54,207 - root - INFO - Logging to console and to file: /content/drive/MyDrive/CNN_Medical_Imaging_Project/ensemble_pipeline_main.log
2025-05-24 09:21:54,784 - root - INFO - Cleared Keras backend session.
2025-05-24 09:21:54,803 - root - INFO - --- Starting Pipeline Execution (Mode: predict) ---
2025-05-24 09:21:54,804 - root - INFO - Script arguments: {'mode': 'predict', 'base_dir': '/content/drive/MyDrive/CNN_Medical_Imaging_Project', 'img_height': 224, 'img_width': 224, 'batch_size': 16, 'valid