## 1. Imports and Settings

In [None]:
import random
import logging
from pathlib import Path
from typing import List, Tuple, Dict, Optional, Union

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
from skimage import measure
import seaborn as sns

from utils import analyze_nifti_dimensions, show_slices_all_modalities
from viz import get_bbox_from_mask_skimage, plot_images_with_bboxes


logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


ROOT_MS_SHIFT = Path("/path/to/your/data/please/edit/me")
ROOT_MSLESSEG = Path("/path/to/your/data/please/edit/me")

# --- MS_Shift Configuration ---
MS_SHIFT_SPLITS: Dict[str, Path] = {
    "Train": ROOT_MS_SHIFT / "Train",
    "Test": ROOT_MS_SHIFT / "Test",
    "Val": ROOT_MS_SHIFT / "Val",
}
MS_SHIFT_MODALITIES: List[str] = ["FLAIR", "T1CE", "T2", "PD", "T1", "gt"]
MS_SHIFT_SUFFIXES: List[str] = [
    "FLAIR_isovox.nii.gz",
    "isovox_fg_mask.nii.gz",  # foreground mask
    "T1CE_isovox.nii.gz",
    "T2_isovox.nii.gz",
    "gt_isovox.nii.gz",
    "PD_isovox.nii.gz",
    "T1_isovox.nii.gz",
]
MS_SHIFT_GT_KEY = "gt"
MS_SHIFT_IGNORE_KEYS = ("isovox_fg_mask",)

# --- MSLesSeg Configuration ---
MSLESSEG_SPLITS: Dict[str, Path] = {
    "train": ROOT_MSLESSEG / "train",
    "test": ROOT_MSLESSEG / "test",
}
MSLESSEG_MODALITIES: List[str] = ["FLAIR", "T2", "T1", "MASK"]

MSLESSEG_TRAIN_SUFFIX_FORMAT = (
    "{ts}_{mod}.nii.gz"  # ts (timestamp)=T1, T2... mod=FLAIR...
)
MSLESSEG_TEST_SUFFIX_FORMAT = "{mod}.nii.gz"
MSLESSEG_MASK_SUFFIX = "MASK.nii.gz"
MSLESSEG_GT_KEY = "MASK"

NIFTI_FORMATS = (".nii", ".nii.gz")

## 2. Helper functions

In [2]:
def get_slice(img_3d: np.ndarray, slice_dim: int, slice_index: int) -> np.ndarray:
    """
    Extracts a 2D slice from a 3D array.

    Args:
        img_3d (np.ndarray): The 3D image data.
        slice_dim (int): Dimension along which to slice (0, 1, or 2).
        slice_index (int): Index of the slice to extract.

    Returns:
        np.ndarray: The extracted 2D slice..
    """
    if not (0 <= slice_dim <= 2):
        raise ValueError(f"Invalid slice_dim: {slice_dim}. Must be 0, 1, or 2.")
    if not (0 <= slice_index < img_3d.shape[slice_dim]):
        raise IndexError(
            f"slice_index {slice_index} out of bounds for dimension {slice_dim} with shape {img_3d.shape}"
        )

    if slice_dim == 0:
        return img_3d[slice_index, :, :]
    elif slice_dim == 1:
        return img_3d[:, slice_index, :]
    else:
        return img_3d[:, :, slice_index]


