In [None]:
# #========================BLACK AND WHITE=====================================
# import os
# import h5py
# import numpy as np
# import matplotlib.pyplot as plt
#
# intermediate_h5_dir = "nnUNet_results/Dataset027_ACDC/FlexibleTrainerV1__nnUNetPlans__2d/predictions/ensemble_0_1_2_3_4/final_model/intermediate_features"
#
# def plot_and_save_array(data: np.ndarray, save_dir: str, base_name: str, is_input=False):
#     """
#     Plots and saves the given data array as one or more .png files.
#
#     Args:
#         data (np.ndarray):    The numpy array data to plot.
#         save_dir (str):       Directory where .png files will be saved.
#         base_name (str):      Base name (without extension) for the output files.
#         is_input (bool):      If True, treat this data as the original input image
#                               instead of a multi-channel feature map.
#     """
#     # Ensure directory exists
#     os.makedirs(save_dir, exist_ok=True)
#
#     # If there's a batch dimension, remove it (assuming B=1).
#     # Typically shape is (B, C, H, W) for 2D. Adjust if your shapes differ.
#     if data.ndim == 4:  # (B, C, H, W)
#         data = data[0]  # remove batch dimension => (C, H, W)
#
#     # If there's only one channel left, make it (1, H, W) so we can unify logic below
#     if data.ndim == 2:  # (H, W)
#         data = np.expand_dims(data, axis=0)  # => (1, H, W)
#
#     # data should now be (C, H, W) or something similar
#     num_channels = data.shape[0]
#
#     if is_input:
#         # Plot the input as "figure" - often each channel or just a single channel.
#         # Here we show every channel as a separate figure.
#         for c in range(num_channels):
#             plt.figure()
#             plt.imshow(data[c], cmap='gray')
#             plt.axis('off')
#             out_path = os.path.join(save_dir, f"{base_name}_channel_{c}.png")
#             plt.savefig(out_path, bbox_inches='tight', pad_inches=0)
#             plt.close()
#     else:
#         # Plot each channel as a separate feature map
#         for c in range(num_channels):
#             plt.figure()
#             plt.imshow(data[c], cmap='gray')
#             plt.axis('off')
#             out_path = os.path.join(save_dir, f"{base_name}_feature_{c}.png")
#             plt.savefig(out_path, bbox_inches='tight', pad_inches=0)
#             plt.close()
#
#
# def process_h5_group(h5_group, save_dir):
#     """
#     Recursively traverse the HDF5 group structure, creating
#     mirrored folders in `save_dir`, and plot/save any datasets found.
#
#     Args:
#         h5_group (h5py.Group):  The current HDF5 group to process.
#         save_dir (str):         The directory corresponding to this group.
#     """
#     os.makedirs(save_dir, exist_ok=True)
#
#     for key in h5_group.keys():
#         item = h5_group[key]
#         # If it's another group, recurse
#         if isinstance(item, h5py.Group):
#             sub_group_dir = os.path.join(save_dir, key)
#             process_h5_group(item, sub_group_dir)
#         else:
#             # It's a dataset
#             data = item[()]  # read the entire dataset into memory
#             # Check if it's the special "input_x" under encoder_intermediates
#             # The simplest check is just if key == "input_x", or you can
#             # check the full path if you need to be more precise.
#             if key == "input_x":
#                 # Save it as an "input figure"
#                 plot_and_save_array(data, save_dir, base_name="input_x", is_input=True)
#             else:
#                 # Save it as a feature map
#                 plot_and_save_array(data, save_dir, base_name=key, is_input=False)
#
#
# for fname in os.listdir(intermediate_h5_dir):
#     if fname.endswith(".h5"):
#         full_h5_path = os.path.join(intermediate_h5_dir, fname)
#         # create a dir for this .h5 (without extension)
#         h5_dir_name = os.path.splitext(fname)[0]
#         output_dir_for_this_h5 = os.path.join(intermediate_h5_dir, h5_dir_name)
#         os.makedirs(output_dir_for_this_h5, exist_ok=True)
#
#         # Open the HDF5 file and recurse
#         with h5py.File(full_h5_path, 'r') as f:
#             process_h5_group(f, output_dir_for_this_h5)

