# Analysis of training results and visualization of YOLO predictions

## Imports and Settings

In [1]:
%matplotlib inline
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import logging
import torch
from pathlib import Path
from pprint import pprint
from typing import Any, List, Dict, Optional, Union
from tqdm import tqdm

from ultralytics import YOLO
from ultralytics.utils.metrics import DetMetrics

from inference import create_filepath_inference_lists, Yolo3dBatchInference

from viz import (
    plot_images_with_bboxes,
    png_to_np,
    bbox_txt_to_list,
    unnormalize_yolo_bboxes,
)
from utils import (
    analyze_best_metrics,
    MAP50_COL,
    MAP50_95_COL,
    PRECISION_COL,
    RECALL_COL,
    FITNESS_COL,
)

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

## Configuration and constants


In [None]:
SLICE_NAMES: Dict[int, str] = {0: "sagittal", 1: "coronal", 2: "axial"}
MODALITIES: List[str] = ["FLAIR", "T1CE", "T2", "PD", "T1", "gt"]

# Path to raw .nii.gz. nifti files
BASE_DATASET_PATH = Path("/path/to/your/data/please/edit/me")
RAW_MS_SHIFT_DATASET_PATH = BASE_DATASET_PATH / "MS_shift"
RAW_MSLESSEG_DATASET_PAT = BASE_DATASET_PATH / "MSLesSeg"

#  We moved 1-20 test patients to train split.
#  And 21-23 test patients to val split in MS_Shift dataset
#  Thats why in test we had only 10 patients (24-33)
TEST_SPLIT_INDEXES_MS_SHIFT = tuple(range(24, 34))

# Path to processed files (2d .png images)
PROCESSED_MS_SHIFT_DATASET_PATH = Path(
    "/path/to/your/data/please/edit/me"
)
PROCESSED_MS_SHIFT_MSLESSEG_DATASET_PATH = Path(
    "/path/to/your/data/please/edit/me"
)

LABELS_DIR = "labels"
IMAGES_DIR = "images"

# Paths to trained models
RUNS_DIR = "/path/to/your/data/please/edit/me/runs/detect"

"""
Model Naming Convention Specification for MS Lesion Detection YOLO Models

**Format:**
{DATASET_CODES}_{MODALITIES}_{SLICE_DIMS}_{EPOCHS}ep_{CLASSES}cls_{IMGSZ}imgsz_{AUGMENT}_{MODEL_SIZE}{YOLO_VERSION}_{COMMENTS}

1.  **{DATASET_CODES}** (Required)
    - Code(s) indicating the dataset(s) used.
    - Codes: 'mss' (MS Shift), 'msl' (MSLesSeg).
    - Combined: 'msl_mss' or 'mss_msl'.

2.  **{MODALITIES}** (Optional - specify only if **not** using all modalities)
    - MRI modality subset used.
    - Codes: 'FLAIR', 'T1', 'T2', 'PD', 'T1CE'.
    - Example: 'FLAIR'.

3.  **{SLICE_DIMS}'** (Optional - specify only if **not** using all projections)
    - Slice projections used for 2D data generation.
    - Codes: 'sag' (sagittal, dim=0), 'cor' (coronal, dim=1), 'ax' (axial, dim=2).
    - Example: 'ax'.

4.  **{EPOCHS}ep** (Required)
    - Number of training epochs.
    - Format: number + 'ep'.
    - Example: '100ep', '50ep'

5.  **{CLASSES}cls** (Optional. Default is 1 cls)
    - Number of detection classes.
    - Format: number + 'cls'.
    - Example: '1cls', '2cls'

6.  **{IMGSZ}imgsz** (Required)
    - Square image size used during training.
    - Format: number + 'imgsz'.
    - Example: '640imgsz', '256imgsz'

7.  **{AUGMENT}** (Optional - specify **if** augmentation was enabled)
    - Flag indicating use of standard Ultralytics augmentations.
    - Code: 'aug'
    - Example: 'aug'

8.  **{MODEL_SIZE}** (Required)
    - Base YOLO model size/architecture.
    - Codes: 'n', 's', 'm', 'l', 'x'.
    - Example: 'l', 'm'

9.  **{YOLO_VERSION}** (Required)
    - YOLO architecture version if different from standard (e.g., v8) or custom.
    - Format: number
    - Example: '11'

10.  **{COMMENTS}** (Optional)
    - Other parameters
    - Example: 'filtered data'
"""
# Commented models wat trained on different sizes of train test val:
# Other models trained, tested and valuated on a same sizes
MODEL_DIRS_NAMES_PATHS = {
    "mss_23ep_640imgsz_l12": "MS_shift/train_All_mods_dims_23ep_1cls_640imgsz_Yolo12_large",
}

