Main Jupyter Notebook that contains all codes to generate the plots used in the paper by R.S. Kort, A. Hagopian, K. Doblhoff-Dier, M. Koper about the Impact of Applied Potential and Hydrogen Coverage on Alkali Metal Cation Behaviour.





In [None]:
# =============================================================================
# IMPORTS
# =============================================================================

import numpy as np
import pickle
import logging
import os

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, ArrowStyle
from matplotlib.collections import LineCollection
from matplotlib.lines import Line2D
from matplotlib.legend_handler import HandlerLine2D

In [None]:
# =============================================================================
# CONSTANTS & PARAMETERS
# =============================================================================

IONS = ["Li", "Na", "K", "Cs"]

ION_COLORS = {
    "Li": "#70AD47",
    "Na": "#FFC000",
    "K": "#A032A0",
    "Cs": "#4472C4"
}

ION_CUTOFFS = {
    "Li": 2.75,
    "Na": 3.2,
    "K": 3.8,
    "Cs": 4.25
}

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
def blend_color(hex_color, fraction):
    """
    Blend a given hexadecimal color with white.

    Parameters:
        hex_color (str): The base hex color (e.g., "#70AD47").
        fraction (float): Fraction for blending with white.
                          0 returns the original color; 1 returns white.

    Returns:
        str: The resulting blended color in hexadecimal format.
    """
    hex_color = hex_color.lstrip('#')
    r = int(hex_color[0:2], 16)
    g = int(hex_color[2:4], 16)
    b = int(hex_color[4:6], 16)
    # Interpolate each channel towards white (255)
    r_new = int(r + (255 - r) * fraction)
    g_new = int(g + (255 - g) * fraction)
    b_new = int(b + (255 - b) * fraction)
    return f"#{r_new:02X}{g_new:02X}{b_new:02X}"

def setup_logging(verbose=False):
    """
    Configure the logging settings for the application.

    Parameters:
        verbose (bool): If True, set logging level to DEBUG; otherwise INFO.
    """
    log_level = logging.DEBUG if verbose else logging.INFO
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(message)s",
        level=log_level
    )

def safe_getattr(obj, attr, default=None):
    """
    Safely retrieve an attribute from an object.

    Parameters:
        obj: The object from which to retrieve the attribute.
        attr (str): The attribute name.
        default: The default value to return if the attribute is missing.

    Returns:
        The attribute's value if present; otherwise, the default.
    """
    return getattr(obj, attr, default)

# =============================================================================
# FILE I/O FUNCTIONS
# =============================================================================
def load_simulation(filename):
    """
    Load a Simulation object from a pickle file.

    Parameters:
        filename (str): Path to the pickle file.

    Returns:
        Simulation: The loaded Simulation object.

    Raises:
        FileNotFoundError: If the file cannot be found.
        pickle.UnpicklingError: If unpickling fails due to a corrupted file.
    """
    try:
        with open(filename, 'rb') as pf:
            simulation = pickle.load(pf)
    except FileNotFoundError as fnf_error:
        logging.error("File not found: %s", filename)
        raise fnf_error
    except pickle.UnpicklingError as up_error:
        logging.error("Error unpickling file %s: %s", filename, up_error)
        raise up_error
    except Exception as e:
        logging.error("Unexpected error loading %s: %s", filename, e)
        raise
    logging.debug("Successfully loaded Simulation from %s", filename)
    return simulation

# =============================================================================
# DATA PROCESSING FUNCTIONS
# =============================================================================
def compute_histogram(data, bin_edges, normalize=False):
    """
    Compute a histogram from a 1D data array given the bin edges.

    Parameters:
        data (np.ndarray): 1D array of data points.
        bin_edges (np.ndarray): Array defining bin boundaries.
        normalize (bool): If True, normalize counts so the sum equals 1.

    Returns:
        tuple: (counts, bin_centers, bin_width)
            counts (np.ndarray): Histogram counts or probabilities.
            bin_centers (np.ndarray): Center positions of the bins.
            bin_width (float): Uniform width of the bins.
    """
    counts, _ = np.histogram(data, bins=bin_edges)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    bin_width = bin_edges[1] - bin_edges[0]
    
    if normalize:
        total = np.sum(counts)
        if total > 0:
            counts = counts / total
    return counts, bin_centers, bin_width

def smooth_data(data, window_size=3):
    """
    Smooth a 1D data array using a moving average.

    Parameters:
        data (np.ndarray): Array of data points to smooth.
        window_size (int): Size of the moving average window.
                           A value of 1 returns the original data.

    Returns:
        np.ndarray: The smoothed data array.
    """
    if window_size <= 1:
        return data
    # Define a uniform kernel for averaging
    kernel = np.ones(window_size) / window_size
    # Convolve data with the kernel using 'same' to preserve length
    smoothed = np.convolve(data, kernel, mode='same')
    return smoothed

def apply_periodic_boundary(diff, box_lengths):
    """
    Apply periodic boundary conditions to a difference vector.

    Parameters:
        diff (np.ndarray): The difference vector between coordinates.
        box_lengths (np.ndarray): Box lengths in each dimension.

    Returns:
        np.ndarray: Adjusted difference vector under periodic boundaries.

    Raises:
        ValueError: If any box length is effectively zero.
    """
    if np.any(box_lengths < 1e-8):
        raise ValueError("Box lengths must be non-zero for periodic boundary conditions.")
    return diff - box_lengths * np.round(diff / box_lengths)

# =============================================================================
# SIMULATION UTILITY FUNCTIONS
# =============================================================================
def has_adsorbates(simulation):
    """
    Determine if the simulation contains adsorbate position data.

    Parameters:
        simulation: Simulation object that should contain trajectory data.

    Returns:
        bool: True if adsorbate positions exist and are non-empty; otherwise False.
    """
    positions = safe_getattr(simulation.trajectories, "positions", None)
    if positions is None:
        return False
    ads = safe_getattr(positions, "adsorbates")
    return ads is not None and ads.size > 0

# =============================================================================
# PLOTTING FUNCTIONS
# =============================================================================
def custom_plot(xlabel="X", ylabel="Y", font_size=60, font_family="Times New Roman",
                figsize=(12, 8), yticks_remove=False, grid=None, tight_layout=True,
                xticks=None, yticks=None, xlim=None, ylim=None,
                legend_options=None, plot_func=None, 
                save_png=None, save_svg=None, show=True):
    """
    Create and finalize a plot with extensive customization options.

    This function sets up the figure and axis, applies custom settings
    (such as labels, grid, and tick configurations), executes an optional
    custom plotting callback, and handles legend customization and file saving.

    Parameters:
        xlabel (str): Label for the x-axis.
        ylabel (str): Label for the y-axis.
        font_size (int): Base font size for text elements.
        font_family (str): Font family for text.
        figsize (tuple): Figure size (width, height) in inches.
        yticks_remove (bool): Remove y-axis ticks if True.
        grid (str or None): Grid type ("x", "y", "both", or None).
        tight_layout (bool): Apply tight layout if True.
        xticks (list or None): Custom x-axis tick positions.
        yticks (list or None): Custom y-axis tick positions.
        xlim (tuple or None): x-axis limits as (min, max).
        ylim (tuple or None): y-axis limits as (min, max).
        legend_options (dict or None): Custom legend options.
            Options: 
                'font_size' (int): Font size for legend text.
                'loc' (str): Location for the legend (e.g., 'best' or 'upper right').
                'handles' (list): Custom handles for the legend.
                'labels' (list): Custom labels for the legend.
                'handler_map' (dict): Custom handler map for the legend.
                'draw_bg' (bool): Draw a background for the legend.
                'bgcolor' (str): Background color for the legend.
                'alpha' (float): Transparency for the legend background.
                'border' (bool): Draw a border around the legend.
                'edgecolor' (str): Color for the legend border.
        plot_func (callable or None): Callback for custom plotting; receives (ax, legend_font_size, legend_loc).
        save_png (str or None): Base filename to save plot as PNG.
        save_svg (str or None): Base filename to save plot as SVG.
        show (bool): Display the plot if True.

    Returns:
        tuple: (ax, legend) where ax is the plot axis and legend is the created legend (or None).
    """
    # Update global font settings.
    plt.rcParams.update({'font.size': font_size, 'font.family': font_family})

    # Create figure and primary axis.
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(1, 1, 1)
    
    # Set axis labels.
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    
    # Optionally remove y-axis ticks.
    if yticks_remove:
        ax.set_yticks([])
    
    # Hide top, right, and left spines for a cleaner look.
    for spine in ['top', 'right', 'left']:
        ax.spines[spine].set_visible(False)
    
    # Configure grid settings based on the provided option.
    if grid == "x":
        ax.yaxis.grid(False)
        ax.xaxis.grid(True)
    elif grid == "y":
        ax.xaxis.grid(False)
        ax.yaxis.grid(True)
    elif grid in [None, "None", False]:
        ax.grid(False)
    else:
        ax.grid(True)
    
    # Apply a tight layout to optimize spacing.
    if tight_layout:
        plt.tight_layout()
    
    # Execute a custom plotting callback if provided.
    if plot_func is not None:
        legend_font_size = legend_options.get('font_size', 32) if legend_options else 32
        legend_loc = legend_options.get('loc', 'best') if legend_options else 'best'
        plot_func(ax, legend_font_size, legend_loc)
    
    # Set custom ticks and axis limits if specified.
    if xticks is not None:
        ax.set_xticks(xticks)
    if yticks is not None:
        ax.set_yticks(yticks)
    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)
    
    # Handle legend customization if options are provided.
    legend = None
    if legend_options:
        legend = ax.get_legend()
        if legend is None:
            handles = legend_options.get('handles')
            labels = legend_options.get('labels')
            handler_map = legend_options.get('handler_map')
            legend_loc = legend_options.get('loc', 'best')
            if handler_map:
                if handles and labels:
                    legend = ax.legend(handles, labels, loc=legend_loc,
                                       handler_map=handler_map)
                else:
                    legend = ax.legend(loc=legend_loc, handler_map=handler_map)
            else:
                if handles and labels:
                    legend = ax.legend(handles, labels, loc=legend_loc)
                else:
                    legend = ax.legend(loc=legend_loc)
        
        # Update legend font size.
        legend_font_size = legend_options.get('font_size', 36)
        plt.setp(legend.get_texts(), fontsize=legend_font_size)
        
        # Adjust legend location if explicitly specified.
        legend_loc = legend_options.get('loc', 'best')
        if legend_loc != 'best':
            legend.set_loc(legend_loc)
        
        # Customize legend background and border if requested.
        if legend_options.get('draw_bg', True):
            legend_bgcolor = legend_options.get('bgcolor', 'white')
            legend_alpha = legend_options.get('alpha', 0.8)
            legend_border = legend_options.get('border', False)
            legend_edgecolor = legend_options.get('edgecolor', 'black')
            
            frame = legend.get_frame()
            frame.set_facecolor(legend_bgcolor)
            frame.set_alpha(legend_alpha)
            frame.set_edgecolor(legend_edgecolor if legend_border else 'none')
    
    # Save the figure in requested formats.
    if save_png:
        fig_path = os.path.join('figures', f"{save_png}.png")
        plt.savefig(fig_path, format='png')
    if save_svg:
        fig_path = os.path.join('figures', f"{save_svg}.svg")
        plt.savefig(fig_path, format='svg')
    
    # Display the plot if required.
    if show:
        plt.show()
    
    return ax, legend

# =============================================================================
# CUSTOM LEGEND HANDLER
# =============================================================================
class HandlerMaybeOutlinedLine2D(HandlerLine2D):
    """
    Custom legend handler that optionally adds an outline to a Line2D.

    If the original line has the attribute '_outline' set to True,
    this handler creates a thicker black outline behind the line.
    """
    def create_artists(self, legend, orig_handle,
                       x0, y0, width, height, fontsize, trans):
        # Check if an outline is requested.
        if getattr(orig_handle, '_outline', False):
            # Create an outlined line with an extra thick black line.
            outline_line = Line2D(
                [x0, x0 + width],
                [y0 + height / 2, y0 + height / 2],
                linestyle=orig_handle.get_linestyle(),
                color='black',
                linewidth=orig_handle.get_linewidth() + 2,
                marker=orig_handle.get_marker()
            )
            outline_line.set_transform(trans)
            line = Line2D(
                [x0, x0 + width],
                [y0 + height / 2, y0 + height / 2],
                linestyle=orig_handle.get_linestyle(),
                color=orig_handle.get_color(),
                linewidth=orig_handle.get_linewidth(),
                marker=orig_handle.get_marker()
            )
            line.set_transform(trans)
            return [outline_line, line]
        else:
            # Default behavior: create a standard Line2D.
            line = Line2D(
                [x0, x0 + width],
                [y0 + height / 2, y0 + height / 2],
                linestyle=orig_handle.get_linestyle(),
                color=orig_handle.get_color(),
                linewidth=orig_handle.get_linewidth(),
                marker=orig_handle.get_marker()
            )
            line.set_transform(trans)
            return [line]