In [None]:
#====================BETTER=====================================
import os
import h5py
import numpy as np
import matplotlib.pyplot as plt

intermediate_h5_dir = (
    "nnUNet_results/Dataset027_ACDC/"
    "FlexibleTrainerV1__nnUNetPlans__2d/predictions/"
    "ensemble_0_1_2_3_4/final_model/intermediate_features"
)

def plot_mri_image(data: np.ndarray,
                   save_dir: str,
                   dataset_name: str):
    """
    Plot and save the input MRI image(s) in a typical grayscale style.
    1) Each channel is saved to its own PNG.
    2) A single multi-channel figure is also created, showing all channels side by side.
    We apply a percentile-based window (0.5%–99.5%) to avoid extremes.

    Args:
        data (np.ndarray):  Could be (B, C, H, W) or (C, H, W).
                            If (B, C, H, W), we assume B=1 and drop it.
        save_dir (str):     The parent directory where images will be saved.
        dataset_name (str): Typically "input_x".
    """
    dataset_dir = os.path.join(save_dir, dataset_name)
    os.makedirs(dataset_dir, exist_ok=True)

    # Remove batch dimension if present, e.g. (1, C, H, W)
    if data.ndim == 4:  # (B, C, H, W)
        data = data[0]  # => (C, H, W)
    elif data.ndim == 2:
        # single channel => expand to (1, H, W)
        data = np.expand_dims(data, axis=0)

    num_channels = data.shape[0]

    # Compute robust min & max across all channels
    # for consistent grayscale windowing
    all_vals = data.reshape(-1)  # flatten
    vmin = np.percentile(all_vals, 0.5)
    vmax = np.percentile(all_vals, 99.5)

    # 1) Plot each channel separately
    for c in range(num_channels):
        channel_data = data[c]

        plt.figure()
        plt.imshow(channel_data, cmap='gray', vmin=vmin, vmax=vmax)
        plt.axis('off')

        out_name = f"{dataset_name}_channel_{c}.png"
        out_path = os.path.join(dataset_dir, out_name)
        plt.savefig(out_path, bbox_inches='tight', pad_inches=0)
        plt.close()

    # 2) Create a single figure with all channels side by side
    #    This helps to see them all in one "multi-contrast" layout
    fig, axes = plt.subplots(1, num_channels, figsize=(4 * num_channels, 4))
    if num_channels == 1:
        # If there's only 1 channel, 'axes' won't be iterable by default
        axes = [axes]

    for c in range(num_channels):
        axes[c].imshow(data[c], cmap='gray', vmin=vmin, vmax=vmax)
        axes[c].axis('off')
        axes[c].set_title(f"Channel {c}")

    plt.tight_layout()
    multi_out_path = os.path.join(dataset_dir, f"{dataset_name}_all_channels.png")
    plt.savefig(multi_out_path, bbox_inches='tight', pad_inches=0)
    plt.close(fig)

def create_colorbar_figure(save_path, cmap='viridis'):
    """
    Create and save a small figure showing the colormap from 0 to 1.
    This is purely a 'legend' to show how low/high values are mapped in 'viridis'.
    """
    fig, ax = plt.subplots(figsize=(4, 0.5))

    # We create a gradient image from 0 to 1, repeated vertically
    gradient = np.linspace(0, 1, 256)
    gradient = np.vstack((gradient, gradient))  # shape now (2, 256) => a thin strip

    im = ax.imshow(gradient, aspect='auto', cmap=cmap)
    ax.set_axis_off()

    # Create a colorbar. By default, it uses the same vmin=0, vmax=1 as the image
    cbar = fig.colorbar(im, ax=ax, orientation='horizontal')
    cbar.set_label("Color scale (low → high)")

    plt.tight_layout()
    fig.savefig(save_path, dpi=120)
    plt.close(fig)