BEST_MODEL_PATHS = {
    model_name: os.path.join(RUNS_DIR, model_path, "weights/best.pt")
    for model_name, model_path in MODEL_DIRS_NAMES_PATHS.items()
}

DATA_YAML_MS_SHIFT = "MS_Shift.yaml"
DATA_YAML_MSSHIFT_MSLESSEG = "MSShift_MSLesSeg.yaml"

# Images to render (specify relative paths from IMAGES_DIR)
IMAGES_TO_VISUALIZE: List[str] = [
    "test/MSShift_test_24_FLAIR_idx_169_axial.png",
    "test/MSShift_test_28_FLAIR_idx_89_axial.png",
    # "test/MSShift_test_25_FLAIR_idx_89_axial.png",
]

## Helper functions

In [3]:
def load_yolo_models(
    model_paths_dict: Dict[str, str],
    model_keys_to_load: Optional[Union[str, List[str]]] = None,
) -> Dict[str, YOLO]:
    """
        Loads one or more YOLO models from a path dictionary.
        Checks for file existence before loading and handles errors.

    Args:
        model_paths_dict (Dict[str, str]): Dictionary where keys are model names (strings)
                                           and values ​​are paths to .pt weight files (strings).
        model_keys_to_load (Optional[Union[str, List[str]]], optional):
            - If None (default): Loads all models from model_paths_dict.
            - If string: Loads only the model with the specified key.
            - If list of strings: Loads only the models with the keys in the list.
            This saves GPU memory by loading only the models you need.

        Returns:
            Dict[str, YOLO]: A dictionary containing loaded YOLO objects,
                             where the keys are the names of the models that were successfully loaded.
    """
    loaded_models: Dict[str, YOLO] = {}
    keys_to_process: List[str] = []

    # Determine which keys (models) need to be loaded
    if model_keys_to_load is None:
        keys_to_process = list(model_paths_dict.keys())
        logger.info(f"Attempting to load all {len(keys_to_process)} models specified.")
    elif isinstance(model_keys_to_load, str):
        if model_keys_to_load in model_paths_dict:
            keys_to_process = [model_keys_to_load]
            logger.info(f"Attempting to load specified model: '{model_keys_to_load}'")
        else:
            logger.warning(
                f"Specified model key '{model_keys_to_load}' not found in model_paths_dict. No models will be loaded."
            )
            return loaded_models  # Return an empty dictionary
    elif isinstance(model_keys_to_load, list):
        keys_to_process = [key for key in model_keys_to_load if key in model_paths_dict]
        if len(keys_to_process) < len(model_keys_to_load):
            missing_keys = set(model_keys_to_load) - set(keys_to_process)
            logger.warning(
                f"Specified model keys not found in model_paths_dict: {missing_keys}. Loading only found models."
            )
        if not keys_to_process:
            logger.warning(
                "None of the specified model keys were found. No models will be loaded."
            )
            return loaded_models
        logger.info(f"Attempting to load specified models: {keys_to_process}")
    else:
        logger.error(
            f"Invalid type for model_keys_to_load: {type(model_keys_to_load)}. Expected None, str, or List[str]."
        )
        return loaded_models  # Return an empty dictionary

    # Loading selected models
    for name in keys_to_process:
        path_str = model_paths_dict[name].strip()
        path = Path(path_str)

        if path.is_file():
            try:
                loaded_models[name] = YOLO(str(path))  # Use str(path) for compatibility
                logger.info(f"Loaded model '{name}' from {path}")
            except Exception as e:
                logger.error(
                    f"Error loading model '{name}' from {path}: {e}", exc_info=True
                )
        else:
            logger.warning(f"Skipping model '{name}' - file not found at {path}.")

    logger.info(f"Loaded {len(loaded_models)} models.")
    return loaded_models