In [None]:
# =============================================================================
# RDF CORE FUNCTIONS
# =============================================================================
# These functions extract simulation parameters, compute the radial distribution
# function (RDF) for a single frame, and average the RDF over multiple frames.

def extract_simulation_parameters(data, bin_edges):
    """
    Extract key simulation parameters needed for RDF calculations.

    This includes the simulation cell volume, volumes of spherical shells
    defined by the bin_edges, centers and width of the bins, and effective box
    lengths for periodic boundary conditions.

    Parameters:
        data: Simulation object containing cell dimensions.
        bin_edges (np.ndarray): Array of bin edge values for the RDF histogram.

    Returns:
        tuple: (volume, shell_volumes, bin_centers, bin_width, box_lengths)
            - volume (float): Total simulation cell volume.
            - shell_volumes (np.ndarray): Volumes of spherical shells defined by bin_edges.
            - bin_centers (np.ndarray): Centers of the bins.
            - bin_width (float): Uniform width of the bins.
            - box_lengths (np.array): Box lengths in each dimension.
            
    Raises:
        Warning: Logs a warning if no cell dimensions are found.
    """
    cell = data.cell_dimensions
    if cell is None:
        logging.warning("Simulation data missing cell dimensions.")
        return None, None, None, None, None

    # Ensure cell dimensions are in a numpy array format.
    cell_array = np.array(cell)
    if cell_array.ndim == 1:
        volume = np.prod(cell_array)
        box_lengths = np.array(cell_array)
    else:
        # If the cell is defined as a matrix, take the diagonal.
        diag_cell = np.diag(cell_array)
        volume = np.prod(diag_cell)
        box_lengths = np.array(diag_cell)

    # Calculate the volumes of spherical shells between consecutive bin edges.
    shell_volumes = (4.0 / 3.0) * np.pi * (bin_edges[1:]**3 - bin_edges[:-1]**3)
    # Compute the centers and width of the bins.
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    bin_width = bin_edges[1] - bin_edges[0]
    return volume, shell_volumes, bin_centers, bin_width, box_lengths


def compute_frame_rdf(oxygen_positions, ion_positions, oxygen_label, ion_label,
                      bin_edges, shell_volumes, bin_centers, bin_width, volume, box_lengths,
                      smoothing_window=1):
    """
    Compute the RDF for a single simulation frame between oxygen atoms and ions.

    Self-correlations are excluded when the oxygen and ion species are identical.

    Parameters:
        oxygen_positions (np.ndarray): Coordinates for oxygen atoms (shape: (N_oxygen, 3)).
        ion_positions (np.ndarray): Coordinates for ion positions (shape: (N_ions, 3)).
        oxygen_label (str): Identifier for the oxygen species.
        ion_label (str): Identifier for the ion species.
        bin_edges (np.ndarray): Bin edge values for histogram calculation.
        shell_volumes (np.ndarray): Precomputed volumes of spherical shells.
        bin_centers (np.ndarray): Centers of the histogram bins.
        bin_width (float): Width of the histogram bins.
        volume (float): Total simulation cell volume.
        box_lengths (np.array): Simulation box lengths for applying periodic boundaries.
        smoothing_window (int, optional): Window size for smoothing the histogram. Default is 1 (no smoothing).

    Returns:
        tuple: (rdf_frame, density_oxygen)
            - rdf_frame (np.ndarray): RDF values for the current frame.
            - density_oxygen (float): Density of oxygen atoms within the simulation cell.
    """
    # Compute displacement vectors between every oxygen and ion,
    # applying periodic boundary conditions.
    diff_vectors = oxygen_positions[:, np.newaxis, :] - ion_positions
    diff_vectors = apply_periodic_boundary(diff_vectors, box_lengths)
    distances = np.linalg.norm(diff_vectors, axis=-1).ravel()

    # Exclude self-correlations for identical species.
    if oxygen_label == ion_label:
        distances = distances[distances > 1e-6]

    # Compute the histogram of interatomic distances.
    counts, _, _ = compute_histogram(distances, bin_edges, normalize=False)

    # Apply smoothing to the histogram counts if requested.
    if smoothing_window > 1:
        counts = smooth_data(counts, window_size=smoothing_window)

    # Calculate oxygen density and the expected counts for a homogeneous distribution.
    num_oxygen = oxygen_positions.shape[0]
    density_oxygen = num_oxygen / volume
    num_ions = ion_positions.shape[0]
    expected_counts = num_ions * density_oxygen * shell_volumes

    # Normalize the histogram to obtain the RDF.
    with np.errstate(divide='ignore', invalid='ignore'):
        rdf_frame = counts / expected_counts
        rdf_frame[np.isnan(rdf_frame)] = 0

    return rdf_frame, density_oxygen


def calculate_average_rdf(data, oxygen_label, ion_label, skip_time, bin_edges, ion_indices=None):
    """
    Compute the average RDF between oxygen and ion species over multiple frames.

    The function aggregates RDFs from frames exceeding a given time threshold and 
    returns the averaged RDF along with the average oxygen density.

    Parameters:
        data: Simulation object containing trajectories and metadata.
        oxygen_label (str): Label for the oxygen positions array (e.g., "watO").
        ion_label (str): Label for the ion positions array (e.g., "ions").
        skip_time (float): Time threshold below which frames are excluded.
        bin_edges (np.ndarray): Bin edge values for the RDF histogram.
        ion_indices (list, optional): List of indices corresponding to a specific ion type.

    Returns:
        tuple: (bin_centers, avg_rdf, avg_density)
            - bin_centers (np.ndarray): Centers of the RDF bins.
            - avg_rdf (np.ndarray): Average RDF computed over selected frames.
            - avg_density (float): Average oxygen density computed over selected frames.
        
        Returns (None, None, None) if no valid frames are found.
    """
    try:
        oxygen_positions_all = getattr(data.trajectories.positions, oxygen_label)
    except AttributeError:
        logging.warning("No position data for oxygen species '%s' found.", oxygen_label)
        return None, None, None

    # Select ion positions according to specified indices, if provided.
    if ion_indices is None:
        ion_positions_all = data.trajectories.positions.ions
    else:
        ion_positions_all = data.trajectories.positions.ions[:, ion_indices, :]

    # Extract simulation parameters required for RDF calculation.
    parameters = extract_simulation_parameters(data, bin_edges)
    if parameters[0] is None:
        return None, None, None
    volume, shell_volumes, bin_centers, bin_width, box_lengths = parameters

    rdf_collection = []
    density_collection = []
    valid_frame_count = 0

    # Iterate over simulation frames, starting after the skip time.
    for frame_idx, current_time in enumerate(data.trajectories.times):
        if current_time < skip_time:
            continue
        valid_frame_count += 1
        oxygen_positions = oxygen_positions_all[frame_idx]
        ion_positions = ion_positions_all[frame_idx]
        rdf_frame, density_oxygen = compute_frame_rdf(
            oxygen_positions, ion_positions,
            oxygen_label, ion_label,
            bin_edges, shell_volumes, bin_centers, bin_width,
            volume, box_lengths
        )
        rdf_collection.append(rdf_frame)
        density_collection.append(density_oxygen)

    if valid_frame_count == 0:
        logging.warning("No frames exceeded the skip time threshold (skip_time=%.2f).", skip_time)
        return None, None, None

    avg_rdf = np.mean(rdf_collection, axis=0)
    avg_density = np.mean(density_collection)
    return bin_centers, avg_rdf, avg_density


# =============================================================================
# AGGREGATION AND PLOTTING OF RDF DATA
# =============================================================================
# This section aggregates RDF data from multiple simulation files and produces a plot
# with one averaged RDF curve per unique ion type.

def plot_ion_oxygen_rdf(simulation_files, smoothing=2, skip_time=5, num_bins=50, max_distance=8, verbose=False):
    """
    Aggregate and plot oxygen-ion RDFs across multiple simulation files.

    The function processes each simulation file to compute the averaged RDF for each
    unique ion type and then plots the aggregated results.

    Parameters:
        simulation_files (list): List of file paths to simulation pickle files.
        smoothing (int): Window size for smoothing the RDF histogram.
        skip_time (float): Time threshold to skip initial frames.
        num_bins (int): Number of bins for the RDF histogram.
        max_distance (float): Maximum distance (in Å) for the RDF analysis.
        verbose (bool): Toggle for detailed logging.
    """
    # Configure logging based on the verbosity flag.
    setup_logging(verbose)
    
    bin_edges = np.linspace(0, max_distance, num_bins + 1)
    aggregated_data = {}  # Dictionary to store RDF data aggregated by ion type.

    # Process each simulation file.
    for file_path in simulation_files:
        logging.debug("Processing simulation file: %s", file_path)
        try:
            sim_data = load_simulation(file_path)
        except Exception as error:
            logging.error("Error loading file %s: %s", file_path, error)
            continue
        
        # Iterate over each unique ion type in the simulation.
        unique_ion_types = set(sim_data.ions)
        for ion_type in unique_ion_types:
            ion_indices = [i for i, ion in enumerate(sim_data.ions) if ion == ion_type]
            
            # For this analysis, oxygen positions are assumed to be under "watO".
            bin_centers, rdf_average, _ = calculate_average_rdf(
                sim_data, "watO", "ions", skip_time, bin_edges, ion_indices=ion_indices
            )
            if bin_centers is None or rdf_average is None:
                logging.warning("Skipping simulation %s for ion type '%s' due to insufficient RDF data.", 
                                file_path, ion_type)
                continue
            
            # Weight the RDF by the number of ions of this type.
            weighted_rdf = rdf_average * len(ion_indices)
            label = f"{ion_type}"
            if label not in aggregated_data:
                aggregated_data[label] = {"bin_centers": bin_centers, "rdf_sum": weighted_rdf, "ion_count": len(ion_indices)}
            else:
                aggregated_data[label]["rdf_sum"] += weighted_rdf
                aggregated_data[label]["ion_count"] += len(ion_indices)
    
    # Prepare aggregated RDF data for plotting.
    rdf_plot_data = []
    for label, data_dict in aggregated_data.items():
        avg_rdf = data_dict["rdf_sum"] / data_dict["ion_count"]
        rdf_plot_data.append((data_dict["bin_centers"], avg_rdf, label))
        logging.debug("Aggregated RDF for ion type '%s': Total ion count = %d", label, data_dict["ion_count"])
    
    # Define a plotting callback for the custom plot function.
    def plot_callback(ax, legend_font_size=32, legend_loc='best'):
        for bin_centers, rdf, label in rdf_plot_data:
            rdf = smooth_data(rdf, window_size=smoothing)
            ax.plot(bin_centers, rdf, linestyle='-', linewidth=8, marker=None,
                    label=label, color=ION_COLORS.get(label, 'black'))
        # Configure the legend appearance.
        ax.legend(fontsize=legend_font_size, loc=legend_loc)
    
    # Generate and display the RDF plot using the custom plotting utility.
    custom_plot(
        xlabel="Ion–O Distance (Å)",
        ylabel="g(r)",
        font_size=60,
        yticks_remove=True,
        figsize=(12, 8),
        xticks=np.arange(2, 7, 1),
        xlim=(1.5, 6.5),
        ylim=(0),
        legend_options={
            'loc': 'upper right',
            'font_size': 32,
            'draw_bg': True,
            'bgcolor': 'white',
            'alpha': 0.8,
            'border': False
        },
        plot_func=plot_callback,
        save_svg='ion-O_rdf'
    )


# =============================================================================
# EXAMPLE USAGE
# =============================================================================
# Define simulation identifiers for ion–O RDF analysis.
simulation_file_paths = [f"data/simulations/Pt111_{ion}2.pkl" for ion in IONS]

# To run the analysis and plot the results, uncomment the following line:
# plot_ion_oxygen_rdf(simulation_files=simulation_file_paths, smoothing=2)