def plot_and_save_array(data: np.ndarray,
                        save_dir: str,
                        dataset_name: str,
                        is_input=False,
                        cmap='viridis'):
    """
    Plots and saves the given data array as one or more .png files.

    Args:
        data (np.ndarray):    The numpy array (B,C,H,W) or similar shape.
        save_dir (str):       Directory where .png files will be saved.
        dataset_name (str):   The name of the dataset (e.g. "conv_0", "nonlin_1").
        is_input (bool):      If True, treat data as the model's input image
                              (plot as an 'input_x'). Otherwise, treat it as
                              feature maps.
        cmap (str):           Colormap for visualization (e.g. 'viridis').
    """
    # Make a subfolder for this particular dataset
    # so conv_0, nonlin_0, etc. get separate folders
    dataset_dir = os.path.join(save_dir, dataset_name)
    os.makedirs(dataset_dir, exist_ok=True)

    # Assume data might have a batch dimension. Often B=1 in inference.
    # For 2D data, shapes are typically (B, C, H, W). Adapt if needed.
    if data.ndim == 4:  # e.g. (B, C, H, W)
        data = data[0]  # remove batch dimension => (C, H, W)
    elif data.ndim == 3:
        # (C, H, W) or (H, W, C). We assume (C, H, W). If not, adapt as needed.
        pass
    elif data.ndim == 2:
        # Single-channel 2D data => create channel dim
        data = np.expand_dims(data, axis=0)  # => (1, H, W)

    num_channels = data.shape[0]  # data now is (C, H, W)

    # Loop over channels and save each filter separately
    for c in range(num_channels):
        plt.figure()
        plt.imshow(data[c], cmap=cmap)
        plt.axis('off')

        if is_input:
            # e.g., input_x_channel_0.png
            out_name = f"{dataset_name}_channel_{c}.png"
        else:
            # e.g., conv_0_feature_0.png
            out_name = f"{dataset_name}_feature_{c}.png"

        out_path = os.path.join(dataset_dir, out_name)
        plt.savefig(out_path, bbox_inches='tight', pad_inches=0)
        plt.close()


def process_h5_group(h5_group, save_dir):
    """
    Recursively traverse the HDF5 group structure, creating
    mirrored folders in `save_dir`, and plot/save any datasets found.

    Args:
        h5_group (h5py.Group):  The current HDF5 group to process.
        save_dir (str):         The directory corresponding to this group.
    """
    # Create a folder matching this group level
    os.makedirs(save_dir, exist_ok=True)

    for key in h5_group.keys():
        item = h5_group[key]
        if isinstance(item, h5py.Group):
            # It's another group => recurse into a subfolder
            sub_group_dir = os.path.join(save_dir, key)
            process_h5_group(item, sub_group_dir)
        else:
            # It's a dataset
            data = item[()]  # load dataset into memory
            # Check if special 'input_x'
            if key == 'input_x':
                # plot_and_save_array(data, save_dir, key, is_input=True, cmap='viridis')
                plot_mri_image(data, save_dir, key)
            else:
                # Feature map
                plot_and_save_array(data, save_dir, key, is_input=False, cmap='viridis')


# Find all .h5 files in intermediate_h5_dir
for fname in os.listdir(intermediate_h5_dir):
    if fname.endswith(".h5"):
        full_h5_path = os.path.join(intermediate_h5_dir, fname)
        # create dir for this h5 (without extension)
        h5_dir_name = os.path.splitext(fname)[0]
        output_dir_for_this_h5 = os.path.join(intermediate_h5_dir, h5_dir_name)
        os.makedirs(output_dir_for_this_h5, exist_ok=True)

        # Add an extra color scale figure at the top level
        colorbar_path = os.path.join(output_dir_for_this_h5, "colormap_scale.png")
        create_colorbar_figure(colorbar_path, cmap='viridis')

        # Open the HDF5 file and recurse
        with h5py.File(full_h5_path, 'r') as f:
            process_h5_group(f, output_dir_for_this_h5)