def extract_imgsz_from_name(model_name: str) -> int:
    """
    Extracts imgsize from model name. Used in YOLO inference.py

    Args:
        model name (str): model name (keys in MODEL_DIRS_NAMES_PATHS)
    """
    imgsz = 640  # default
    if "256imgsz" in model_name:
        imgsz = 256
    elif "224imgsz" in model_name:
        imgsz = 224
    return imgsz


def extract_serializable_metrics(
    metrics_obj: Optional[DetMetrics],
) -> Optional[Dict[str, float]]:
    """
    Extracts the primary numeric metrics from the DetMetrics object for storing in JSON.

    Args:
        metrics_obj (Optional[DetMetrics]): The metrics object from model.val().

    Returns:
        Optional[Dict[str, float]]: A dictionary of primary metrics, or None if the input object is invalid.
    """
    if metrics_obj is None or not hasattr(metrics_obj, "box"):
        return None

    output_metrics: Dict[str, Any] = {}
    # Extracting box metrics
    if hasattr(metrics_obj, "box"):
        try:
            box_metrics = metrics_obj.box
            output_metrics.update(
                {
                    "map50": float((getattr(box_metrics, "map50", None))),
                    "map75": float((getattr(box_metrics, "map75", None))),
                    "map50-95": float((getattr(box_metrics, "map", None))),
                    "precision": float((getattr(box_metrics, "p", None))),
                    "recall": float((getattr(box_metrics, "r", None))),
                    "mean_precision": float(
                        (getattr(box_metrics, "mp", None))
                    ),  # Mean Precision
                    "mean_recall": float(
                        (getattr(box_metrics, "mr", None))
                    ),  # Mean Recall
                    "mean_f1": float((getattr(box_metrics, "f1", None))),
                    "fitness": float((getattr(metrics_obj, "fitness", None))),
                }
            )

        except Exception as e:
            logger.error(f"Error extracting serializable metrics: {e}", exc_info=True)
            return None

    if (
        hasattr(metrics_obj, "confusion_matrix")
        and metrics_obj.confusion_matrix is not None
    ):
        cm = metrics_obj.confusion_matrix
        if hasattr(cm, "matrix") and isinstance(cm.matrix, np.ndarray):
            # Convert numpy array to list of lists for JSON serialization
            output_metrics["confusion_matrix"] = cm.matrix.tolist()
            tp = cm.matrix[0, 0]
            fn = cm.matrix[1, 0]
            output_metrics["images_scanned"] = int(tp + fn)
    else:
        logger.warning(
            "Metrics object has 'confusion_matrix' attribute, but no valid '.matrix' numpy array found."
        )

    return output_metrics


def save_metrics_to_json(metrics_dict: Dict[str, Any], filepath: Union[str, Path]):
    """
    Saves a dictionary of metrics to a JSON file.
    Overwrites the file if it exists.

    Args:
        metrics_dict (Dict[str, Any]): Dictionary of metrics to save.
                                       Values ​​are assumed to be serializable to JSON.
        filepath (Union[str, Path]): Path to file to save the JSON to.
    """
    try:
        path = Path(filepath)
        path.parent.mkdir(parents=True, exist_ok=True)
        with open(path, "w", encoding="utf-8") as f:
            json.dump(metrics_dict, f, indent=4, ensure_ascii=False)
        logger.info(f"Metrics successfully saved to: {path}")
    except TypeError as te:
        logger.error(
            f"Failed to serialize metrics to JSON. Ensure all values are JSON-serializable (e.g., numbers, strings, lists, dicts): {te}",
            exc_info=True,
        )
    except Exception as e:
        logger.error(
            f"Failed to save metrics summary to {filepath}: {e}", exc_info=True
        )


