In [19]:
import numpy as np
import matplotlib.pyplot as plt

def plot_heatmap_with_nan(data, x_list, y_list, save_path):
    """
    Plots a heatmap with NaN values handled and saves it to a specified path.

    Parameters:
        data (numpy.ndarray): 2D array of data values (NaN values are allowed).
        x_list (list or numpy.ndarray): List of values for the x-axis (corresponding to rows of data).
        y_list (list or numpy.ndarray): List of values for the y-axis (corresponding to columns of data).
        save_path (str): File path to save the plot.

    Returns:
        None
    """
    # Validate dimensions
    if data.shape[0] != len(x_list):
        raise ValueError("The length of x_list must match the number of rows in data.")
    if data.shape[1] != len(y_list):
        raise ValueError("The length of y_list must match the number of columns in data.")

    # Replace NaN with a value greater than the maximum
    max_value = np.nanmax(data)
    nan_value = max_value * 1.1  # Set NaN to 1.1 times the max value
    data_with_nan_replaced = np.where(np.isnan(data), nan_value, data)

    # Create grid edges (for pcolormesh)
    x_edges = np.linspace(x_list[0] - (x_list[1] - x_list[0]) / 2, 
                          x_list[-1] + (x_list[1] - x_list[0]) / 2, len(x_list) + 1)
    y_edges = np.linspace(y_list[0] - (y_list[1] - y_list[0]) / 2, 
                          y_list[-1] + (y_list[1] - y_list[0]) / 2, len(y_list) + 1)

    # Use the 'jet' colormap
    cmap = plt.cm.jet

    # Plot the heatmap
    plt.figure(figsize=(8, 6))
    mesh = plt.pcolormesh(y_edges, x_edges, data_with_nan_replaced, cmap=cmap, shading='auto')

    # Add a colorbar
    cbar = plt.colorbar(mesh)
    cbar.set_label("Values", fontsize=12)

    # Adjust colorbar ticks to include NaN
    cbar_ticks = np.linspace(np.nanmin(data), max_value, num=6)  # Generate ticks for original data range
    cbar_ticks = np.append(cbar_ticks, nan_value)  # Add NaN as the last tick
    cbar.set_ticks(cbar_ticks)
    cbar.set_ticklabels([f"{tick:.2f}" for tick in cbar_ticks[:-1]] + ["NaN"])  # Add "NaN" label

    # Explicitly set ticks and labels
    plt.xticks(ticks=y_list, labels=[f"{val:.2f}" for val in y_list], fontsize=10)
    plt.yticks(ticks=x_list, labels=[f"{val:.2f}" for val in x_list], fontsize=10)

    # Label axes
    plt.xlabel("Y-Axis Values", fontsize=12)
    plt.ylabel("X-Axis Values", fontsize=12)

    # Set title
    plt.title("Heatmap with NaN", fontsize=14)

    # Save the plot
    plt.savefig(save_path, dpi=200, bbox_inches='tight')  # Save with specified DPI and remove white edges
    plt.close()  # Close the figure to release memory


In [21]:
# Example data
M, N = 9, 5
data = np.random.random((M, N))
data[2, 3] = np.nan
data[4, 1] = np.nan

x_list = np.linspace(1.1, 1.5, M)
y_list = np.linspace(1, 5, N)

# Save path for the heatmap
save_path = "heatmap_with_nan.png"

# Call the function
plot_heatmap_with_nan(data, x_list, y_list, save_path)