In [None]:
# =============================================================================
# HELPER FUNCTIONS: DENSITY CALCULATION & DATA EXTRACTION
# =============================================================================
def calculate_histogram_density(all_distances, bin_edges, valid_frames):
    """
    Compute the density histogram using compute_histogram.
    
    Normalizes the histogram counts by the number of valid frames.
    Density is reported in arbitrary units.
    
    Parameters:
        all_distances (np.ndarray): Array of distance values.
        bin_edges (np.ndarray): Bin edges used for the histogram.
        valid_frames (int): Number of simulation frames used in the histogram.
        
    Returns:
        np.ndarray: Density histogram in arbitrary units.
    """
    counts, bin_centers, bin_width = compute_histogram(all_distances, bin_edges, normalize=False)
    density = counts / valid_frames
    return density

def extract_initial_positions(simulation, ion, skip):
    """
    Extract the initial z positions (first frame) for a specified ion.
    
    Parameters:
        simulation: Simulation object containing trajectory data.
        ion (str): Ion identifier (e.g., "Li", "Na", etc.).
        skip (float): Time to skip (provided for interface consistency).
        
    Returns:
        list: List of z distances (in Å) between the ion positions and the surface.
    """
    target_mask = np.array([elem == ion for elem in simulation.ions])
    if not np.any(target_mask):
        return []
    first_frame_positions = simulation.trajectories.positions.ions[0, target_mask, 2]
    first_frame_surface = simulation.trajectories.surface_z[0]
    distances = np.abs(first_frame_positions - first_frame_surface)
    return distances.tolist()

def process_density(simulation, target_positions, skip, bin_edges):
    """
    Process simulation data to compute the density histogram.
    
    Iterates over simulation frames (ignoring those with time < skip),
    computes the absolute distance between target positions and the surface,
    aggregates these distances, and calculates the density histogram.
    
    Parameters:
        simulation: Simulation object containing trajectories.
        target_positions (np.ndarray): Array of target positions (z coordinates).
        skip (float): Time threshold (in ps) to skip early simulation frames.
        bin_edges (np.ndarray): Edges defining histogram bins.
        
    Returns:
        tuple: (density, valid_frames)
            density (np.ndarray): Computed density histogram.
            valid_frames (int): Number of frames that were processed.
            
        Returns (None, 0) and logs a warning if no valid frames are found.
    """
    distances_list = []
    valid_frames = 0
    times = simulation.trajectories.times
    for i, t in enumerate(times):
        if t < skip:
            continue
        valid_frames += 1
        d = np.abs(target_positions[i] - simulation.trajectories.surface_z[i])
        distances_list.append(d)
    if valid_frames == 0:
        logging.warning("No valid frames found after skipping %d ps.", skip)
        return None, 0
    all_distances = np.concatenate(distances_list)
    density = calculate_histogram_density(all_distances, bin_edges, valid_frames)
    return density, valid_frames

# =============================================================================
# MAIN PLOTTING FUNCTION: DENSITY PLOT GENERATION
# =============================================================================
def plot_density(files, skip=5, bins=50, smoothing=1, initial_position=False, verbose=False, normalize=True):
    """
    Process simulation files and generate a density plot for ions and oxygen.
    
    For hydrogen-covered systems (detected via adsorbates), ion density curves are
    drawn with a black outline (using a two-layer approach) and oxygen density is
    filled with a hatch. Bare systems are drawn with solid colored lines.
    
    Parameters:
        files (list): List of file paths to simulation pickle files.
        skip (float): Time (in ps) to skip at the beginning of the simulation.
        bins (int): Number of bins for the histogram.
        smoothing (int): Window size for smoothing the density curves.
        initial_position (bool): If True, record and plot initial ion positions.
        verbose (bool): Enable verbose logging if True.
        normalize (bool): Normalize the density curves so that the area under each curve is unity.
        
    Returns:
        None. The function displays and optionally saves the density plot.
    """
    # Set up logging based on the verbosity flag.
    setup_logging(verbose)
    
    # Define histogram bins from 2 Å to 5 Å.
    bin_edges = np.linspace(2, 5, bins + 1)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    # Use the bin width in Å (no unit conversion).
    bin_width = bin_edges[1] - bin_edges[0]

    # Initialize dictionaries to store densities grouped by hydrogen coverage.
    # False: Bare systems; True: Hydrogen-covered systems.
    ion_density_dict = {False: {ion: [] for ion in IONS}, True: {ion: [] for ion in IONS}}
    oxygen_density_list = {False: [], True: []}
    initial_positions = {ion: [] for ion in IONS}
    
    # Process each simulation file.
    for file_path in files:
        simulation = load_simulation(file_path)
        # Detect hydrogen coverage by checking for adsorbates.
        coverage_check = has_adsorbates(simulation)
        
        # Optionally record initial ion positions.
        if initial_position:
            for ion in IONS:
                initial_positions[ion].extend(extract_initial_positions(simulation, ion, skip))
        
        # Process density for each ion.
        for ion in IONS:
            target_mask = np.array([elem == ion for elem in simulation.ions])
            if not np.any(target_mask):
                continue
            target_positions = simulation.trajectories.positions.ions[:, target_mask, 2]
            density, valid_frames = process_density(simulation, target_positions, skip, bin_edges)
            if density is None:
                logging.warning("No valid frames for ion %s in file %s", ion, file_path)
                continue
            ion_density_dict[coverage_check][ion].append(density)
        
        # Process oxygen density if water oxygen positions are available.
        if (hasattr(simulation.trajectories.positions, "watO") and
            simulation.trajectories.positions.watO is not None and
            simulation.trajectories.positions.watO.size > 0):
            target_positions = simulation.trajectories.positions.watO[:, :, 2]
            density, valid_frames = process_density(simulation, target_positions, skip, bin_edges)
            if valid_frames > 0 and density is not None:
                oxygen_density_list[coverage_check].append(density)
            else:
                logging.warning("No valid frames for oxygen in file %s", file_path)
        else:
            logging.debug("No water oxygen positions found in file %s", file_path)
    
    # -------------------------------------------------------------------------
    # AVERAGE AND NORMALIZE DENSITY DATA
    # -------------------------------------------------------------------------
    group_max_ion = {False: 0.0, True: 0.0}
    avg_ion_density = {False: {}, True: {}}
    # Average ion density curves per ion and per coverage group.
    for coverage_flag in [False, True]:
        for ion in IONS:
            if ion_density_dict[coverage_flag][ion]:
                avg_density = np.mean(ion_density_dict[coverage_flag][ion], axis=0)
                if normalize:
                    area = np.sum(avg_density * bin_width)
                    if area != 0:
                        avg_density = avg_density / area
                avg_ion_density[coverage_flag][ion] = avg_density
                group_max_ion[coverage_flag] = max(group_max_ion[coverage_flag], np.max(avg_density))
            else:
                avg_ion_density[coverage_flag][ion] = None

    # Average oxygen density curves for each coverage group.
    avg_oxygen_density = {False: None, True: None}
    for coverage_flag in [False, True]:
        if oxygen_density_list[coverage_flag]:
            oxygen_density = np.mean(oxygen_density_list[coverage_flag], axis=0)
            if normalize:
                area = np.trapezoid(oxygen_density, bin_centers)
                if area != 0:
                    oxygen_density = oxygen_density / area
            avg_oxygen_density[coverage_flag] = oxygen_density
        else:
            avg_oxygen_density[coverage_flag] = None

    # Scale oxygen density so its maximum matches the maximum ion density for each group.
    scaled_oxygen_density = {False: None, True: None}
    for coverage_flag in [False, True]:
        od = avg_oxygen_density[coverage_flag]
        if od is not None and np.max(od) > 0 and group_max_ion[coverage_flag] > 0:
            scaled_oxygen_density[coverage_flag] = od * (group_max_ion[coverage_flag] / np.max(od))
        else:
            scaled_oxygen_density[coverage_flag] = od

    # =============================================================================
    # PLOTTING CALLBACK FUNCTION
    # =============================================================================
    def plot_callback(ax, legend_font_size=32, legend_loc='best'):
        """
        Callback for custom plotting on the provided axis.
        
        Plots the scaled oxygen density as a filled area and overlays ion density
        curves. For hydrogen-covered systems, the ion curves are drawn with an outline.
        Also adds custom legend entries.
        
        Parameters:
            ax (matplotlib.axes.Axes): Axis for plotting.
            legend_font_size (int): Font size for legend entries.
            legend_loc (str): Location of the legend.
        """
        # Plot oxygen density using a filled area.
        for coverage_flag in [False, True]:
            od = scaled_oxygen_density[coverage_flag]
            if od is not None and len(od) == len(bin_centers):
                ax.fill_between(bin_centers, od, color='black', alpha=0.3)
                # For hydrogen-covered systems, overlay a thin black outline.
                if coverage_flag:
                    ax.plot(bin_centers, od, color='black', lw=1)
        
        # Prepare custom legend entries and plot ion density curves.
        custom_handles = []
        custom_labels = []
        for coverage_flag in [False, True]:
            for ion in IONS:
                density = avg_ion_density[coverage_flag][ion]
                if density is not None and len(density) == len(bin_centers):
                    # Optionally smooth the density curve.
                    if smoothing > 1:
                        density = smooth_data(density, window_size=smoothing)
                    # Interpolate onto a dense grid for smoother lines.
                    num_dense = 100
                    x_dense = np.linspace(bin_centers[0], bin_centers[-1], num_dense)
                    y_dense = np.interp(x_dense, bin_centers, density)
                    lw_dense = np.full_like(x_dense, 8)
                    # Construct line segments from the dense data.
                    points = np.array([x_dense, y_dense]).T.reshape(-1, 1, 2)
                    segments = np.concatenate([points[:-1], points[1:]], axis=1)
                    lw_segments = 0.5 * (lw_dense[:-1] + lw_dense[1:])
                    
                    # Draw the ion density curve using LineCollection.
                    if coverage_flag:
                        # For hydrogen-covered systems, draw an outlined curve.
                        outline_width = lw_segments + 2.0
                        lc_outline = LineCollection(
                            segments, linewidths=outline_width,
                            colors='black', capstyle='round', joinstyle='round'
                        )
                        ax.add_collection(lc_outline)
                        # Overlay the colored line.
                        lc = LineCollection(
                            segments, linewidths=lw_segments,
                            colors=ION_COLORS[ion], capstyle='round', joinstyle='round'
                        )
                        ax.add_collection(lc)
                        # Create a custom legend handle flagged for an outline.
                        handle = Line2D([], [], color=ION_COLORS[ion], lw=8, linestyle='-')
                        handle._outline = True  # Flag for custom legend handler
                    # No outline for bare surface
                    else:
                        # For bare systems, draw a solid colored line.
                        lc = LineCollection(
                            segments, linewidths=lw_segments,
                            colors=ION_COLORS[ion], capstyle='round', joinstyle='round'
                        )
                        ax.add_collection(lc)
                        handle = Line2D([], [], color=ION_COLORS[ion], lw=8, linestyle='-')
                    # Gray outline for bare surface
                    # else:
                    #     # For hydrogen-covered systems, draw an outlined curve.
                    #     outline_width = lw_segments + 2.0
                    #     lc_outline = LineCollection(
                    #         segments, linewidths=outline_width,
                    #         colors='gray', capstyle='round', joinstyle='round'
                    #     )
                    #     ax.add_collection(lc_outline)
                    #     # Overlay the colored line.
                    #     lc = LineCollection(
                    #         segments, linewidths=lw_segments,
                    #         colors=ION_COLORS[ion], capstyle='round', joinstyle='round'
                    #     )
                    #     ax.add_collection(lc)
                    #     # Create a custom legend handle flagged for an outline.
                    #     handle = Line2D([], [], color=ION_COLORS[ion], lw=8, linestyle='-')
                    #     handle._outline = True  # Flag for custom legend handler
                            
                    
                    custom_handles.append(handle)
                    custom_labels.append(f"{ion}")
        
        # Add a legend with a custom handler for outlined lines.
        ax.legend(
            handles=custom_handles, labels=custom_labels, loc=legend_loc,
            fontsize=legend_font_size, handler_map={Line2D: HandlerMaybeOutlinedLine2D()}
        )
        
        # Optionally plot the initial ion positions.
        if initial_position:
            for ion in IONS:
                for pos in initial_positions[ion]:
                    ax.plot(pos, 0.2, marker='o', markersize=10, linestyle='None', color=ION_COLORS[ion])

    # =============================================================================
    # FINALIZE THE PLOT USING THE CUSTOM PLOT FUNCTION
    # =============================================================================
    custom_plot(
        xlabel="Distance to Surface (Å)",
        ylabel="ρ (a.u.)",
        font_size=40,
        yticks_remove=True,
        xticks=np.arange(2, 6, 0.5),
        xlim=(2, 5),
        ylim=(0),
        legend_options={
            'loc': 'upper left',
            'font_size': 32,
            'draw_bg': True,
            'bgcolor': 'white',
            'alpha': 0.8
        },
        plot_func=plot_callback,
        show=True,
        save_svg="density_plot"
    )