def load_metrics_from_json(filepath: Union[str, Path]) -> Optional[Dict[str, Any]]:
    """
    Loads a dictionary of metrics from a JSON file.

    Args:
        filepath (Union[str, Path]): Path to the JSON file with metrics.

    Returns:
        Optional[Dict[str, Any]]: The loaded dictionary of metrics, or None if the file
        was not found or a read/parse error occurred.
    """
    path = Path(filepath)
    if not path.is_file():
        logger.warning(f"Metrics file not found: {path}. Returning None.")
        return None
    try:
        with open(path, "r", encoding="utf-8") as f:
            loaded_metrics = json.load(f)
        logger.info(f"Metrics successfully loaded from: {path}")
        return loaded_metrics
    except json.JSONDecodeError as jde:
        logger.error(f"Error decoding JSON from {path}: {jde}", exc_info=True)
        return None
    except Exception as e:
        logger.error(f"Failed to load metrics summary from {path}: {e}", exc_info=True)
        return None


def display_metrics_summary(metrics_summary: Dict[str, Any]):
    """
    Converts a dictionary of metrics to a DataFrame and outputs it.

        Args:
            metrics_summary (Dict[str, Any]): Dictionary loaded from JSON or assembled in the process.
            Expected structure: {model_name: {split_name: {metric: value}}}
    """
    if not metrics_summary:
        print("Metrics summary is empty. Nothing to display.")
        return

    rows = []
    for model_name, splits_data in metrics_summary.items():
        for split_name, metrics_data in splits_data.items():
            row = {"model": model_name, "split": split_name}
            if metrics_data:
                row.update(metrics_data)
            else:
                # Add NaN for models/splits with errors or without metrics
                row.update(
                    {
                        m: np.nan
                        for m in ["map", "map50", "map75", "precision", "recall", "f1"]
                    }
                )
            rows.append(row)

    summary_df = pd.DataFrame(rows)
    print("\n--- Validation Metrics Summary ---")
    # Use to_string to output the entire table without truncation
    print(summary_df.to_string(float_format="%.4f", na_rep="N/A"))
    print("-----------------------------------")


def validate_and_print_detailed_metrics(
    model: YOLO,
    model_name: str,
    data_yaml: str,
    splits: Optional[List[str]] = None,
    imgsz: int = 640,
    batch: int = 24,
    iou_threshold: Optional[float] = None,
    conf_threshold: Optional[float] = None,
) -> Dict[str, Optional[DetMetrics]]:
    """
    Runs YOLO model validation on the specified splits and outputs detailed metrics.

    Args:
        model (YOLO): The loaded Ultralytics YOLO model object.
        model_name (str): The model name for logging and headers.
        data_yaml (str): Path to the data.yaml file describing the dataset.
        splits (Optional[List[str]], optional): List of splits to validate (e.g., ['val', 'test']).
        If None, 'val' is used. Defaults to None.
        imgsz (int, optional): Image size to validate. Defaults to 640.
        batch (int, optional): Batch size to validate. Defaults to 32.
        iou_threshold (Optional[float], optional): IoU threshold for NMS. If None, YOLO default (0.6) is used.
        conf_threshold (Optional[float], optional): Confidence threshold for calculating metrics. If None, the default YOLO (0.001) is used.

    Returns:
        Dict[str, Optional[DetMetrics]]: A dictionary where keys are split names
                                        and values ​​are DetMetrics objects with validation results or None on error.
    """
    if splits is None:
        splits = ["val"]

    if not Path(data_yaml).is_file():
        logger.error(f"Data YAML file not found: {data_yaml}")
        return {split: None for split in splits}

    logger.info(f"--- Starting Validation for Model: {model_name} ---")
    logger.info(f"  Dataset: {data_yaml}")
    logger.info(f"  Splits: {splits}")
    logger.info(f"  Image Size: {imgsz}")
    logger.info(f"  Batch Size: {batch}")
    if iou_threshold is not None:
        logger.info(f"  IoU Threshold: {iou_threshold}")
    if conf_threshold is not None:
        logger.info(f"  Conf Threshold: {conf_threshold}")
    print("-" * 30)

    results_dict: Dict[str, Optional[DetMetrics]] = {}

    for split in splits:
        logger.info(f"Validating on split: '{split}'...")
        metrics = None
        run_name = f"{model_name}_{split}"
        try:
            # Create a dictionary of arguments for model.val to pass only non-None values
            val_args = {
                "data": data_yaml,
                "split": split,
                "imgsz": imgsz,
                "batch": batch,
                "name": run_name,
                "verbose": True,  # Enable YOLO output for tracking
            }
            if iou_threshold is not None:
                val_args["iou"] = iou_threshold
            if conf_threshold is not None:
                val_args["conf"] = conf_threshold

            metrics = model.val(**val_args)

            results_dict[split] = metrics
            # --- Print Metrics ---
            print("-" * 40)
            print(f"\n--- Metrics for Model '{model_name}' on Split '{split}' ---")
            for key, value in extract_serializable_metrics(metrics).items():
                if isinstance(value, (list, tuple)):
                    formatted_value = str(value)  # Сonvert the list to a string
                else:
                    formatted_value = f"{float(value):>10.4f}"  # Formatting numbers
                print(f"  {key:<20}    {formatted_value}")
            print("-" * 40)

        except Exception as e:
            logger.error(
                f"An error occurred during validation for split '{split}': {e}",
                exc_info=True,
            )

    return results_dict