def verify_dataset_structure(
    split_paths: Dict[str, Path],
    expected_suffixes: Union[List[str], Dict[str, str]],  # Список или формат-строка
    modality_keys: List[str],
    gt_key: str,
    ignore_suffixes: Optional[Tuple[str, ...]] = None,
    check_train_timepoints: bool = False,
):
    """
    Checks directory structure and file naming for a single split of the dataset.

    Args:
        split_paths (Dict[str, Path]): Dictionary of {split_name: split_path}.
        expected_suffixes (Union[List[str], Dict[str, str]]):
            - For MS_Shift: List of expected file suffixes (after SubjectID_).
            - For MSLesSeg: Dictionary with format strings for 'train' and 'test'
            (e.g., {"train": "{ts}_{mod}.nii.gz", "test": "{mod}.nii.gz"}).
        modality_keys (List[str]): List of expected modality keys (including gt_key).
        gt_key (str): Key identifying the mask file.
        ignore_suffixes (Optional[Tuple[str, ...]]): File suffixes to ignore when checking.
        check_train_timepoints (bool): Flag for MSLesSeg to check subfolders T1, T2... in train.
    """
    all_ok = True
    for split_name, split_dir in split_paths.items():
        logger.info(
            f"--- Verifying structure for split: {split_name} in {split_dir} ---"
        )
        if not split_dir.is_dir():
            logger.error(f"Directory not found: {split_dir}")
            all_ok = False
            continue

        subject_dirs = sorted(
            [d for d in split_dir.iterdir() if d.is_dir() and d.name.startswith("P")]
        )  # MSLesSeg: P*, MS_Shift: numeric
        if (
            not subject_dirs and split_name != "Val"
        ):  # MS_Shift Validation may have a different structure
            numeric_dirs = [
                d for d in split_dir.iterdir() if d.is_dir() and d.name.isdigit()
            ]

            subject_dirs = sorted(numeric_dirs, key=lambda path_obj: int(path_obj.name))

        if not subject_dirs:
            logger.warning(f"No subject directories found in {split_dir}.")
            continue

        logger.info(f"Found {len(subject_dirs)} subject directories.")

        for subj_dir in subject_dirs:
            subj_id = subj_dir.name
            locations_to_check = [subj_dir]
            is_train_mslesseg = check_train_timepoints and split_name == "train"

            if is_train_mslesseg:
                locations_to_check = [
                    d
                    for d in subj_dir.iterdir()
                    if d.is_dir() and d.name.startswith("T")
                ]
                if not locations_to_check:
                    logger.warning(
                        f"MSLesSeg train subject {subj_id} has no timepoint subdirectories (T*)."
                    )
                    continue

            for loc_dir in locations_to_check:
                loc_files = {
                    f.name
                    for f in loc_dir.iterdir()
                    if f.is_file() and f.name.lower().endswith(NIFTI_FORMATS)
                }
                expected_files = set()

                # Determine expected suffixes for the current location
                current_suffixes = []
                if isinstance(expected_suffixes, list):  # MS_Shift like
                    current_suffixes = expected_suffixes
                    prefix = f"{subj_id}_"
                elif isinstance(expected_suffixes, dict):  # MSLesSeg like
                    suffix_format = expected_suffixes.get(split_name)
                    if suffix_format:
                        if is_train_mslesseg:
                            ts = loc_dir.name  # Timepoint like T1
                            current_suffixes = [
                                suffix_format.format(ts=ts, mod=mod)
                                for mod in modality_keys
                            ]
                            prefix = f"{subj_id}_"
                        else:  # Test split or flat train structure
                            current_suffixes = [
                                suffix_format.format(mod=mod)
                                for mod in modality_keys
                                if mod != gt_key
                            ]
                            # Add mask suffix separately if needed for test (usually not present)
                            if (
                                gt_key in modality_keys
                            ):  # If mask *could* be expected (even if usually absent in test)
                                current_suffixes.append(
                                    suffix_format.format(mod=gt_key)
                                )
                            prefix = f"{subj_id}_"
                    else:
                        logger.error(
                            f"Suffix format string not found for split '{split_name}' in expected_suffixes dict."
                        )
                        all_ok = False
                        continue
                else:
                    logger.error("Invalid format for expected_suffixes.")
                    all_ok = False
                    continue

                # Generate expected full file names
                for suffix in current_suffixes:
                    # Skip ignored suffixes
                    if ignore_suffixes and any(
                        ign_suf in suffix for ign_suf in ignore_suffixes
                    ):
                        continue
                    expected_files.add(f"{prefix}{suffix}")

                # Compare
                missing_files = expected_files - loc_files
                extra_files = loc_files - expected_files
                # Ignore hidden files in extra_files
                extra_files = {f for f in extra_files if not f.startswith(".")}

                if missing_files:
                    # Don't consider MASK/gt missing in test split if it is not required
                    is_test_split = split_name == "test"
                    truly_missing = set()
                    for f in missing_files:
                        if is_test_split and gt_key in f:
                            continue
                        truly_missing.add(f)

                    if truly_missing:
                        logger.warning(
                            f"Subject {subj_id} (in {loc_dir.name}): Missing files: {sorted(list(truly_missing))}"
                        )
                        all_ok = False  # We consider it an error if something other than GT is missing in the test

                if extra_files:
                    logger.warning(
                        f"Subject {subj_id} (in {loc_dir.name}): Found extra/unexpected files: {sorted(list(extra_files))}"
                    )

    if all_ok:
        logger.info(
            f"Structure verification PASSED for splits: {list(split_paths.keys())}"
        )
    else:
        logger.error(
            f"Structure verification FAILED for splits: {list(split_paths.keys())}. See warnings above."
        )