# =============================================================================
# EXECUTION BLOCK
# =============================================================================
# Uncomment the following line to run the density plot function with sample files.
# plot_density([f"data/simulations/Pt111_{ion}2.pkl" for ion in IONS], smoothing=2)
# plot_density([f"data/simulations/Pt111_{ion}4_H.pkl" for ion in IONS], smoothing=2)


In [None]:
# =============================================================================
# FUNCTION: compare_ion_densities
# =============================================================================
def compare_ion_densities(simulation_files, skip=5, bins=50, smoothing=1, ox_smoothing=1, normalize=True, verbose=False):
    """
    Compare ion densities and oxygen density for a set of simulations.

    For each simulation, the ion density curves are plotted as separate lines
    while the oxygen density is shown as a shaded region. The plotting routine
    distinguishes hydrogen-covered simulations (with black outlines) from bare
    simulations (with solid colored lines) and plots oxygen density as a shaded area.

    Parameters:
        simulation_files (list of str): List of simulation pickle file paths.
        skip (int): Skip frames with simulation time (in ps) less than this value.
        bins (int): Number of bins for the density histogram.
        smoothing (int): Window size for smoothing the ion density curves.
        ox_smoothing (int): Window size for smoothing the oxygen density curve.
        normalize (bool): If True, normalize the density curves (area under the curve = 1).
        verbose (bool): Enable verbose logging.

    Returns:
        tuple: (fig, ax) of the generated plot.
    """
    # ----------------------------
    # SETUP: Logging and Histogram Bins
    # ----------------------------
    setup_logging(verbose)
    bin_edges = np.linspace(2, 5, bins + 1)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])

    # ----------------------------
    # LOAD SIMULATIONS & EXTRACT POTENTIALS WITH ERROR HANDLING
    # ----------------------------
    simulations = []
    potentials = []
    for file_path in simulation_files:
        try:
            sim = load_simulation(file_path)
        except Exception as e:
            if verbose:
                print(f"Error loading simulation from {file_path}: {e}")
            continue
        simulations.append(sim)
        try:
            potentials.append(sim.electrode_potential)
        except AttributeError as e:
            if verbose:
                print(f"Missing electrode potential in simulation from {file_path}: {e}")
            potentials.append(0)

    if not simulations:
        if verbose:
            print("No valid simulations loaded. Exiting function.")
        return None, None

    # ----------------------------
    # DETERMINE COVERAGE FLAGS
    # ----------------------------
    # A simulation is considered "covered" if adsorbates are present.
    coverage_flags = [has_adsorbates(sim) for sim in simulations]

    # ----------------------------
    # HELPER FUNCTION: Compute Density
    # ----------------------------
    def compute_density(sim, target_positions, smooth_value):
        """
        Compute the density using process_density, applying smoothing and normalization.

        Parameters:
            sim: Simulation object.
            target_positions: Array of target positions to process.
            smooth_value (int): Window size for smoothing the density curve.

        Returns:
            tuple: (density, valid_frames) or (None, None) if processing fails.
        """
        try:
            density, valid_frames = process_density(sim, target_positions, skip, bin_edges)
        except Exception as e:
            if verbose:
                print(f"Error processing density for simulation: {e}")
            return None, None
        if density is None:
            return None, None
        if smooth_value > 1:
            try:
                density = smooth_data(density, window_size=smooth_value)
            except Exception as e:
                if verbose:
                    print(f"Error smoothing data: {e}")
        if normalize:
            area = np.trapezoid(density, bin_centers)
            if area != 0:
                density /= area
        return density, valid_frames

    # ----------------------------
    # COMPUTE GLOBAL SCALING FACTORS
    # ----------------------------
    global_ion_max = 0
    global_oxygen_max = 0
    for sim in simulations:
        # Process ion densities for each ion type.
        for ion in IONS:
            target_mask = np.array(sim.ions) == ion  # vectorized comparison
            if not np.any(target_mask):
                continue
            target_positions = sim.trajectories.positions.ions[:, target_mask, 2]
            density, _ = compute_density(sim, target_positions, smoothing)
            if density is None:
                continue
            global_ion_max = max(global_ion_max, density.max())

        # Process oxygen density if available.
        if (hasattr(sim.trajectories.positions, "watO") and
            sim.trajectories.positions.watO is not None and
            sim.trajectories.positions.watO.size > 0):
            target_positions = sim.trajectories.positions.watO[:, :, 2]
            oxygen_density, _ = compute_density(sim, target_positions, ox_smoothing)
            if oxygen_density is not None:
                global_oxygen_max = max(global_oxygen_max, oxygen_density.max())

    scaling_factor = global_ion_max / global_oxygen_max if global_oxygen_max > 0 else 1

    # ----------------------------
    # DEFINE PLOTTING CALLBACK FUNCTION
    # ----------------------------
    def plot_callback(ax, legend_font_size, legend_loc):
        """
        Custom plotting callback to render density curves and annotations.
        """
        min_length = 0.1

        # Helper: Adjust arrow endpoints if too short.
        def adjust_arrow(p, q):
            dx = q[0] - p[0]
            dy = q[1] - p[1]
            d = np.sqrt(dx**2 + dy**2)
            if d < min_length:
                mid = ((p[0] + q[0]) / 2, (p[1] + q[1]) / 2)
                u = (1, 0) if d == 0 else (dx / d, dy / d)
                new_p = (mid[0] - 0.5 * min_length * u[0], mid[1] - 0.5 * min_length * u[1])
                new_q = (mid[0] + 0.5 * min_length * u[0], mid[1] + 0.5 * min_length * u[1])
                return new_p, new_q
            return p, q

        group_min = min(potentials)
        group_range = max(potentials) - group_min if max(potentials) != group_min else 1

        # Variables for tracking oxygen density extremes.
        pos_max_point = None
        neg_max_point = None
        best_pos = -np.inf
        best_neg = np.inf

        pos_cutoff_point = None
        neg_cutoff_point = None
        best_pos_cutoff = -np.inf
        best_neg_cutoff = np.inf

        line_segments_data = []
        legend_entries = []

        for idx, sim in enumerate(simulations):
            coverage_flag = coverage_flags[idx]

            # Process ion densities for each ion type.
            for ion in IONS:
                target_mask = np.array(sim.ions) == ion  # vectorized comparison
                if not np.any(target_mask):
                    continue
                target_positions = sim.trajectories.positions.ions[:, target_mask, 2]
                density, _ = compute_density(sim, target_positions, smoothing)
                if density is None:
                    continue
                exact_fraction = (sim.electrode_potential - group_min) / group_range
                color_fraction = 0.75 * exact_fraction
                ion_color = blend_color(ION_COLORS[ion], color_fraction)
                label = f"{ion} ({sim.electrode_potential:.2f} V)"
                # Interpolate density curve onto a dense grid.
                num_dense = 100
                x_dense = np.linspace(bin_centers[0], bin_centers[-1], num_dense)
                y_dense = np.interp(x_dense, bin_centers, density)
                lw_dense = np.full_like(x_dense, 8)
                points = np.array([x_dense, y_dense]).T.reshape(-1, 1, 2)
                segments = np.concatenate([points[:-1], points[1:]], axis=1)
                lw_segments = 0.5 * (lw_dense[:-1] + lw_dense[1:])
                line_segments_data.append((segments, lw_segments, ion_color, coverage_flag, label))
                legend_entries.append((ion_color, label, coverage_flag))

            # Process oxygen density.
            if (hasattr(sim.trajectories.positions, "watO") and
                sim.trajectories.positions.watO is not None and
                sim.trajectories.positions.watO.size > 0):
                target_positions = sim.trajectories.positions.watO[:, :, 2]
                oxygen_density, _ = compute_density(sim, target_positions, ox_smoothing)
                if oxygen_density is not None:
                    # Scale oxygen density to match ion density scale.
                    oxygen_density *= scaling_factor
                    oxygen_color = blend_color("#000000", 0.5)
                    ax.fill_between(bin_centers, oxygen_density, color=oxygen_color, alpha=0.3)
                    if coverage_flag:
                        ax.plot(bin_centers, oxygen_density, color='black', lw=1)
                    # Track maximum oxygen density for arrow annotation.
                    max_index = np.argmax(oxygen_density)
                    x_max = bin_centers[max_index]
                    y_max = oxygen_density[max_index]
                    if sim.electrode_potential > best_pos:
                        best_pos = sim.electrode_potential
                        pos_max_point = (x_max, y_max)
                    if sim.electrode_potential < best_neg:
                        best_neg = sim.electrode_potential
                        neg_max_point = (x_max, y_max)
                    # Track oxygen density at a cutoff (x = 4.25 Å).
                    cutoff_val = 4.25
                    oxygen_at_cutoff = np.interp(cutoff_val, bin_centers, oxygen_density)
                    if sim.electrode_potential > best_pos_cutoff:
                        best_pos_cutoff = sim.electrode_potential
                        pos_cutoff_point = (cutoff_val, oxygen_at_cutoff)
                    if sim.electrode_potential < best_neg_cutoff:
                        best_neg_cutoff = sim.electrode_potential
                        neg_cutoff_point = (cutoff_val, oxygen_at_cutoff)

        # Draw arrows for oxygen density annotations.
        if pos_max_point is not None and neg_max_point is not None:
            p, q = adjust_arrow(pos_max_point, neg_max_point)
            ax.annotate(
                '',
                xy=q,
                xytext=p,
                arrowprops=dict(
                    facecolor='black', edgecolor='black',
                    arrowstyle=ArrowStyle.Simple(head_length=0.2, head_width=0.2, tail_width=0.07),
                    alpha=0.7
                )
            )
        if pos_cutoff_point is not None and neg_cutoff_point is not None:
            p, q = adjust_arrow(pos_cutoff_point, neg_cutoff_point)
            ax.annotate(
                '',
                xy=q,
                xytext=p,
                arrowprops=dict(
                    facecolor='black', edgecolor='black',
                    arrowstyle=ArrowStyle.Simple(head_length=0.2, head_width=0.2, tail_width=0.07),
                    alpha=0.7
                )
            )

        # Draw ion density curves using LineCollection.
        for segments, lw_segments, color, coverage_flag, label in line_segments_data:
            if coverage_flag:
                outline_width = lw_segments + 2.0
                lc_outline = LineCollection(segments, linewidths=outline_width,
                                            colors='black', capstyle='round', joinstyle='round')
                ax.add_collection(lc_outline)
                lc = LineCollection(segments, linewidths=lw_segments,
                                    colors=color, capstyle='round', joinstyle='round')
                ax.add_collection(lc)
            else:
                lc = LineCollection(segments, linewidths=lw_segments,
                                    colors=color, capstyle='round', joinstyle='round')
                ax.add_collection(lc)

        # Build and add a custom legend.
        custom_handles = []
        custom_labels = []
        for color, label, cov_flag in legend_entries:
            handle = Line2D([], [], color=color, lw=8, linestyle='-')
            if cov_flag:
                handle._outline = True
            custom_handles.append(handle)
            custom_labels.append(label)
        ax.legend(
            handles=custom_handles, labels=custom_labels,
            fontsize=legend_font_size, loc=legend_loc,
            handler_map={Line2D: HandlerMaybeOutlinedLine2D()}
        )

    # ----------------------------
    # FINAL PLOTTING: Invoke custom_plot and return figure and axis
    # ----------------------------
    fig, ax = custom_plot(
        xlabel="Distance to Surface (Å)",
        ylabel="ρ (a.u.)",
        font_size=40,
        yticks_remove=True,
        plot_func=plot_callback,
        xticks=np.arange(2, 5.5, 0.5),
        xlim=(2, 5),
        legend_options={"font_size": 24, "alpha": 0.8, "loc": "upper left"},
        save_svg=f"ion_density_{simulations[0].ions[0]}{'_H' if coverage_flags[0] else ''}",
        show=True
    )
    return fig, ax

# =============================================================================
# EXECUTION BLOCK (UNCOMMENT TO RUN)
# =============================================================================
# Example usage:
# for ion in IONS:
#     compare_ion_densities(
#         [f"data/simulations/Pt111_{ion}{i}.pkl" for i in [2, 3, 4]],
#         smoothing=2,
#         ox_smoothing=10
#     )