## Analysis of the number of labels in the dataset

In [None]:
def total_lesion_labels_amount(labels_dirpath: str) -> int:
    """
    Counts the total number of lines (labels) in all .txt files
    in the specified directory (e.g. train, val, or test).

    Args:
    labels_dirpath (str): Path to the directory containing the labels (.txt) files.

    Returns:
    int: Total number of lines (labels) found. Returns 0 if
    the directory is not found or does not contain files.
    """
    total = 0
    path = Path(labels_dirpath)
    if not path.is_dir():
        logger.warning(f"Labels directory not found: {labels_dirpath}")
        return 0

    label_files = list(path.glob("*.txt"))
    if not label_files:
        logger.warning(f"No label files (.txt) found in {labels_dirpath}")
        return 0

    logger.info(f"Counting labels in {len(label_files)} files in {labels_dirpath}...")
    for filepath in tqdm(label_files, desc="Counting labels"):
        try:
            with open(filepath, "r") as f:
                for line in f:
                    if line.strip():  # Считаем непустые строки
                        total += 1
        except Exception as e:
            logger.error(f"Error reading file {filepath}: {e}")
    logger.info(f"Total labels found: {total}")
    return total


train_labels_count_MS_SHIFT = total_lesion_labels_amount(
    str(PROCESSED_MS_SHIFT_DATASET_PATH / LABELS_DIR / "train")
)
val_labels_count_MS_SHIFT = total_lesion_labels_amount(
    str(PROCESSED_MS_SHIFT_DATASET_PATH / LABELS_DIR / "val")
)
test_labels_count_MS_SHIFT = total_lesion_labels_amount(
    str(PROCESSED_MS_SHIFT_DATASET_PATH / LABELS_DIR / "test")
)

print(f"Total labels in train set: {train_labels_count_MS_SHIFT}")
print(f"Total labels in validation set: {val_labels_count_MS_SHIFT}")
print(f"Total labels in test set: {test_labels_count_MS_SHIFT}")

train_labels_count_MSShift_MSLesSeg = total_lesion_labels_amount(
    str(PROCESSED_MS_SHIFT_MSLESSEG_DATASET_PATH / LABELS_DIR / "train")
)
val_labels_count_MSShift_MSLesSeg = total_lesion_labels_amount(
    str(PROCESSED_MS_SHIFT_MSLESSEG_DATASET_PATH / LABELS_DIR / "val")
)
test_labels_count_MSShift_MSLesSeg = total_lesion_labels_amount(
    str(PROCESSED_MS_SHIFT_MSLESSEG_DATASET_PATH / LABELS_DIR / "test")
)