def analyze_mask_statistics(mask_filepath: Path) -> Optional[Dict]:
    """
    Parses a single NIfTI segmentation mask.

    Args:
        mask_filepath (Path): Path to NIfTI mask file.

    Returns:
        Optional[Dict]: Dictionary with statistics (size, number of hits,
                        percentage of hit voxels) or None on error.
    """
    img_data = nib.load(mask_filepath).get_fdata()
    if img_data is None:
        return None

    mask_data = img_data.astype(np.uint8)
    total_voxels = mask_data.size
    lesion_voxels = np.sum(mask_data > 0)
    lesion_percentage = (lesion_voxels / total_voxels) * 100 if total_voxels > 0 else 0

    # Counting the number of individual lesions (connected components)
    labels, num_labels = measure.label(mask_data, connectivity=2, return_num=True)
    num_lesions = num_labels  # num_labels includes background as 0

    # Statistics on the size of lesions
    region_props = measure.regionprops(labels)
    lesion_sizes = [prop.area for prop in region_props]  # Area is number of voxels

    stats = {
        "shape": mask_data.shape,
        "total_voxels": total_voxels,
        "lesion_voxels": lesion_voxels,
        "lesion_percentage": lesion_percentage,
        "num_lesions": num_lesions,
        "lesion_sizes": lesion_sizes,  # List of sizes of each lesion
        "mean_lesion_size": np.mean(lesion_sizes) if lesion_sizes else 0,
        "median_lesion_size": np.median(lesion_sizes) if lesion_sizes else 0,
    }
    return stats


def plot_intensity_distribution(img_filepath: Path, modality_name: str) -> None:
    """
    Loads a NIfTI image and plots a histogram of the voxel intensity distribution.

    Args:
        img_filepath (Path): Path to the NIfTI image file.
        modality_name (str): Modality name for the plot title.
    """
    img_data = nib.load(img_filepath).get_fdata()
    if img_data is None:
        return

    plt.figure(figsize=(10, 4))
    # Exclude zero values ​​(background) for better visualization of tissue distribution
    non_zero_voxels = img_data[img_data > 0].flatten()
    if non_zero_voxels.size > 0:
        sns.histplot(non_zero_voxels, bins=100, kde=True)
        plt.title(f"Intensity distribution for {modality_name} ({img_filepath.name})")
        plt.xlabel("Voxel intensity (excluding 0)")
        plt.ylabel("Frequency")
        plt.show()
    else:
        logger.warning(f"Image {img_filepath.name} contains only zero values.")

## 3. EDA for MS_Shift

### Checking the Structure of MS_Shift Files

In [None]:
# Check if all expected files are present for each subject in each split.
verify_dataset_structure(
    MS_SHIFT_SPLITS,
    expected_suffixes=MS_SHIFT_SUFFIXES,
    modality_keys=MS_SHIFT_MODALITIES,
    gt_key=MS_SHIFT_GT_KEY,
)

### Image Dimension Analysis MS_Shift

In [None]:
# Let's see what unique sizes of 3D volumes are found in the dataset.
ms_shift_dims_train = analyze_nifti_dimensions(str(MS_SHIFT_SPLITS["Train"]))
ms_shift_dims_val = analyze_nifti_dimensions(str(MS_SHIFT_SPLITS["Val"]))
ms_shift_dims_test = analyze_nifti_dimensions(str(MS_SHIFT_SPLITS["Test"]))

### MS_Shift Slice Visualization

In [None]:
# # Select a random subject from the split
split = "Train"
subj_ids = [
    d.name for d in MS_SHIFT_SPLITS[split].iterdir() if d.is_dir() and d.name.isdigit()
]
idx = random.choice(subj_ids)
slice_index_axial = 121
# idx = str(21)