In [None]:
# =============================================================================
# FUNCTION: plot_opening_angles
# =============================================================================

def plot_opening_angles(files, bins=180, region_width=0.5, min_count=1000, smoothing=3):
    """
    Plot the distribution of opening angles between ions, water oxygen, 
    and the surface-normal. Angles are binned for each ion type and grouped by 
    the ion's distance from the surface. Only regions with region boundaries 
    (region_min > 1.5 Å and region_max < 5.5 Å) that collect at least min_count angles 
    are plotted. An inset displays ion density curves (with coverage-dependent 
    black outlines) alongside the oxygen density (displayed as a shaded area).

    Parameters:
      files (list of str): List of simulation pickle file paths.
      bins (int): Number of bins for the opening-angle histogram (angles 0–180°).
      region_width (float): Width (in Å) of the distance regions.
      min_count (int): Minimum number of angles required for a region to be included.
      smoothing (int): Smoothing window size for density curves.
    """
    # -------------------------------------------------------------------------
    # Setup: Logging and Data Containers
    # -------------------------------------------------------------------------
    setup_logging(verbose=False)
    angle_data = {ion: {} for ion in IONS}       # Dictionary to store angle data per ion and region.
    density_counts = {ion: {} for ion in IONS}     # Dictionary to count angles per ion and region.

    # -------------------------------------------------------------------------
    # Determine if any simulation is "covered" (has adsorbates)
    # -------------------------------------------------------------------------
    coverage_flag = False
    for file in files:
        sim = load_simulation(file)
        if has_adsorbates(sim):
            coverage_flag = True
            break

    # -------------------------------------------------------------------------
    # Accumulate Opening-Angle Data from Simulations
    # -------------------------------------------------------------------------
    for file in files:
        sim = load_simulation(file)
        n_frames = sim.trajectories.times.shape[0]
        surface_zs = sim.trajectories.surface_z  # One surface z per frame.

        ion_positions_all = sim.trajectories.positions.ions    # (n_frames, n_ions, 3)
        o_positions_all = sim.trajectories.positions.watO         # (n_frames, n_water, 3)
        if o_positions_all is None or o_positions_all.size == 0:
            print(f"No water oxygen positions found in file: {file}")
            continue

        # Build a dictionary of indices per ion type.
        ion_indices = {}
        for i, ion in enumerate(sim.ions):
            ion_indices.setdefault(ion, []).append(i)

        # Loop over each frame.
        for frame in range(n_frames):
            frame_ions = ion_positions_all[frame]  # (n_ions, 3)
            frame_o = o_positions_all[frame]         # (n_water, 3)
            surface_z = surface_zs[frame]

            for ion in IONS:
                if ion not in ion_indices:
                    continue
                for idx in ion_indices[ion]:
                    ion_pos = frame_ions[idx]
                    distance = ion_pos[2] - surface_z
                    if distance < 0:
                        continue
                    # Determine region based on ion distance from the surface.
                    region_min = np.floor(distance / region_width) * region_width
                    region_max = region_min + region_width
                    if region_min > 1.5 and region_max < 5.5:
                        region_key = (region_min, region_max)
                        angle_data[ion].setdefault(region_key, [])
                        density_counts[ion].setdefault(region_key, 0)
                        density_counts[ion][region_key] += 1

                        # Compute displacement vectors to water oxygens.
                        diff = frame_o - ion_pos
                        dists = np.linalg.norm(diff, axis=1)
                        valid = dists <= ION_CUTOFFS[ion]
                        if not np.any(valid):
                            continue
                        valid_diff = diff[valid]
                        valid_dists = dists[valid]
                        cos_theta = valid_diff[:, 2] / valid_dists
                        cos_theta = np.clip(cos_theta, -1.0, 1.0)
                        angles = np.abs(180 - np.degrees(np.arccos(cos_theta)))
                        angle_data[ion][region_key].extend(angles.tolist())

    # -------------------------------------------------------------------------
    # Build Histogram Data for Regions Meeting Minimum Count
    # -------------------------------------------------------------------------
    histogram_data = []
    angle_bins = np.linspace(0, 180, bins + 1)
    for ion in IONS:
        for region_key, angles in angle_data[ion].items():
            if len(angles) < min_count:
                continue
            angles_arr = np.array(angles)
            counts, _ = np.histogram(angles_arr, bins=angle_bins)
            probability = counts / np.sum(counts)
            bin_centers = (angle_bins[:-1] + angle_bins[1:]) / 2
            histogram_data.append((ion, region_key, bin_centers, probability))
    # Sort so that regions with higher lower-bound appear first.
    histogram_data_sorted = sorted(histogram_data,
                                   key=lambda x: (IONS.index(x[0]), -x[1][0]))

    # -------------------------------------------------------------------------
    # Generate Gradient Colors per Ion and Region
    # -------------------------------------------------------------------------
    ion_region_colors = {}
    for ion in IONS:
        valid_regions = sorted([rk for rk in angle_data[ion].keys() if len(angle_data[ion][rk]) >= min_count],
                               key=lambda x: x[0])
        n_regions = len(valid_regions)
        colors = {}
        if n_regions > 0:
            for idx, region_key in enumerate(valid_regions):
                # Lighten the base ion color according to the region order.
                fraction = (idx / (n_regions - 1) / 2) if n_regions > 1 else 0
                colors[region_key] = blend_color(ION_COLORS[ion], fraction)
        ion_region_colors[ion] = colors

    # -------------------------------------------------------------------------
    # Compute Ion Density Data for the Inset Plot
    # -------------------------------------------------------------------------
    ion_density_dict = {ion: [] for ion in IONS}
    density_bin_edges = np.linspace(2, 5, 51)  # 50 bins from 2 Å to 5 Å.
    density_bin_centers = 0.5 * (density_bin_edges[:-1] + density_bin_edges[1:])

    for file in files:
        sim = load_simulation(file)
        for ion in IONS:
            target_mask = np.array([elem == ion for elem in sim.ions])
            if not np.any(target_mask):
                continue
            target_positions = sim.trajectories.positions.ions[:, target_mask, 2]
            density, _ = process_density(sim, target_positions, skip=0, bin_edges=density_bin_edges)
            if density is not None:
                if smoothing > 1:
                    density = np.convolve(density, np.ones(smoothing) / smoothing, mode='same')
                ion_density_dict[ion].append(density)
    for ion in IONS:
        if ion_density_dict[ion]:
            ion_density_dict[ion] = np.mean(ion_density_dict[ion], axis=0)
        else:
            ion_density_dict[ion] = None

    # -------------------------------------------------------------------------
    # Compute Oxygen Density from Water Oxygen Positions
    # -------------------------------------------------------------------------
    oxygen_density_list = []
    for file in files:
        sim = load_simulation(file)
        if (sim.trajectories.positions.watO is not None and 
            sim.trajectories.positions.watO.size > 0):
            target_positions = sim.trajectories.positions.watO[:, :, 2]
            density, _ = process_density(sim, target_positions, skip=0, bin_edges=density_bin_edges)
            if density is not None:
                if smoothing > 1:
                    density = np.convolve(density, np.ones(smoothing) / smoothing, mode='same')
                oxygen_density_list.append(density)
    oxygen_density = np.mean(oxygen_density_list, axis=0) if oxygen_density_list else None

    # -------------------------------------------------------------------------
    # Scale Oxygen Density to Match Ion Density Maximum (for Inset)
    # -------------------------------------------------------------------------
    max_ion_value = 0.0
    for ion in IONS:
        d = ion_density_dict[ion]
        if d is not None:
            max_ion_value = max(max_ion_value, np.max(d))
    if (oxygen_density is not None and np.max(oxygen_density) > 0 and max_ion_value > 0):
        oxygen_density_scaled = oxygen_density * (max_ion_value / np.max(oxygen_density))
    else:
        oxygen_density_scaled = oxygen_density

    # -------------------------------------------------------------------------
    # Define the Plotting Callback Function
    # -------------------------------------------------------------------------
    def plot_callback(ax, legend_font_size, legend_loc):
        """
        Custom plotting callback that renders opening-angle histograms and an inset 
        displaying ion and oxygen density curves.
        """
        # --- Plot Opening-Angle Histograms ---
        for ion, region_key, centers, prob in histogram_data_sorted:
            color = ion_region_colors[ion].get(region_key, ION_COLORS[ion])
            if coverage_flag:
                # For covered simulations, draw with a black outline.
                ax.plot(centers, prob, linewidth=10, color='black')
                ax.plot(centers, prob, linewidth=8, color=color)
            else:
                ax.plot(centers, prob, linestyle='-', linewidth=8, color=color)
        ax.set_xticks(np.arange(0, 181, 30))
        ax.set_xlim(0, 180)
        ax.set_ylim(0)

        # --- Create Inset Axis for Density Plot ---
        fig = ax.figure
        main_pos = ax.get_position().bounds  # (left, bottom, width, height)
        # Define candidate inset positions.
        candidates = {
            "top_right": [main_pos[0] + 0.68 * main_pos[2], main_pos[1] + 0.62 * main_pos[3],
                          0.35 * main_pos[2], 0.35 * main_pos[3]],
            "top_left": [main_pos[0] + 0.05 * main_pos[2], main_pos[1] + 0.60 * main_pos[3],
                         0.35 * main_pos[2], 0.35 * main_pos[3]],
            "bottom_right": [main_pos[0] + 0.60 * main_pos[2], main_pos[1] + 0.05 * main_pos[3],
                             0.35 * main_pos[2], 0.35 * main_pos[3]],
            "bottom_left": [main_pos[0] + 0.05 * main_pos[2], main_pos[1] + 0.05 * main_pos[3],
                            0.35 * main_pos[2], 0.35 * main_pos[3]]
        }
        def candidate_score(candidate_box):
            fig_left, fig_bottom, fig_width, fig_height = candidate_box
            inv = ax.transData.inverted()
            data_ll = inv.transform((fig_left, fig_bottom))
            data_ur = inv.transform((fig_left + fig_width, fig_bottom + fig_height))
            x_min, y_min = data_ll
            x_max, y_max = data_ur
            total, count = 0, 0
            for line in ax.lines:
                xdata = line.get_xdata()
                ydata = line.get_ydata()
                mask = (xdata >= x_min) & (xdata <= x_max)
                if np.any(mask):
                    total += np.mean(ydata[mask])
                    count += 1
            return total / count if count > 0 else 1e6

        candidate_scores = {key: candidate_score(box) for key, box in candidates.items()}
        best_candidate = min(candidate_scores, key=candidate_scores.get)
        best_box = candidates[best_candidate]
        ax_density = fig.add_axes(best_box, zorder=3)

        # Expand the inset bounding box slightly.
        fig.canvas.draw()
        renderer = fig.canvas.get_renderer()
        bbox = ax_density.get_tightbbox(renderer).transformed(fig.transFigure.inverted())
        margin = 0.02
        new_x0 = bbox.x0
        new_y0 = bbox.y0 - 3 * margin
        new_width = bbox.width - 2 * margin
        new_height = bbox.height + 2 * margin
        rect = Rectangle((new_x0, new_y0), new_width, new_height,
                         transform=fig.transFigure, facecolor='white', alpha=0.8,
                         edgecolor='none', zorder=2)
        fig.patches.append(rect)

        # --- Plot Oxygen Density in the Inset ---
        if (oxygen_density_scaled is not None and len(density_bin_centers) == len(oxygen_density_scaled)):
            oxy_color = blend_color("#000000", 0.5)
            ax_density.fill_between(density_bin_centers, oxygen_density_scaled,
                                    color=oxy_color, alpha=0.3)
            if coverage_flag:
                ax_density.plot(density_bin_centers, oxygen_density_scaled,
                                color='black', lw=1, zorder=0)

        # --- Plot Ion Density Curves in the Inset ---
        for ion in IONS:
            density = ion_density_dict[ion]
            if density is None:
                continue
            if coverage_flag:
                ax_density.plot(density_bin_centers, density, linewidth=5, color='black')
                ax_density.plot(density_bin_centers, density, linewidth=3, color=ION_COLORS[ion])
            else:
                ax_density.plot(density_bin_centers, density, linestyle='-', linewidth=3,
                                color=ION_COLORS[ion])
            # Fill the area for each region with sufficient angle data.
            for region_key in angle_data[ion]:
                if len(angle_data[ion][region_key]) < min_count:
                    continue
                region_min, region_max = region_key
                mask = (density_bin_centers >= region_min) & (density_bin_centers <= region_max)
                if np.any(mask):
                    color = ion_region_colors[ion].get(region_key, ION_COLORS[ion])
                    ax_density.fill_between(density_bin_centers[mask],
                                            density[mask],
                                            color=color)
        ax_density.set_xlabel("Distance to Surface (Å)", fontsize=28)
        ax_density.set_ylabel("ρ (a.u.)", fontsize=32)
        ax_density.set_yticks([])
        ax_density.tick_params(labelsize=28)
        ax_density.set_xticks(np.arange(2, 6, 1))
        ax_density.set_xlim(2, 5)
        ax_density.set_ylim(0)

    # -------------------------------------------------------------------------
    # Finalize Plot using the custom_plot Utility
    # -------------------------------------------------------------------------
    custom_plot(
        xlabel="Opening angle (degrees)",
        ylabel="Probability (a.u.)",
        font_size=48,
        yticks_remove=True,
        grid=False,
        xlim=(0, 180),
        xticks=np.arange(0, 181, 30),
        ylim=(0),
        legend_options=None,  # No custom legend defined for the main plot.
        plot_func=plot_callback,
        show=True
    )