print(f"Total labels in train set: {train_labels_count_MSShift_MSLesSeg}")
print(f"Total labels in validation set: {val_labels_count_MSShift_MSLesSeg}")
print(f"Total labels in test set: {test_labels_count_MSShift_MSLesSeg}")

## Visualization of Predictions

In [None]:
models = load_yolo_models(BEST_MODEL_PATHS)

dataset_to_use = PROCESSED_MS_SHIFT_DATASET_PATH
# models_to_visualize = list(models.keys())
models_to_visualize = [
    "mss_23ep_640imgsz_l12",
]

for img_relative_path in IMAGES_TO_VISUALIZE:
    img_full_path = dataset_to_use / IMAGES_DIR / img_relative_path
    if not img_full_path.is_file():
        logger.warning(
            f"Image file not found: {img_full_path}. Skipping visualization."
        )
        continue

    logger.info(f"\n--- Generating predictions for: {img_relative_path} ---")

    predictions_list = []
    titles_list = []

    source_img_np = png_to_np(str(img_full_path))
    if source_img_np is None:
        logger.warning(
            f"Failed to load image array for {img_full_path}. Skipping visualization."
        )
        continue

    # Get Ground Truth BBoxes once
    gt_label_path = str(dataset_to_use / LABELS_DIR / img_relative_path).replace(
        img_full_path.suffix, ".txt"
    )
    gt_bboxes_yolo = bbox_txt_to_list(gt_label_path)
    img_h, img_w = source_img_np.shape[:2]
    gt_bboxes_pixels = unnormalize_yolo_bboxes(gt_bboxes_yolo, img_w, img_h)

    images_for_plot = [source_img_np]
    bboxes_for_plot = [gt_bboxes_pixels]
    titles_for_plot = ["Ground Truth"]

    for model_name in models_to_visualize:
        if model_name in models:
            model = models[model_name]
            logger.info(f"Predicting with model: {model_name}")
            try:
                # Determine imgsz from the model name or use the default
                imgsz_pred = extract_imgsz_from_name(model_name)

                results = model.predict(
                    str(img_full_path), imgsz=imgsz_pred, verbose=False
                )
                if results and results[0].boxes:
                    pred_bboxes_yolo = results[0].boxes.xywhn.cpu().tolist()
                    pred_bboxes_pixels = unnormalize_yolo_bboxes(
                        pred_bboxes_yolo, img_w, img_h
                    )
                else:
                    pred_bboxes_pixels = []

                images_for_plot.append(source_img_np)
                bboxes_for_plot.append(pred_bboxes_pixels)
                titles_for_plot.append(model_name)

            except Exception as e:
                logger.error(
                    f"Error predicting with model {model_name} on {img_relative_path}: {e}",
                    exc_info=True,
                )

    if len(images_for_plot) > 1:
        plot_images_with_bboxes(
            images=images_for_plot,
            bboxes_list=bboxes_for_plot,
            titles=titles_for_plot,
            main_title=f"Predictions for: {img_relative_path}",
            cols=min(3, len(images_for_plot)),
            figsize_scale=5,
            fontsize=20,
            # filename_to_save=f"prediction_{Path(img_relative_path).stem}.png"
        )
    else:
        logger.warning(
            f"No successful predictions to visualize for {img_relative_path}"
        )

## Model Validation and Metrics Output

### Model.val() YOLO method

In [None]:
# -----Configuration-----
OUTPUT_METRICS_FILE = Path("validation_metrics_summary_correct.json")
SPLITS_TO_VALIDATE = ["val", "test"]
BATCH_SIZE = 40

# Load existing metrics if the file exists, or create an empty dictionary
all_extracted_metrics = load_metrics_from_json(OUTPUT_METRICS_FILE) or {}

logger.info("--- Starting/Resuming Sequential Validation ---")

# Select which models to validate now (you can use the entire BEST_MODEL_PATHS.keys())
models_to_process_now = BEST_MODEL_PATHS.keys()
# models_to_process_now = ["mss_23ep_640imgsz_l12"]