logger.info(f"Visualizing modalities for MS_Shift split {split}: idx {idx}")
show_slices_all_modalities(
    str(MS_SHIFT_SPLITS[split] / idx), slice_idx=slice_index_axial, dim=2
)

### MS_Shift Pixel/Voxel Value Analysis

In [None]:
# Select a random subject from the split
split = "Train"
subj_ids = [
    d.name for d in MS_SHIFT_SPLITS[split].iterdir() if d.is_dir() and d.name.isdigit()
]
random_idx = random.choice(subj_ids)
modality = MS_SHIFT_MODALITIES[0]  # ["FLAIR", "T1CE", "T2", "PD", "T1", "gt"]
subj_dir = MS_SHIFT_SPLITS[split] / random_idx
img_path = next(subj_dir.glob(f"{random_idx}_{modality}*.nii.gz"), None)
gt_path = next(subj_dir.glob(f"{random_idx}_gt*.nii.gz"), None)
gt_data = nib.load(gt_path).get_fdata()

if img_path:
    logger.info(f"\n--- Analyzing {split} {modality}: {img_path.name} ---")
    img_data = nib.load(img_path).get_fdata()
    if img_data is not None:
        print(f"Shape: {img_data.shape}")
        print(f"Min: {img_data.min():.2f}, Max: {img_data.max():.2f}")
        print(f"Mean: {img_data.mean():.2f}, Std: {img_data.std():.2f}")
        plot_intensity_distribution(img_path, modality)
    else:
        logger.error(f"Could not load {split} {modality} data for analysis.")
else:
    logger.warning(f"{modality} file not found for subject {img_path}.")

if gt_path:
    logger.info(f"\n--- Analyzing GT Mask: {split} {gt_path.name} ---")
    mask_stats = analyze_mask_statistics(gt_path)
    if mask_stats:
        print(f"Shape: {mask_stats['shape']}")
        print(f"Total Voxels: {mask_stats['total_voxels']}")
        print(f"Lesion Voxels: {mask_stats['lesion_voxels']}")
        print(f"Lesion Percentage: {mask_stats['lesion_percentage']:.4f}%")
        print(f"Number of Lesions (connected components): {mask_stats['num_lesions']}")
        if mask_stats["lesion_sizes"]:
            print(
                f"Lesion Sizes (voxels): Min={min(mask_stats['lesion_sizes'])}, Max={max(mask_stats['lesion_sizes'])}, Mean={mask_stats['mean_lesion_size']:.2f}, Median={mask_stats['median_lesion_size']:.1f}"
            )
            plt.figure(figsize=(8, 4))
            sns.histplot(mask_stats["lesion_sizes"], bins=50)
            plt.title(f"Lesion Size Distribution (voxels) - {split} {gt_path.name}")
            plt.xlabel("Lesion Size (voxels)")
            plt.ylabel("Number")
            plt.yscale("log")  # Use a logarithmic scale for sizes
            plt.show()
    else:
        logger.error("Could not load GT data for analysis.")
else:
    logger.warning(f"GT mask file not found for subject {gt_path}.")

### Visualization of Bounding Boxes MS_Shift


In [None]:
# Select a random subject from the split
# split = "Train"
# subj_ids = [d.name for d in MS_SHIFT_SPLITS[split].iterdir() if d.is_dir() and d.name.isdigit()]
# random_idx = random.choice(subj_ids)
# modality = MS_SHIFT_MODALITIES[0]  # ["FLAIR", "T1CE", "T2", "PD", "T1", "gt"]
# subj_dir = MS_SHIFT_SPLITS[split] / random_idx
# img_path = next(subj_dir.glob(f"{random_idx}_{modality}*.nii.gz"), None)
# gt_path = next(subj_dir.glob(f"{random_idx}_gt*.nii.gz"), None)
# gt_data = nib.load(gt_path).get_fdata()

slice_dim_vis = 2
slice_idx_vis = 115