# =============================================================================
# EXECUTION BLOCK (UNCOMMENT TO RUN)
# =============================================================================

# Uncomment the following lines to process the desired simulations.
# for ion in IONS:
#     plot_opening_angles([f"data/simulations/Pt111_{ion}2.pkl"], bins=36, region_width=0.5, smoothing=2)
#     plot_opening_angles([f"data/simulations/Pt111_{ion}4_H.pkl"], bins=36, region_width=0.5, smoothing=2)

In [None]:
# =============================================================================
# LOCAL CONTINUOUS COORDINATION NUMBER (CCN) CALCULATION
# =============================================================================
def compute_local_ccn(ion_position, o_positions, o_density, rdf_bin_edges, cutoff, box_lengths=None, ccn_d=0.1):
    """
    Compute the local continuous coordination number (CN) for a single ion using a switching function.
    
    Each water molecule contributes with a weight given by:
        f(r) = 1 / (1 + exp((r - cutoff) / ccn_d))
    
    Additionally, computes the weighted average z-coordinate of the oxygen atoms.
    
    Parameters:
        ion_position (np.ndarray): (3,) array representing the ion's position.
        o_positions (np.ndarray): (n_O, 3) array of water oxygen positions.
        o_density (float): Global oxygen density (unused; retained for compatibility).
        rdf_bin_edges (np.ndarray): Bin edges for the RDF histogram (unused; retained for compatibility).
        cutoff (float): Characteristic distance (Å) for the switching function.
        box_lengths (np.array, optional): Simulation box dimensions for periodic boundary corrections.
        ccn_d (float, optional): Smoothing parameter controlling the switching width.
    
    Returns:
        tuple: (cn, avg_z) where:
            cn   : Continuous coordination number.
            avg_z: Weighted average z-coordinate of oxygen atoms.
            Returns (None, None) if no valid weights are found.
    """
    # Compute displacement vectors from the ion to each oxygen.
    diff = o_positions - ion_position
    # Apply periodic boundary corrections if box dimensions are provided.
    if box_lengths is not None:
        diff = apply_periodic_boundary(diff, box_lengths)
    
    # Calculate Euclidean distances from the ion to each oxygen atom.
    distances = np.linalg.norm(diff, axis=1)
    
    # Evaluate the switching function to obtain weights for each oxygen.
    weights = 1.0 / (1.0 + np.exp((distances - cutoff) / ccn_d))
    
    # Sum the weights to obtain the continuous coordination number.
    cn = np.sum(weights)
    
    # Compute the weighted average of the oxygen z-coordinates.
    if np.sum(weights) > 0:
        avg_z = np.sum(o_positions[:, 2] * weights) / np.sum(weights)
    else:
        return None, None
    
    return cn, avg_z

# =============================================================================
# PLOT CCN VERSUS DISTANCE TO THE SURFACE
# =============================================================================
def plot_ccn_vs_distance(sim_files, skip=5, default_cutoff=3.0, bins=100,
                         max_distance=5.0, ccn_d=0.1, threshold=1):
    """
    Process simulation pickle files to compute and plot the continuous coordination number (CN)
    versus distance to the surface.
    
    For each simulation:
      - Compute the local CN for every ion using the switching function.
      - Determine the distance of the ion from the surface (ion's z-coordinate relative to the surface).
      - Bin the CN values as a function of distance.
    
    Special considerations:
      - For hydrogen-adsorbed (covered) systems, CN curves are drawn with a black outline.
      - Oxygen density is shown as a shaded area, with an additional black outline for covered systems.
    
    Parameters:
        sim_files (list): List of simulation pickle file paths.
        skip (int): Time (ps) to skip at the beginning of each simulation.
        default_cutoff (float): Default cutoff (Å) for the switching function if not specified per ion.
        bins (int): Number of bins for the distance-to-surface histogram.
        max_distance (float): Maximum distance (Å) for binning.
        ccn_d (float): Smoothing parameter for the switching function.
        threshold (int): Minimum count threshold for a bin to be considered valid.
    """
    # -------------------------------
    # Define Distance Bins and Initialize Containers
    # -------------------------------
    distance_bins = np.linspace(0, max_distance, bins + 1)
    bin_centers_distance = 0.5 * (distance_bins[:-1] + distance_bins[1:])
    
    results = {}              # Structure: key -> (ion_type, hydrogen_flag); value -> [total_cn_sum, total_counts, cutoff, total_ion_count, total_frames]
    oxygen_distances = {False: [], True: []}  # Accumulate oxygen distances per coverage flag.
    oxygen_frames = {False: 0, True: 0}         # Count frames per coverage flag.
    global_area = None        # Will be set from the first simulation with valid cell dimensions.
    
    # -------------------------------
    # Process Each Simulation File
    # -------------------------------
    for file_name in sim_files:
        sim = load_simulation(file_name)
        
        # Determine if the simulation is hydrogen-covered.
        coverage_check = has_adsorbates(sim)
        
        # Determine ion type and cutoff from the simulation metadata.
        ion_type = sim.ions[0] if isinstance(sim.ions, (list, tuple)) else sim.ions
        specific_cutoff = ION_CUTOFFS.get(ion_type, default_cutoff)
        rdf_bin_edges = np.linspace(0, specific_cutoff, 51)  # (Unused; retained for compatibility)
        
        # Ensure necessary atomic positions are present.
        if sim.trajectories.positions.watO is None or sim.trajectories.positions.ions is None:
            print(f"Required atom positions (water O or {ion_type}) not found in {file_name}.")
            continue
        
        # Obtain simulation cell dimensions and calculate area.
        cell = sim.cell_dimensions
        if cell is None:
            print(f"No cell dimensions found in {file_name}; skipping periodic corrections.")
            box_lengths = None
            area = 1.0
        else:
            box_lengths = np.array(cell) if np.ndim(cell) == 1 else np.diag(cell)
            area = box_lengths[0] * box_lengths[1]
            if global_area is None:
                global_area = area
        
        volume = np.prod(box_lengths) if box_lengths is not None else 1.0
        o_count = sim.trajectories.positions.watO.shape[1]
        o_density_global = o_count / volume
        
        n_frames = sim.trajectories.positions.watO.shape[0]
        n_ions = sim.trajectories.positions.ions.shape[1]
        
        # Temporary accumulators for this simulation.
        cn_sum = np.zeros(bins)
        counts = np.zeros(bins)
        
        # Loop over simulation frames; skip early frames based on simulation time.
        for frame in range(n_frames):
            if frame * sim.timestep < (skip * 1000):
                continue
            
            surface_z = sim.trajectories.surface_z[frame]
            o_positions = sim.trajectories.positions.watO[frame]   # (n_water, 3)
            ion_positions = sim.trajectories.positions.ions[frame]   # (n_ions, 3)
            
            # Accumulate oxygen distances relative to the surface.
            oxygen_distances_frame = o_positions[:, 2] - surface_z
            valid_mask = (oxygen_distances_frame >= distance_bins[0]) & (oxygen_distances_frame <= distance_bins[-1])
            if np.any(valid_mask):
                oxygen_distances[coverage_check].extend(oxygen_distances_frame[valid_mask])
                oxygen_frames[coverage_check] += 1
            
            # Process each ion in the frame.
            for ion_pos in ion_positions:
                result = compute_local_ccn(ion_pos, o_positions, o_density_global,
                                           rdf_bin_edges, specific_cutoff, box_lengths, ccn_d=ccn_d)
                if result is None:
                    continue
                cn, avg_z = result
                # Compute distance from the surface.
                dist_to_surface = ion_pos[2] - surface_z
                bin_idx = np.digitize(dist_to_surface, distance_bins) - 1
                if bin_idx < 0 or bin_idx >= bins:
                    continue
                cn_sum[bin_idx] += cn
                counts[bin_idx] += 1
        
        key = (ion_type, coverage_check)
        if key in results:
            results[key][0] += cn_sum
            results[key][1] += counts
            results[key][3] += n_frames * n_ions
            results[key][4] += n_frames
        else:
            results[key] = [cn_sum, counts, specific_cutoff, n_frames * n_ions, n_frames]
    
    # -------------------------------
    # Normalize and Process Oxygen Density Data
    # -------------------------------
    oxygen_density_norm = {}
    for coverage_flag in [False, True]:
        if oxygen_distances[coverage_flag] and oxygen_frames[coverage_flag] > 0 and global_area is not None:
            counts_o, _ = np.histogram(oxygen_distances[coverage_flag], bins=distance_bins)
            bin_width = distance_bins[1] - distance_bins[0]
            oxygen_density = counts_o / (global_area * bin_width * oxygen_frames[coverage_flag])
            # Normalize oxygen density to a comparable scale.
            if oxygen_density.max() > 0:
                oxygen_density_norm[coverage_flag] = 2.5 + oxygen_density * (8.5 / oxygen_density.max())
            else:
                oxygen_density_norm[coverage_flag] = oxygen_density
        else:
            oxygen_density_norm[coverage_flag] = None

    # -------------------------------
    # Prepare Plotting Data for CN Curves
    # -------------------------------
    line_segments_data = []  # List of tuples: (segments, line widths, color, hydrogen_flag)
    legend_dict = {}         # Mapping: ion -> (color, hydrogen_flag)
    baseline = 1
    max_lw = 12
    for (ion, coverage_flag), (cn_sum_total, counts_total, _, total_ion_count, total_frames) in results.items():
        # Compute average CN per bin.
        avg_cn = np.divide(cn_sum_total, counts_total, out=np.zeros_like(cn_sum_total), where=counts_total > 0)
        valid = counts_total > 0
        if not np.any(valid):
            continue
        x_valid = bin_centers_distance[valid]
        y_valid = avg_cn[valid]
        
        # Compute a smoothed line width proportional to the counts.
        smoothed = np.convolve(counts_total, np.ones(3) / 3, mode='same')[valid]
        norm = (smoothed - smoothed.min()) / (smoothed.max() - smoothed.min() + 1e-8) if smoothed.max() > 0 else smoothed
        lw_valid = baseline + (max_lw - baseline) * norm
        
        # Interpolate onto a dense grid for smooth curve rendering.
        num_dense = 100
        x_dense = np.linspace(x_valid[0], x_valid[-1], num_dense)
        y_dense = np.interp(x_dense, x_valid, y_valid)
        lw_dense = np.interp(x_dense, x_valid, lw_valid)
        
        # Build line segments for plotting.
        points = np.array([x_dense, y_dense]).T.reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)
        lw_segments = 0.5 * (lw_dense[:-1] + lw_dense[1:])
        
        # Get ion color.
        color = ION_COLORS.get(ion, "#4472C4")
        line_segments_data.append((segments, lw_segments, color, coverage_flag))
        
        # For legend purposes, if an ion appears in both conditions, prefer the hydrogen-covered version.
        if ion not in legend_dict or coverage_flag:
            legend_dict[ion] = (color, coverage_flag)
    
    # Build custom legend handles.
    custom_handles = []
    custom_labels = []
    for ion, (color, coverage_flag) in legend_dict.items():
        handle = Line2D([], [], color=color, lw=max_lw, linestyle='-')
        if coverage_flag:
            handle._outline = True
        custom_handles.append(handle)
        custom_labels.append(ion)
    
    # -------------------------------
    # Define Plotting Callback for Custom Plot
    # -------------------------------
    def plot_callback(ax, legend_font_size, legend_loc):
        # Plot oxygen density as a shaded area.
        for coverage_flag in [False, True]:
            od = oxygen_density_norm[coverage_flag]
            if od is None:
                continue
            oxygen_color = blend_color("#000000", 0.5)
            ax.fill_between(bin_centers_distance, od, color=oxygen_color, alpha=0.3)
            if coverage_flag:
                ax.plot(bin_centers_distance, od, color='black', lw=1)
        
        # Plot CN curves using LineCollection.
        for segments, lw_segments, color, coverage_flag in line_segments_data:
            if coverage_flag:
                outline_width = lw_segments + 2.0
                lc_outline = LineCollection(segments, linewidths=outline_width,
                                            colors='black', capstyle='round', joinstyle='round')
                ax.add_collection(lc_outline)
                lc = LineCollection(segments, linewidths=lw_segments,
                                    colors=color, capstyle='round', joinstyle='round')
                ax.add_collection(lc)
            # Original plotting method, without outlined lines.
            else:
                lc = LineCollection(segments, linewidths=lw_segments,
                                    colors=color, capstyle='round', joinstyle='round')
                ax.add_collection(lc)
            
            # New plotting method with gray outlined lines.
            # else:
            #     outline_width = lw_segments + 2.0
            #     lc_outline = LineCollection(segments, linewidths=outline_width,
            #                                 colors='gray', capstyle='round', joinstyle='round')
            #     ax.add_collection(lc_outline)
            #     lc = LineCollection(segments, linewidths=lw_segments,
            #                         colors=color, capstyle='round', joinstyle='round')
            #     ax.add_collection(lc)
        
        # Add the custom legend.
        ax.legend(handles=custom_handles, labels=custom_labels, 
                  fontsize=legend_font_size, loc=legend_loc,
                  handler_map={Line2D: HandlerMaybeOutlinedLine2D()})
    
    # -------------------------------
    # Final Plot Generation using custom_plot Utility
    # -------------------------------
    custom_plot(
        figsize=(12, 8),
        plot_func=plot_callback,
        font_size=32,
        xlabel="Distance to Surface (Å)",
        xlim=(2, 5),
        xticks=np.arange(2, 5.25, 0.5),
        ylabel="CN",
        ylim=(2.5, 11.2),
        yticks=np.arange(4, 12, 2),
        legend_options={
            'handles': custom_handles, 
            'labels': custom_labels, 
            'loc': 'upper left', 
            'handler_map': {Line2D: HandlerMaybeOutlinedLine2D()},
            'font_size': 24
        },
        yticks_remove=False,
        grid='y',
        save_svg='ccn_vs_distance',
        show=True
    )