for name in models_to_process_now:
    if name not in BEST_MODEL_PATHS:
        logger.warning(f"Model key '{name}' not found in BEST_MODEL_PATHS. Skipping.")
        continue

    # Skip if metrics for this model already exist (optional)
    if name in all_extracted_metrics and all(
        all_extracted_metrics[name].get(s) for s in SPLITS_TO_VALIDATE
    ):
        logger.info(f"Metrics for model '{name}' already exist. Skipping validation.")
        continue

    logger.info(f"Processing model: {name}")
    model_obj: Optional[YOLO] = None
    model_results: Dict[str, Optional[DetMetrics]] = {}

    try:
        # 1. Load one model
        loaded_dict = load_yolo_models(BEST_MODEL_PATHS, model_keys_to_load=name)
        if name not in loaded_dict:
            raise RuntimeError(f"Failed to load model object for '{name}'")
        model_obj = loaded_dict[name]

        # 2. Define the parameters
        current_data_yaml = (
            DATA_YAML_MSSHIFT_MSLESSEG if "msl_mss" in name else DATA_YAML_MS_SHIFT
        )
        current_imgsz = extract_imgsz_from_name(name)
        current_batch = BATCH_SIZE

        # 3. Perform validation
        model_results = validate_and_print_detailed_metrics(
            model=model_obj,
            model_name=name,
            data_yaml=current_data_yaml,
            splits=SPLITS_TO_VALIDATE,
            imgsz=current_imgsz,
            batch=current_batch,
        )

        # 4. Extract and save metrics for the current model
        extracted_metrics_for_model = {}
        for split, metrics_obj in model_results.items():
            extracted_metrics_for_model[split] = extract_serializable_metrics(
                metrics_obj
            )
        all_extracted_metrics[name] = (
            extracted_metrics_for_model  # Updating the common dictionary
        )

        # 5. Save the updated metrics dictionary to a file after each model
        save_metrics_to_json(all_extracted_metrics, OUTPUT_METRICS_FILE)

    except Exception as e:
        logger.error(f"Error during processing model '{name}': {e}", exc_info=True)

    finally:
        # 6. Freeing up GPU memory
        if model_obj is not None:
            logger.info(f"Unloading model '{name}' and clearing CUDA cache...")
            del model_obj
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                logger.info("CUDA cache cleared.")
            else:
                logger.info("CUDA not available, skipping cache clear.")
        print("-" * 50)

logger.info("--- Finished Sequential Validation Run ---")

# 7. Output the final table
final_metrics = load_metrics_from_json(OUTPUT_METRICS_FILE)
if final_metrics:
    display_metrics_summary(final_metrics)

### Custom inference methods

In [None]:
models_to_process = "mss_23ep_640imgsz_l12"
current_model = load_yolo_models(
    BEST_MODEL_PATHS, model_keys_to_load=models_to_process
)[models_to_process]
imgs_list, gts_list = create_filepath_inference_lists(
    RAW_MS_SHIFT_DATASET_PATH, "Test", TEST_SPLIT_INDEXES_MS_SHIFT
)

In [None]:
batch_inference_mss_23ep_640imgsz_l12 = Yolo3dBatchInference(
    yolo_model=current_model,
    nifti_filepaths=imgs_list,
    gt_filepaths=gts_list,
    conf=0.001,
    iou=0.6,
    slice_dims=[0, 1, 2],
)
batch_inference_mss_23ep_640imgsz_l12.run_batch_inference()

In [None]:
batch_inference_mss_23ep_640imgsz_l12.compute_aggregate_metrics()

## Analyzing metrics from results.csv file

In [None]:
RESULTS_FILENAME = "results.csv"
PARENT_DIR = Path("/path/to/your/data/please/edit/me")

results_data: Dict[str, Dict[str, Optional[Union[pd.DataFrame, Dict]]]] = {}