if 0 <= slice_idx_vis < img_data.shape[slice_dim_vis]:
    img_slice = get_slice(img_data, slice_dim_vis, slice_idx_vis)
    mask_slice = get_slice(gt_data, slice_dim_vis, slice_idx_vis)

    bboxes_pixels = get_bbox_from_mask_skimage(mask_slice)
    print(
        f"\nBounding boxes found on on split {split} slice {slice_idx_vis} (dim {slice_dim_vis}): {len(bboxes_pixels)}"
    )

    plot_images_with_bboxes(
        images=[img_slice, mask_slice],
        bboxes_list=[bboxes_pixels, bboxes_pixels],  # Показываем bbox'ы на обоих
        titles=[
            f"{modality} Slice {slice_idx_vis}",
            f"Mask Slice {slice_idx_vis}/ BBoxes",
        ],
        main_title=f"Split {split} Idx {random_idx} - Slice {slice_idx_vis} - Dim {slice_dim_vis}",
        cols=2,
    )
else:
    logger.warning(f"Slice index {slice_idx_vis} is out of bounds for visualization.")

## 4. EDA for MSLesSeg

### Checking the Structure of MS_Shift Files

In [None]:
# Check if all expected files are present for each subject in each split.
mslesseg_suffix_formats = {
    "train": MSLESSEG_TRAIN_SUFFIX_FORMAT,
    "test": MSLESSEG_TEST_SUFFIX_FORMAT,
}
verify_dataset_structure(
    {"train": MSLESSEG_SPLITS["train"]},
    expected_suffixes=mslesseg_suffix_formats,
    modality_keys=MSLESSEG_MODALITIES,
    gt_key=MSLESSEG_GT_KEY,
    check_train_timepoints=True,  # Important for MSLesSeg train
)
# Проверка Test
verify_dataset_structure(
    {"test": MSLESSEG_SPLITS["test"]},
    expected_suffixes=mslesseg_suffix_formats,
    modality_keys=MSLESSEG_MODALITIES,
    gt_key=MSLESSEG_GT_KEY,
    check_train_timepoints=False,
)

### Image Dimension Analysis MSLesSeg

In [None]:
# Let's see what unique sizes of 3D volumes are found in the dataset.
mslesseg_dims_train = analyze_nifti_dimensions(str(MSLESSEG_SPLITS["train"]))
mslesseg_dims_test = analyze_nifti_dimensions(str(MSLESSEG_SPLITS["test"]))

### MSLesSeg Slice Visualization

In [None]:
# Select a random subject from the split
split = "train"
subj_ids_msl = [
    d.name
    for d in MSLESSEG_SPLITS[split].iterdir()
    if d.is_dir() and d.name.startswith("P")
]
idx = random.choice(subj_ids_msl)
# idx = str(21)
timepoint_dirs = [
    d
    for d in (MSLESSEG_SPLITS[split] / idx).iterdir()
    if d.is_dir() and d.name.startswith("T")
]
if timepoint_dirs:
    random_timepoint_dir = random.choice(timepoint_dirs)
    logger.info(
        f"Visualizing modalities for MSLesSeg split {split}: {idx}, Timepoint: {random_timepoint_dir.name}"
    )
else:
    logger.warning(f"No timepoint directories found for MSLesSeg split {split} {idx}.")

slice_index_axial = 78
show_slices_all_modalities(
    (str(random_timepoint_dir)), slice_idx=slice_index_axial, dim=2
)

### MSLesSeg Pixel/Voxel Value Analysis

In [None]:
# Select a random subject from the split
split = "train"
subj_ids_msl = [
    d.name
    for d in MSLESSEG_SPLITS[split].iterdir()
    if d.is_dir() and d.name.startswith("P")
]
idx = random.choice(subj_ids_msl)
# idx = str(21)
timepoint_dirs = [
    d
    for d in (MSLESSEG_SPLITS[split] / idx).iterdir()
    if d.is_dir() and d.name.startswith("T")
]
if timepoint_dirs:
    random_timepoint_dir = random.choice(timepoint_dirs)
    logger.info(
        f"Visualizing modalities for MSLesSeg split {split}: {idx}, Timepoint: {random_timepoint_dir.name}"
    )
else:
    logger.warning(f"No timepoint directories found for MSLesSeg split {split} {idx}.")


tp_id = random_timepoint_dir.name
modality = MSLESSEG_MODALITIES[0]  # ["FLAIR", "T2", "T1", "MASK"]
img_path_msl = next(random_timepoint_dir.glob(f"{idx}_{tp_id}_{modality}.nii.gz"), None)
mask_path_msl = next(random_timepoint_dir.glob(f"{idx}_{tp_id}_MASK.nii.gz"), None)