# =============================================================================
# EXECUTION BLOCK (EXAMPLE USAGE)
# =============================================================================
# Construct simulation file paths for both bare and hydrogen-covered systems.
sim_files = [f"data/simulations/Pt111_{ion}{i}.pkl" for ion in IONS for i in [2]] + \
            [f"data/simulations/Pt111_{ion}{i}_H.pkl" for ion in IONS for i in [1]]

# Uncomment the following line to run the analysis:
# plot_ccn_vs_distance(sim_files=sim_files)

In [None]:
# =============================================================================
# VECTORIZED COORDINATION NUMBER CALCULATION (HOLLOW CN)
# =============================================================================
def compute_local_hollow_cn_vectorized(ion_positions, o_positions,
                                         inner_radius=2.75, outer_radius=4.25,
                                         box_lengths=None):
    """
    Compute the coordination number (CN) for each ion position by counting water 
    oxygen atoms within a hollow spherical shell defined by inner and outer radii.

    Parameters:
        ion_positions (np.ndarray): Array of shape (N, 3) containing N ion positions.
        o_positions (np.ndarray): Array of shape (n, 3) with water oxygen positions.
        inner_radius (float): Inner radius of the hollow sphere (Å).
        outer_radius (float): Outer radius of the hollow sphere (Å).
        box_lengths (np.ndarray, optional): Cell dimensions for applying periodic 
                                            boundary corrections.

    Returns:
        np.ndarray: Array of length N containing the CN for each ion position.
    """
    # Compute difference vectors between each ion and every oxygen.
    # The result has shape (N, n, 3).
    diff = o_positions[None, :, :] - ion_positions[:, None, :]
    
    # Apply periodic boundary corrections if box dimensions are provided.
    if box_lengths is not None:
        diff = diff - box_lengths * np.round(diff / box_lengths)
    
    # Compute the Euclidean distances for each ion-oxygen pair.
    distances = np.linalg.norm(diff, axis=2)
    
    # Create a boolean mask for distances within the hollow spherical shell.
    mask = (distances >= inner_radius) & (distances <= outer_radius)
    
    # Sum valid entries along the oxygen axis to obtain the CN per ion.
    cn_values = np.sum(mask, axis=1)
    return cn_values