for model_name, run_dir in MODEL_DIRS_NAMES_PATHS.items():
    results_file = PARENT_DIR / RUNS_DIR / Path(run_dir) / RESULTS_FILENAME
    logger.info(f"--- Analyzing results for: {model_name} ---")
    results_data[model_name] = {"df": None, "analysis": None}

    # if results_file.is_file():
    try:
        df = pd.read_csv(results_file)
        df.columns = df.columns.str.strip()
        results_data[model_name]["df"] = df
        logger.info(f"Loaded results from: {results_file}")

        analysis = analyze_best_metrics(
            df.copy(),
            map50_col=MAP50_COL,
            map50_95_col=MAP50_95_COL,
            precision_col=PRECISION_COL,
            recall_col=RECALL_COL,
            fitness_col=FITNESS_COL,
        )
        results_data[model_name]["analysis"] = analysis
        logger.info(f"Analysis complete for {model_name}.")

        print(f"\nAnalysis Summary for '{model_name}':")
        pprint(analysis)
        print("-" * 40)

    except FileNotFoundError:
        logger.error(f"Results file not found: {results_file}")
    except KeyError as e:
        logger.error(
            f"Missing expected column in {results_file}: {e}. Analysis might be incomplete."
        )
        # Can print df.columns for debugging
        if (
            "df" in results_data[model_name]
            and results_data[model_name]["df"] is not None
        ):
            logger.debug(
                f"Available columns: {list(results_data[model_name]['df'].columns)}"
            )
    except Exception as e:
        logger.error(f"Failed to process results for {model_name}: {e}", exc_info=True)
    # else:
    #     logger.warning(f"Results file not found: {results_file}")

In [None]:
# Comparison of Models

comparison_metrics = {}
for model_name, data in results_data.items():
    analysis = data.get("analysis")
    if analysis:
        best_fitness_epoch, best_fitness_value = analysis.get(
            "best_fitness", (None, None)
        )
        mAP50_at_best_fitness_epoch, mAP50_at_best_fitness = analysis.get(
            "best_fitness_epoch_mAP50", (None, None)
        )
        mAP50_95_at_best_fitness_epoch, mAP50_95_at_best_fitness = analysis.get(
            "best_fitness_epoch_mAP50-95", (None, None)
        )

        if best_fitness_epoch is not None:
            comparison_metrics[model_name] = {
                "best_fitness_epoch": best_fitness_epoch,
                "best_fitness": best_fitness_value,
                "mAP50_at_best_fitness": mAP50_at_best_fitness,
                "mAP50-95_at_best_fitness": mAP50_95_at_best_fitness,
            }
    else:
        logger.warning(
            f"No analysis data available for {model_name} to include in comparison."
        )

# Create a DataFrame for comparison
if comparison_metrics:
    comparison_df = pd.DataFrame.from_dict(comparison_metrics, orient="index")
    # Sort by best fitness (descending)
    comparison_df = comparison_df.sort_values(by="best_fitness", ascending=False)

    print("\n--- Model Comparison Based on Best Fitness Epoch ---")
    print(comparison_df)
    print("----------------------------------------------------")
else:
    print("\nNo valid analysis results found for comparison.")

In [None]:
# Visualizing Learning Curves

models_to_plot = list(results_data.keys())[-3:]  # Show last 3 models


plt.figure(figsize=(12, 6))

# mAP50 plot
plt.subplot(1, 2, 1)
for model_name in models_to_plot:
    df = results_data.get(model_name, {}).get("df")
    if df is not None and MAP50_COL in df.columns:
        plt.plot(df["epoch"], df[MAP50_COL], label=f"{model_name} mAP50")
plt.title("Validation mAP@0.50 vs Epoch")
plt.xlabel("Epoch")
plt.ylabel("mAP@0.50")
plt.grid(True)
plt.legend()

# mAP50-95 plot
plt.subplot(1, 2, 2)
for model_name in models_to_plot:
    df = results_data.get(model_name, {}).get("df")
    if df is not None and MAP50_95_COL in df.columns:
        plt.plot(df["epoch"], df[MAP50_95_COL], label=f"{model_name} mAP50-95")
plt.title("Validation mAP@0.50-0.95 vs Epoch")
plt.xlabel("Epoch")
plt.ylabel("mAP@0.50-0.95")
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()