if img_path_msl:
    logger.info(f"\n--- Analyzing {split} {modality}: {img_path_msl.name} ---")
    plot_intensity_distribution(img_path_msl, modality)
else:
    logger.warning(f"{modality} file not found for {idx}/{tp_id}.")

if mask_path_msl:
    logger.info(f"\n--- Analyzing GT Mask: {split} {mask_path_msl.name} ---")
    mask_stats_msl = analyze_mask_statistics(mask_path_msl)
    if mask_stats_msl:
        print(f"Shape: {mask_stats_msl['shape']}")
        print(f"Total Voxels: {mask_stats_msl['total_voxels']}")
        print(f"Lesion Voxels: {mask_stats_msl['lesion_voxels']}")
        print(f"Lesion Percentage: {mask_stats_msl['lesion_percentage']:.4f}%")
        print(
            f"Number of Lesions (connected components): {mask_stats_msl['num_lesions']}"
        )
        if mask_stats_msl["lesion_sizes"]:
            print(
                f"Lesion Sizes (voxels): Min={min(mask_stats_msl['lesion_sizes'])}, Max={max(mask_stats_msl['lesion_sizes'])}, Mean={mask_stats_msl['mean_lesion_size']:.2f}, Median={mask_stats_msl['median_lesion_size']:.1f}"
            )
            plt.figure(figsize=(8, 4))
            sns.histplot(mask_stats_msl["lesion_sizes"], bins=50)
            plt.title(f"Lesion Size Distribution (voxels) - {mask_path_msl.name}")
            plt.xlabel("Lesion Size (voxels)")
            plt.ylabel("Number")
            plt.yscale("log")  # Use a logarithmic scale for sizes
            plt.show()
    else:
        logger.error("Could not load GT data for MSLesSeg analysis.")

### Visualization of Bounding Boxes MSLesSeg


In [None]:
# Select a random subject from the split
split = "train"
subj_ids_msl = [
    d.name
    for d in MSLESSEG_SPLITS[split].iterdir()
    if d.is_dir() and d.name.startswith("P")
]
idx = random.choice(subj_ids_msl)
# idx = str(21)
timepoint_dirs = [
    d
    for d in (MSLESSEG_SPLITS[split] / idx).iterdir()
    if d.is_dir() and d.name.startswith("T")
]
if timepoint_dirs:
    random_timepoint_dir = random.choice(timepoint_dirs)
    logger.info(
        f"Visualizing modalities for MSLesSeg split {split}: {idx}, Timepoint: {random_timepoint_dir.name}"
    )
else:
    logger.warning(f"No timepoint directories found for MSLesSeg split {split} {idx}.")

tp_id = random_timepoint_dir.name
modality = MSLESSEG_MODALITIES[0]  # ["FLAIR", "T2", "T1", "MASK"]
img_path_msl = next(random_timepoint_dir.glob(f"{idx}_{tp_id}_{modality}.nii.gz"), None)
mask_path_msl = next(random_timepoint_dir.glob(f"{idx}_{tp_id}_MASK.nii.gz"), None)
img_data_msl = nib.load(img_path_msl).get_fdata()
mask_data_msl = nib.load(mask_path_msl).get_fdata()

slice_dim_vis = 2
slice_idx_vis = 128

if 0 <= slice_idx_vis < img_data_msl.shape[slice_dim_vis]:
    img_slice = get_slice(img_data_msl, slice_dim_vis, slice_idx_vis)
    mask_slice = get_slice(mask_data_msl, slice_dim_vis, slice_idx_vis)

    bboxes_pixels = get_bbox_from_mask_skimage(mask_slice)
    print(
        f"\nBounding boxes found on split {split} slice {slice_idx_vis} (dim {slice_dim_vis}): {len(bboxes_pixels)}"
    )

    plot_images_with_bboxes(
        images=[img_slice, mask_slice],
        bboxes_list=[bboxes_pixels, bboxes_pixels],
        titles=[
            f"{modality} Slice {slice_idx_vis}",
            f"Mask Slice {slice_idx_vis}/ BBoxes",
        ],
        main_title=f"Split {split} Idx {random_idx} - Slice {slice_idx_vis} - Dim {slice_dim_vis}",
        cols=2,
    )
else:
    logger.warning(f"Slice index {slice_idx_vis} is out of bounds for visualization.")