# =============================================================================
# PLOT HOLLOW CN VERSUS DISTANCE TO THE SURFACE
# =============================================================================
def plot_hollow_cn_vs_distance(sim_file, ion='Cs', d_min=2.0, d_max=5.0,
                               n_d=50, grid_resolution=10):
    """
    Compute and plot the average hollow coordination number (CN) as a function 
    of the distance from the metal surface. For each simulation frame, an 
    "imaginary ion" is placed at multiple (x, y) grid points (at a given z 
    level above the surface) and the CN is computed using a hollow spherical shell.
    The result is averaged over grid points and frames. Optionally, an oxygen 
    density profile is computed and displayed as a shaded area in an inset.

    Coverage handling:
      - If the simulation filename ends with '_H.pkl', the CN curve and oxygen 
        density are plotted with a black outline.

    Parameters:
        sim_file (str): Path to the simulation pickle file.
        ion (str): Ion type ('Li', 'Na', 'K', or 'Cs') for the CN calculation.
        d_min (float): Minimum distance above the surface (Å) to evaluate CN.
        d_max (float): Maximum distance above the surface (Å) to evaluate CN.
        n_d (int): Number of distance bins.
        grid_resolution (int): Number of grid points per dimension in the x-y plane.
    """
    # ----------------------------
    # Ion-Specific Parameter Setup
    # ----------------------------
    if ion not in ["Li", "Na", "K", "Cs"] or ion == 'Cs':
        ion = "Cs"
        irad, orad = 2.75, 4.25
        o_start = 2.01
        ylim = (o_start, o_start + 6.5)
        logging.info("Invalid ion type; defaulting to Cs.")
    elif ion == "Li":
        irad, orad = 1.75, 2.75
        o_start = 0.01
        ylim = (o_start, o_start + 3)
    elif ion == "Na":
        irad, orad = 2.0, 3.2
        o_start = 1.01
        ylim = (o_start, o_start + 3.5)
    else:  # ion == "K"
        irad, orad = 2.5, 3.8
        o_start = 2.51
        ylim = (o_start, o_start + 3.5)

    # ----------------------------
    # Load Simulation and Setup Grid
    # ----------------------------
    sim = load_simulation(sim_file)
    n_frames = sim.trajectories.positions.all.shape[0]

    # Ensure cell dimensions are available for periodic boundary corrections.
    if sim.cell_dimensions is not None:
        cell = np.array(sim.cell_dimensions)
    else:
        raise ValueError("Cell dimensions are required for PBC correction and grid positioning.")

    # Generate grid points in the x-y plane.
    x_centers = np.linspace(0, cell[0], grid_resolution, endpoint=False) + cell[0] / (2 * grid_resolution)
    y_centers = np.linspace(0, cell[1], grid_resolution, endpoint=False) + cell[1] / (2 * grid_resolution)
    xv, yv = np.meshgrid(x_centers, y_centers, indexing='ij')
    grid_xy = np.column_stack((xv.ravel(), yv.ravel()))
    n_grid = grid_xy.shape[0]

    # Define distance bins (d values) at which to compute the CN.
    d_values = np.linspace(d_min, d_max, n_d)
    cn_accum = np.zeros(n_d)
    counts = np.zeros(n_d)

    # Containers for optional oxygen density profile.
    oxygen_distances_all = []
    total_oxygen_frames = 0
    distance_bins = np.linspace(d_min, d_max, n_d + 1)

    # ----------------------------
    # Process Simulation Frames
    # ----------------------------
    for frame in range(n_frames):
        # Log progress roughly every 25% of frames.
        if frame % (n_frames // 4) == 0:
            logging.info("Processing frame %d of %d", frame, n_frames)
        
        surface_z = sim.trajectories.surface_z[frame]
        o_positions = sim.trajectories.positions.watO[frame]
        
        # Accumulate oxygen distances relative to the surface.
        oxy_z = o_positions[:, 2] - surface_z
        valid_mask = (oxy_z >= d_min) & (oxy_z <= d_max)
        if np.any(valid_mask):
            oxygen_distances_all.extend(oxy_z[valid_mask])
            total_oxygen_frames += 1
        
        # For each distance d, compute CN over all grid points.
        for i, d in enumerate(d_values):
            # Construct "imaginary" ion positions at grid (x, y) with z = surface_z + d.
            ion_positions = np.hstack((grid_xy, np.full((n_grid, 1), surface_z + d)))
            cn_grid = compute_local_hollow_cn_vectorized(ion_positions, o_positions,
                                                         inner_radius=irad, outer_radius=orad,
                                                         box_lengths=cell)
            avg_cn_frame = np.mean(cn_grid)
            cn_accum[i] += avg_cn_frame
            counts[i] += 1

    # Average the CN values over all frames.
    avg_cn = cn_accum / counts

    # ----------------------------
    # Compute Oxygen Density Profile (Optional)
    # ----------------------------
    oxygen_density_norm = None
    if oxygen_distances_all and total_oxygen_frames > 0:
        counts_o, _ = np.histogram(oxygen_distances_all, bins=distance_bins)
        bin_width = distance_bins[1] - distance_bins[0]
        global_area = cell[0] * cell[1]
        oxygen_density = counts_o / (global_area * bin_width * total_oxygen_frames)
        if np.max(oxygen_density) > 0:
            # Normalize and shift oxygen density for plotting.
            oxygen_density_norm = o_start + oxygen_density * ((ylim[1] - ylim[0]) / np.max(oxygen_density))
        bin_centers_distance = 0.5 * (distance_bins[:-1] + distance_bins[1:])
    else:
        bin_centers_distance = None

    # ----------------------------
    # Define Plotting Callback Function
    # ----------------------------
    def plot_callback(ax, legend_font_size, legend_loc):
        # Determine if the simulation is covered (filename ending with '_H.pkl').
        is_covered = sim_file.endswith('_H.pkl')
        
        # Plot oxygen density profile as a shaded area.
        if oxygen_density_norm is not None and bin_centers_distance is not None:
            ax.fill_between(bin_centers_distance, oxygen_density_norm, color="black", alpha=0.2)
            if is_covered:
                ax.plot(bin_centers_distance, oxygen_density_norm, color='black', lw=1)
        
        # Plot the average CN curve.
        if is_covered:
            ax.plot(d_values, avg_cn, linestyle='-', linewidth=8, color="black")
        else:
            ion_color = ION_COLORS.get(ion, "#4472C4")
            ax.plot(d_values, avg_cn, linestyle='-', linewidth=8, color=ion_color)
    
    # ----------------------------
    # Generate and Display the Plot Using Custom Plot
    # ----------------------------
    custom_plot(
        xlabel="Distance from Surface (Å)",
        ylabel="CN",
        font_size=40,
        xticks=np.arange(d_min, d_max + 1, 1),
        xlim=(d_min, d_max),
        ylim=ylim,
        legend_options={"font_size": 24, "alpha": 0.8, "loc": "upper left"},
        plot_func=plot_callback,
        show=True
    )

# =============================================================================
# EXECUTION BLOCK (UNCOMMENT TO RUN)
# =============================================================================
simulation_file = "data/simulations/Pt111_Li4_H.pkl"

# Uncomment the following line to run the analysis on the provided simulation file.
# plot_hollow_cn_vs_distance(simulation_file, ion='Li', d_max=6, grid_resolution=5, n_d=20)

In [None]:
# =============================================================================
# WATER RESIDENCE TIME CALCULATION
# =============================================================================

def compute_water_residence_time(sim_file, ion_type, skip_time=0):
    """
    Compute the average residence time (lifetime) of water molecules within the solvation sphere
    of a specified cation. The residence time is defined as the duration (in ps) for which a water
    molecule continuously remains inside the solvation sphere.

    The process involves:
      1. Loading the simulation.
      2. Selecting the target cations based on the specified ion type.
      3. Determining whether water oxygen atoms are within the solvation sphere of any target ion.
      4. Constructing an occupancy time series (True if inside, False otherwise).
      5. Extracting contiguous "in-shell" segments from the occupancy series and converting them 
         to time units using the simulation timestep.
      6. Returning the average residence time (ps), all residence durations (ps), and a coverage flag.

    The coverage flag indicates whether the simulation is "covered" by adsorbates (True) or not (False).

    Parameters:
        sim_file (str): Path to the simulation pickle file.
        ion_type (str): Ion element (e.g., "Cs", "K", etc.) defining the solvation sphere.
        skip_time (float): Skip simulation frames with time < skip_time (in ps).

    Returns:
        tuple: (avg_lifetime, lifetimes, coverage_flag)
            avg_lifetime (float): Average residence time (ps) of water molecules within the solvation sphere.
            lifetimes (list): List of individual residence time durations (ps).
            coverage_flag (bool): True if adsorbates are present (covered), False otherwise.

    Raises:
        ValueError: If required data (e.g., cutoff, ion type, water positions) are missing.
    """
    # -------------------------------
    # Load Simulation and Determine Coverage
    # -------------------------------
    sim_data = load_simulation(sim_file)
    
    # Check if adsorbates are present to determine coverage.
    if (hasattr(sim_data.trajectories.positions, 'adsorbates') and 
        sim_data.trajectories.positions.adsorbates is not None and 
        sim_data.trajectories.positions.adsorbates.size > 0):
        coverage_flag = True
    else:
        coverage_flag = False

    # -------------------------------
    # Retrieve Ion-Specific Parameters and Data
    # -------------------------------
    # Get the cutoff distance for the specified ion type.
    cutoff = ION_CUTOFFS.get(ion_type, None)
    if cutoff is None:
        raise ValueError(f"No cutoff defined for ion type {ion_type}.")

    # Identify indices for the selected ion type.
    ion_indices = [i for i, ion in enumerate(sim_data.ions) if ion == ion_type]
    if not ion_indices:
        raise ValueError(f"No ions of type {ion_type} found in simulation.")

    # Extract water oxygen positions (expected shape: [n_frames, n_water, 3]).
    try:
        water_positions = getattr(sim_data.trajectories.positions, "watO")
    except AttributeError:
        raise ValueError("Water positions with label 'watO' not found in simulation.")

    # Extract positions of the selected ions (shape: [n_frames, n_selected, 3]).
    ion_positions_all = sim_data.trajectories.positions.ions[:, ion_indices, :]

    # -------------------------------
    # Select Valid Frames Based on Skip Time
    # -------------------------------
    times = sim_data.trajectories.times
    valid_frame_indices = np.where(times >= skip_time)[0]
    if valid_frame_indices.size == 0:
        raise ValueError("No frames available after applying skip_time.")
    
    # Subset data to valid frames.
    water_positions = water_positions[valid_frame_indices]      # Shape: (n_valid, n_water, 3)
    ion_positions_all = ion_positions_all[valid_frame_indices]    # Shape: (n_valid, n_selected, 3)
    valid_times = times[valid_frame_indices]
    n_water = water_positions.shape[1]

    # -------------------------------
    # Vectorized Occupancy Calculation
    # -------------------------------
    # Compute displacement vectors between each water molecule and each selected ion.
    # Resulting shape: (n_valid, n_water, n_selected, 3)
    diff = water_positions[:, :, None, :] - ion_positions_all[:, None, :, :]

    # Apply periodic boundary conditions if cell dimensions are provided.
    if sim_data.cell_dimensions is not None:
        cell = np.array(sim_data.cell_dimensions)
        box_lengths = np.diag(cell) if cell.ndim > 1 else cell
        diff = apply_periodic_boundary(diff, box_lengths)
    
    # Compute Euclidean distances (shape: (n_valid, n_water, n_selected)).
    distances = np.linalg.norm(diff, axis=-1)
    
    # Determine occupancy: True if any ion is within the cutoff, for each water molecule in each frame.
    occupancy = np.any(distances < cutoff, axis=2)

    # -------------------------------
    # Determine Simulation Timestep
    # -------------------------------
    if hasattr(sim_data, "timestep") and sim_data.timestep is not None:
        dt = sim_data.timestep
    else:
        dt_array = np.diff(valid_times)
        dt = np.mean(dt_array) if dt_array.size > 0 else 1.0

    # -------------------------------
    # Compute Residence Durations for Each Water Molecule
    # -------------------------------
    def get_contiguous_durations(bool_array):
        """
        Helper function that calculates contiguous True segments in a boolean array.
        Converts the length of each segment into time units using the timestep dt.
        """
        durations = []
        current_length = 0
        for val in bool_array:
            if val:
                current_length += 1
            else:
                if current_length > 0:
                    durations.append(current_length * dt)
                    current_length = 0
        if current_length > 0:
            durations.append(current_length * dt)
        return durations

    # Iterate over each water molecule and extract contiguous residence durations.
    all_lifetimes = []
    for j in range(n_water):
        durations = get_contiguous_durations(occupancy[:, j])
        all_lifetimes.extend(durations)

    # Compute the average residence time if any events were detected.
    if all_lifetimes:
        avg_lifetime = np.mean(all_lifetimes)
    else:
        avg_lifetime = 0.0
        logging.info(f"No water molecule residence events found for ion type {ion_type}.")

    return avg_lifetime, all_lifetimes, coverage_flag


# =============================================================================
# EXECUTION BLOCK (UNCOMMENT TO RUN)
# =============================================================================

# # Compute the average lifetime per ion type.
# ion_lifetimes = {}
# for ion in IONS:
#     lifetimes = []
#     for i in range(1, 5):
#         for h in ["", "_H"]:
#             sim_file = f"data/simulations/Pt111_{ion}{i}{h}.pkl"
#             avg_lifetime, _, _ = compute_water_residence_time(sim_file, ion)
#             lifetimes.append(avg_lifetime)
#     ion_lifetimes[ion] = np.mean(lifetimes)
# print("Average residence times (fs) per ion type:", ion_lifetimes)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle, Patch

def plot_density_on_surface(pickle_path, group='water', distance_threshold=3.0, bins=250, metal_radius=1.0):
    """
    Visualize the average XY density of a selected group near the metal surface 
    and overlay the averaged positions of all metal atoms grouped by layers.
    
    Metal atoms are drawn as silver-colored circles with a specified radius.
    Layers are plotted in order (bottom layer first, top layer last) so that the
    top layer appears on top. The density heatmap is plotted last.
    
    The density heatmap is plotted over the cell (0 to cell_limit), but metal atoms
    are drawn in periodic copies (shifted by the cell dimensions) so that atoms near
    the boundaries show their periodic images (even if only partially visible).
    
    Parameters:
        pickle_path (str): Path to the simulation pickle file.
        group (str): Group to analyze ('water', 'adsorbates', or 'ions').
        distance_threshold (float): Maximum vertical distance (Å) from the surface.
        bins (int): Number of bins for the 2D histogram.
        metal_radius (float): Radius (in Å) used to plot the metal atoms.
    """
    # Validate group selection.
    allowed_groups = ['water', 'adsorbates', 'ions']
    if group not in allowed_groups:
        raise ValueError(f"Invalid group specified: {group}. Allowed groups: {allowed_groups}")
    
    # Setup logging and load simulation data.
    setup_logging(verbose=False)
    sim = load_simulation(pickle_path)
    n_frames = sim.trajectories.times.shape[0]
    
    # Determine whether to add water coverage contours.
    coverage_flag = (group == 'water' and hasattr(sim.trajectories.positions, 'adsorbates') and 
                     sim.trajectories.positions.adsorbates is not None and 
                     sim.trajectories.positions.adsorbates.size > 0)
    
    # Determine the cell extent for plotting.
    cell = sim.cell_dimensions
    if np.ndim(cell) == 1:
        xlim, ylim = (0, cell[0]), (0, cell[1])
    else:
        diag = np.diag(cell)
        xlim, ylim = (0, diag[0]), (0, diag[1])
    
    # Create edges for the density histogram (heatmap spans the cell limits).
    xedges = np.linspace(xlim[0], xlim[1], bins + 1)
    yedges = np.linspace(ylim[0], ylim[1], bins + 1)
    
    # Initialize accumulator for the density heatmap.
    group_hist_sum = np.zeros((bins, bins))
    
    # Use the number of metal atoms from the first frame.
    first_metal = sim.trajectories.positions.metal[0]
    n_metal = first_metal.shape[0]
    metal_positions_sum = np.zeros((n_metal, 3))
    
    # Loop over frames: accumulate group density and metal atom positions.
    for frame in range(n_frames):
        # Process group density near the metal surface.
        group_positions = getattr(sim.trajectories.positions, group)[frame]
        surface_z = sim.trajectories.surface_z[frame]
        group_mask = np.abs(group_positions[:, 2] - surface_z) < distance_threshold
        group_near = group_positions[group_mask]
        if group_near.size > 0:
            group_xy = group_near[:, :2]
            hist, _, _ = np.histogram2d(group_xy[:, 0], group_xy[:, 1], bins=[xedges, yedges])
            group_hist_sum += hist
        
        # Accumulate metal atom positions (x, y, z) for averaging.
        metal = sim.trajectories.positions.metal[frame]
        metal_positions_sum += metal  # sum over frames
    
    # Compute average group density and average metal atom positions.
    group_hist_avg = group_hist_sum / n_frames
    avg_metal_positions = metal_positions_sum / n_frames
    
    # Determine the number of metal layers (assumed provided in the simulation).
    num_layers = sim.lattice_dimensions[2] if sim.lattice_dimensions is not None else 1
    n_atoms_layer = n_metal // num_layers
    
    # Sort metal atoms by their average z coordinate (highest = top layer).
    sorted_indices = np.argsort(avg_metal_positions[:, 2])[::-1]
    
    # Group metal atoms into layers.
    metal_layers = []
    for i in range(num_layers):
        indices = sorted_indices[i * n_atoms_layer:(i + 1) * n_atoms_layer]
        metal_layers.append(avg_metal_positions[indices])
    
    # Plotting callback function.
    def plot_callback(ax, legend_font_size, legend_loc):
        # Set fixed axis limits for the heatmap (0 to cell limit).
        ax.set_xlim(xlim[0], xlim[1])
        ax.set_ylim(ylim[0], ylim[1])
        
        # Cell dimensions.
        cell_x = xlim[1] - xlim[0]
        cell_y = ylim[1] - ylim[0]
        
        # Plot metal atoms by layer with periodic copies.
        # Iteration over layers (draw bottom layers first).
        for i in reversed(range(num_layers)):
            # Compute silver shade: top layer is the brightest.
            if num_layers > 1:
                brightness = 0.9 - 0.2 * (i / (num_layers - 1))
            else:
                brightness = 0.9
            silver_color = (brightness, brightness, brightness)
            zorder_layer = num_layers - i  # ensure top layers are drawn last
            
            # For each metal atom in the layer, draw it and its periodic images.
            for pos in metal_layers[i]:
                # Iterate over translations: original plus neighbors.
                for dx in (-cell_x, 0, cell_x):
                    for dy in (-cell_y, 0, cell_y):
                        circ = Circle((pos[0] + dx, pos[1] + dy), metal_radius,
                                      facecolor=silver_color, edgecolor='k', lw=1,
                                      zorder=zorder_layer)
                        ax.add_patch(circ)
        
        # Optionally overlay water coverage contours.
        if coverage_flag:
            ax.contour(group_hist_avg.T, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
                       origin='lower', levels=5, colors='black', linewidths=1, zorder=num_layers+1)
        
        # Plot the density heatmap on top with transparency.
        extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
        im = ax.imshow(group_hist_avg.T, extent=extent, origin='lower', aspect='auto',
                       cmap='Reds', interpolation='nearest', alpha=0.9, zorder=num_layers+2)
        cbar = ax.figure.colorbar(im, ax=ax)
        cbar.set_label(f"Average {group.capitalize()} Density", fontsize=legend_font_size)

    
    # Generate and display the final plot using the custom plotting routine.
    custom_plot(
        xlabel="X (Å)",
        ylabel="Y (Å)",
        font_size=20,
        figsize=(10, 8),
        plot_func=plot_callback
    )

# =============================================================================
# EXECUTION BLOCK (UNCOMMENT TO RUN)
# =============================================================================
sim_file = "data/simulations/Pt111_Li1.pkl"

# Uncomment the following line to run the analysis on the provided simulation pickle file.
plot_density_on_surface(sim_file, group='water', metal_radius=1.23, distance_threshold=2.5)
