Main Jupyter Notebook that contains all codes to generate the plots used in the paper by R.S. Kort, A. Hagopian, K. Doblhoff-Dier, M. Koper about the Impact of Applied Potential and Hydrogen Coverage on Alkali Metal Cation Behaviour.





In [1]:
# Import all the required libraries
import numpy as np
import pickle
import logging
import os

# Import all matplotlib related libraries
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib.patheffects as pe
from matplotlib.collections import LineCollection
from matplotlib.lines import Line2D
from matplotlib.legend_handler import HandlerLine2D
from matplotlib.patches import ArrowStyle

In [2]:
# Definition of parameters, basic functions, and logging
# -------------------------------------------------------

# Define some parameters
IONS = ["Li", "Na", "K", "Cs"]
ION_COLORS = {"Li": "#70AD47", "Na": "#FFC000", "K": "#A032A0", "Cs": "#4472C4"}
ION_CUTOFFS = {"Li": 2.75, "Na": 3.2, "K": 3.8, "Cs": 4.25}

# Define same general functions
# -----------------------------
def blend_color(hex_color, fraction):
    """
    Blend the given hex color with white. fraction=0 returns the original color,
    fraction=1 returns white.
    """
    hex_color = hex_color.lstrip('#')
    r = int(hex_color[0:2], 16)
    g = int(hex_color[2:4], 16)
    b = int(hex_color[4:6], 16)
    r_new = int(r + (255 - r) * fraction)
    g_new = int(g + (255 - g) * fraction)
    b_new = int(b + (255 - b) * fraction)
    return f"#{r_new:02X}{g_new:02X}{b_new:02X}"

def setup_logging(verbose=False):
    """
    Configure logging based on verbosity.

    Parameters:
        verbose (bool): If True, set logging level to DEBUG, else INFO.
    """
    log_level = logging.DEBUG if verbose else logging.INFO
    logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=log_level)



# Function to load the Simulation object from a pickle file.
def load_simulation(filename):
    """
    Load a Simulation object from a pickle file.

    Parameters:
        filename (str): The path to the pickle file.

    Returns:
        Simulation: The loaded Simulation object.
    """
    try:
        with open(filename, 'rb') as pf:
            simulation = pickle.load(pf)
        logging.debug("Successfully loaded Simulation from %s", filename)
        return simulation
    except Exception as e:
        logging.error("Failed to load Simulation from %s: %s", filename, e)
        raise

# Function to compute the histogram of a 1D data array.
def compute_histogram(data, bin_edges, normalize=False):
    """
    Compute a histogram for the given data and bin edges.
    
    Parameters:
      data (np.ndarray): 1D array of data points.
      bin_edges (np.array): Array of bin edge values.
      normalize (bool): If True, normalize the histogram so that the sum of bins equals 1.
    
    Returns:
      counts (np.ndarray): The histogram counts (or probabilities if normalized).
      bin_centers (np.ndarray): The centers of the bins.
      bin_width (float): The uniform width of the bins.
    """
    counts, _ = np.histogram(data, bins=bin_edges)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    bin_width = bin_edges[1] - bin_edges[0]
    
    if normalize:
        # Normalize by the total count so that the area under the histogram equals 1.
        total = np.sum(counts)
        if total > 0:
            counts = counts / total
    return counts, bin_centers, bin_width

# Function to smooth a 1D data array using a moving average.
def smooth_data(data, window_size=3):
    """
    Smooth a 1D data array using a moving average (simple convolution).
    
    Parameters:
      data (np.ndarray): 1D array of data points to smooth.
      window_size (int): The size of the smoothing window. A window_size of 1 returns the original data.
    
    Returns:
      smoothed (np.ndarray): The smoothed data array.
    """
    if window_size <= 1:
        return data
    # Create a uniform kernel
    kernel = np.ones(window_size) / window_size
    # Use convolution with 'same' mode to keep the original data length.
    smoothed = np.convolve(data, kernel, mode='same')
    return smoothed

def apply_periodic_boundary(diff, box_lengths):
    """
    Applies periodic boundary conditions to a difference vector.

    Parameters:
        diff (np.ndarray): The difference vector.
        box_lengths (np.ndarray): The box lengths for each dimension.
    """
    return diff - box_lengths * np.round(diff / box_lengths)


# Custom plotting functions 
# -------------------------

def custom_plot(xlabel="X", ylabel="Y", font_size=60, font_family="Times New Roman",
                figsize=(12, 8), yticks_remove=False, grid=None, tight_layout=True,
                xticks=None, yticks=None, xlim=None, ylim=None,
                legend_options=None, plot_func=None, 
                save_png=None, save_svg=None, show=True):
    """
    Create and finalize a plot in one function. This function sets up the figure,
    configures the axis, applies a custom plotting callback, and optionally saves 
    the figure to file.

    Parameters:
        xlabel (str): Label for the x-axis.
        ylabel (str): Label for the y-axis.
        font_size (int): Font size for the plot labels.
        font_family (str): Font family for the plot labels.
        figsize (tuple): Figure size (width, height) in inches.
        yticks_remove (bool): If True, removes the y-axis ticks.
        grid (str or None): Type of grid to display. Options are "x", "y", "both", or None.
        tight_layout (bool): If True, applies tight layout.
        xticks (list or None): Locations for x-axis ticks.
        yticks (list or None): Locations for y-axis ticks.
        xlim (tuple or None): Limits for the x-axis as (min, max).
        ylim (tuple or None): Limits for the y-axis as (min, max).
        legend_options (dict or None): Options for customizing the legend.
            Supported keys:
                - "font_size" (int): Font size for legend text.
                - "draw_bg" (bool): Whether to display a background patch.
                - "border" (bool): Whether to draw a border around the legend.
                - "bgcolor" (str): Background color for the legend.
                - "edgecolor" (str): Border color for the legend.
                - "alpha" (float): Transparency for the legend background (0 to 1).
                - "linewidth" (float, optional): Line width for legend lines.
                - "loc" (str): Location of the legend.
                - "handles" (list, optional): Custom legend handles.
                - "labels" (list, optional): Custom legend labels.
                - "handler_map" (dict, optional): Custom legend handler map.
        plot_func (function or None): Callback function that accepts the current axis (ax)
                                      to perform custom plotting.
        save_png (str or None): Filename (without extension) to save the plot as PNG.
        save_svg (str or None): Filename (without extension) to save the plot as SVG.
        show (bool): If True, displays the plot.

    Returns:
        tuple: A tuple containing the axis (ax) and the legend (if created, otherwise None).
    """
    # Create the figure and axis.
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(1, 1, 1)
    
    # Update global plot parameters.
    plt.rcParams.update({'font.size': font_size, 'font.family': font_family})
    
    # Set axis labels.
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    
    # Remove y-axis ticks if requested.
    if yticks_remove:
        ax.set_yticks([])
    
    # Hide top, right, and left spines.
    for spine in ['top', 'right', 'left']:
        ax.spines[spine].set_visible(False)
    
    # Configure grid settings.
    if grid == "x":
        ax.yaxis.grid(False)
        ax.xaxis.grid(True)
    elif grid == "y":
        ax.xaxis.grid(False)
        ax.yaxis.grid(True)
    elif grid in [None, "None", False]:
        ax.grid(False)
    else:  # Default: enable grid on both axes.
        ax.grid(True)
    
    # Apply tight layout if enabled.
    if tight_layout:
        plt.tight_layout()
    
    # Execute custom plotting function if provided.
    if plot_func is not None:
        legend_font_size = legend_options.get('font_size', 32) if legend_options else 32
        legend_loc = legend_options.get('loc', 'best') if legend_options else 'best'
        plot_func(ax, legend_font_size, legend_loc)

    # Set custom ticks and axis limits if specified.
    if xticks is not None:
        ax.set_xticks(xticks)
    if yticks is not None:
        ax.set_yticks(yticks)
    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)
    
    # Handle legend customization if legend_options are provided.
    legend = None
    if legend_options:
        # Attempt to retrieve an existing legend.
        legend = ax.get_legend()
        if legend is None:
            # Create a legend using custom handles/labels if available.
            handles = legend_options.get('handles')
            labels = legend_options.get('labels')
            handler_map = legend_options.get('handler_map')
            legend_loc = legend_options.get('loc', 'best')
            if handler_map:
                if handles and labels:
                    legend = ax.legend(handles, labels, loc=legend_loc,
                                       handler_map=handler_map)
                else:
                    legend = ax.legend(loc=legend_loc, handler_map=handler_map)
            else:
                if handles and labels:
                    legend = ax.legend(handles, labels, loc=legend_loc)
                else:
                    legend = ax.legend(loc=legend_loc)
        
        # Update legend properties.
        legend_font_size = legend_options.get('font_size', 36)
        plt.setp(legend.get_texts(), fontsize=legend_font_size)
        
        # Adjust legend location if specified.
        legend_loc = legend_options.get('loc', 'best')
        if legend_loc != 'best':
            legend.set_loc(legend_loc)
        
        # Customize the legend background if draw_bg is True.
        if legend_options.get('draw_bg', True):
            legend_bgcolor = legend_options.get('bgcolor', 'white')
            legend_alpha = legend_options.get('alpha', 0.8)
            legend_border = legend_options.get('border', False)
            legend_edgecolor = legend_options.get('edgecolor', 'black')
            
            frame = legend.get_frame()
            frame.set_facecolor(legend_bgcolor)
            frame.set_alpha(legend_alpha)
            frame.set_edgecolor(legend_edgecolor if legend_border else 'none')
    
    # Save the figure if requested.
    if save_png:
        fig_path = os.path.join('figures', f"{save_png}.png")
        plt.savefig(fig_path, format='png')
    if save_svg:
        fig_path = os.path.join('figures', f"{save_svg}.svg")
        plt.savefig(fig_path, format='svg')
    
    # Display the plot if required.
    if show:
        plt.show()
    
    return ax, legend

# Custom legend handler to add an outline if desired.
class HandlerMaybeOutlinedLine2D(HandlerLine2D):
    def create_artists(self, legend, orig_handle,
                       x0, y0, width, height, fontsize, trans):
        if getattr(orig_handle, '_outline', False):
            # Create a thicker black outline behind the line.
            outline_line = Line2D([x0, x0 + width],
                                  [y0 + height / 2, y0 + height / 2],
                                  linestyle=orig_handle.get_linestyle(),
                                  color='black',
                                  linewidth=orig_handle.get_linewidth() + 2,
                                  marker=orig_handle.get_marker())
            outline_line.set_transform(trans)
            line = Line2D([x0, x0 + width],
                          [y0 + height / 2, y0 + height / 2],
                          linestyle=orig_handle.get_linestyle(),
                          color=orig_handle.get_color(),
                          linewidth=orig_handle.get_linewidth(),
                          marker=orig_handle.get_marker())
            line.set_transform(trans)
            return [outline_line, line]
        else:
            line = Line2D([x0, x0 + width],
                          [y0 + height / 2, y0 + height / 2],
                          linestyle=orig_handle.get_linestyle(),
                          color=orig_handle.get_color(),
                          linewidth=orig_handle.get_linewidth(),
                          marker=orig_handle.get_marker())
            line.set_transform(trans)
            return [line]

In [3]:
# ------------------ Radial Distribution Function (RDF) Core Functions ------------------

def extract_simulation_parameters(data, bin_edges):
    """
    Extract key simulation parameters including cell volume, spherical shell volumes for RDF calculations,
    bin centers, bin width, and effective box lengths.
    
    Parameters:
        data: Simulation object containing cell dimensions.
        bin_edges (np.ndarray): Array of bin edge values for the RDF histogram.
    
    Returns:
        tuple: (volume, shell_volumes, bin_centers, bin_width, box_lengths)
               - volume (float): Total simulation cell volume.
               - shell_volumes (np.ndarray): Volumes of spherical shells defined by bin_edges.
               - bin_centers (np.ndarray): Centers of the bins.
               - bin_width (float): Uniform width of the bins.
               - box_lengths (np.ndarray): Box lengths in each dimension.
               
    Raises:
        Warning: Logs a warning if no cell dimensions are found.
    """
    cell = data.cell_dimensions
    if cell is None:
        logging.warning("Simulation data missing cell dimensions.")
        return None, None, None, None, None
    
    cell_array = np.array(cell)
    if cell_array.ndim == 1:
        volume = np.prod(cell_array)
        box_lengths = cell_array
    else:
        diag_cell = np.diag(cell_array)
        volume = np.prod(diag_cell)
        box_lengths = diag_cell

    shell_volumes = (4.0 / 3.0) * np.pi * (bin_edges[1:]**3 - bin_edges[:-1]**3)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    bin_width = bin_edges[1] - bin_edges[0]
    return volume, shell_volumes, bin_centers, bin_width, box_lengths


def compute_frame_rdf(oxygen_positions, ion_positions, oxygen_label, ion_label,
                      bin_edges, shell_volumes, bin_centers, bin_width, volume, box_lengths,
                      smoothing_window=1):
    """
    Compute the radial distribution function (RDF) for a single simulation frame
    between oxygen atoms and ions. Self-correlations are excluded when both labels match.
    
    Parameters:
        oxygen_positions (np.ndarray): Coordinates for oxygen atoms (shape: (N_oxygen, 3)).
        ion_positions (np.ndarray): Coordinates for ion positions (shape: (N_ions, 3)).
        oxygen_label (str): Identifier for the oxygen species.
        ion_label (str): Identifier for the ion species.
        bin_edges (np.ndarray): Bin edge values for histogram calculation.
        shell_volumes (np.ndarray): Precomputed volumes of spherical shells.
        bin_centers (np.ndarray): Centers of the histogram bins.
        bin_width (float): Width of the histogram bins.
        volume (float): Total simulation cell volume.
        box_lengths (np.ndarray): Simulation box lengths for applying periodic boundary conditions.
        smoothing_window (int, optional): Window size for smoothing the histogram. Default is 1 (no smoothing).
    
    Returns:
        tuple: (rdf_frame, density_oxygen)
               - rdf_frame (np.ndarray): RDF values for the current frame.
               - density_oxygen (float): Density of oxygen atoms within the simulation cell.
    """
    # Calculate displacement vectors accounting for periodic boundaries.
    diff_vectors = oxygen_positions[:, np.newaxis, :] - ion_positions
    diff_vectors = apply_periodic_boundary(diff_vectors, box_lengths)
    distances = np.linalg.norm(diff_vectors, axis=-1).ravel()
    
    # Exclude self-correlations if oxygen and ion species are identical.
    if oxygen_label == ion_label:
        distances = distances[distances > 1e-6]
    
    # Compute histogram of distances.
    counts, _, _ = compute_histogram(distances, bin_edges, normalize=False)
    
    # Apply smoothing if required.
    if smoothing_window > 1:
        counts = smooth_data(counts, window_size=smoothing_window)
    
    # Normalize histogram to obtain RDF.
    num_oxygen = oxygen_positions.shape[0]
    density_oxygen = num_oxygen / volume
    num_ions = ion_positions.shape[0]
    expected_counts = num_ions * density_oxygen * shell_volumes
    
    with np.errstate(divide='ignore', invalid='ignore'):
        rdf_frame = counts / expected_counts
        rdf_frame[np.isnan(rdf_frame)] = 0
    
    return rdf_frame, density_oxygen


def calculate_average_rdf(data, oxygen_label, ion_label, skip_time, bin_edges, ion_indices=None):
    """
    Compute the averaged RDF between oxygen and ion species over simulation frames.
    
    Parameters:
        data: Simulation object containing trajectories and metadata.
        oxygen_label (str): Label for the oxygen positions array (e.g., "watO").
        ion_label (str): Label for the ion positions array (e.g., "ions").
        skip_time (float): Time threshold below which frames are excluded.
        bin_edges (np.ndarray): Bin edge values for RDF histogram.
        ion_indices (list, optional): List of indices corresponding to a specific ion type.
    
    Returns:
        tuple: (bin_centers, avg_rdf, avg_density)
               - bin_centers (np.ndarray): Centers of the RDF bins.
               - avg_rdf (np.ndarray): Average RDF computed over selected frames.
               - avg_density (float): Average oxygen density computed over selected frames.
        
        Returns (None, None, None) if no valid frames are found.
    """
    try:
        oxygen_positions_all = getattr(data.trajectories.positions, oxygen_label)
    except AttributeError:
        logging.warning("No position data for oxygen species '%s' found.", oxygen_label)
        return None, None, None

    if ion_indices is None:
        ion_positions_all = data.trajectories.positions.ions
    else:
        ion_positions_all = data.trajectories.positions.ions[:, ion_indices, :]

    parameters = extract_simulation_parameters(data, bin_edges)
    if parameters[0] is None:
        return None, None, None
    volume, shell_volumes, bin_centers, bin_width, box_lengths = parameters

    rdf_collection = []
    density_collection = []
    valid_frame_count = 0
    for frame_idx, current_time in enumerate(data.trajectories.times):
        if current_time < skip_time:
            continue
        valid_frame_count += 1
        oxygen_positions = oxygen_positions_all[frame_idx]
        ion_positions = ion_positions_all[frame_idx]
        rdf_frame, density_oxygen = compute_frame_rdf(
            oxygen_positions, ion_positions,
            oxygen_label, ion_label,
            bin_edges, shell_volumes, bin_centers, bin_width,
            volume, box_lengths
        )
        rdf_collection.append(rdf_frame)
        density_collection.append(density_oxygen)

    if valid_frame_count == 0:
        logging.warning("No frames exceeded the skip time threshold (skip_time=%.2f).", skip_time)
        return None, None, None

    avg_rdf = np.mean(rdf_collection, axis=0)
    avg_density = np.mean(density_collection)
    return bin_centers, avg_rdf, avg_density

# ------------------ Aggregation and Plotting of RDF Data ------------------

def plot_ion_oxygen_rdf(simulation_files, skip_time=5, num_bins=50, max_distance=8, verbose=False):
    """
    Aggregate and plot the oxygen-ion radial distribution functions (RDFs) across multiple simulation files.
    
    Parameters:
        simulation_files (list): List of file paths to simulation pickle files.
        skip_time (float): Time threshold to skip initial frames.
        num_bins (int): Number of bins for the RDF histogram.
        max_distance (float): Maximum distance (in Å) for the RDF analysis.
        verbose (bool): Toggle for detailed logging.
    
    The function processes each simulation file, computes the RDF for each unique ion type,
    aggregates the data, and generates a plot with one averaged RDF curve per ion type.
    """
    # Configure logging.
    logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", 
                        level=logging.DEBUG if verbose else logging.INFO)
    
    bin_edges = np.linspace(0, max_distance, num_bins + 1)
    aggregated_data = {}  # Dictionary to store RDF data per ion type.

    # Process each simulation file.
    for file_path in simulation_files:
        logging.debug("Processing simulation file: %s", file_path)
        try:
            sim_data = load_simulation(file_path)
        except Exception as error:
            logging.error("Error loading file %s: %s", file_path, error)
            continue
        
        # Iterate over each unique ion type present in the simulation.
        unique_ion_types = set(sim_data.ions)
        for ion_type in unique_ion_types:
            ion_indices = [i for i, ion in enumerate(sim_data.ions) if ion == ion_type]
            
            # For this analysis, we consider oxygen positions defined under "watO".
            bin_centers, rdf_average, _ = calculate_average_rdf(
                sim_data, "watO", "ions", skip_time, bin_edges, ion_indices=ion_indices
            )
            if bin_centers is None or rdf_average is None:
                logging.warning("Skipping simulation %s for ion type '%s' due to insufficient RDF data.", 
                                file_path, ion_type)
                continue
            
            # Weight the RDF by the number of ions of this type.
            weighted_rdf = rdf_average * len(ion_indices)
            label = f"{ion_type}"
            if label not in aggregated_data:
                aggregated_data[label] = {"bin_centers": bin_centers, "rdf_sum": weighted_rdf, "ion_count": len(ion_indices)}
            else:
                aggregated_data[label]["rdf_sum"] += weighted_rdf
                aggregated_data[label]["ion_count"] += len(ion_indices)
    
    # Prepare aggregated RDF data for plotting.
    rdf_plot_data = []
    for label, data_dict in aggregated_data.items():
        avg_rdf = data_dict["rdf_sum"] / data_dict["ion_count"]
        rdf_plot_data.append((data_dict["bin_centers"], avg_rdf, label))
        logging.debug("Aggregated RDF for ion type '%s': Total ion count = %d", label, data_dict["ion_count"])
    
    # Define plotting callback.
    def plot_callback(ax, legend_font_size=32, legend_loc='best'):
        for bin_centers, rdf, label in rdf_plot_data:
            ax.plot(bin_centers, rdf, linestyle='-', linewidth=8, marker=None,
                    label=label, color=ION_COLORS.get(label, 'black'))
            # Set legend font size and location
            ax.legend(fontsize=legend_font_size, loc=legend_loc)
    
    # Generate and display the plot using the custom plotting utility.
    custom_plot(
        xlabel="Ion–O Distance (Å)",
        ylabel="g(r)",
        font_size=40,
        yticks_remove=True,
        figsize=(12, 8),
        xticks=np.arange(2, 7, 1),
        xlim=(1.5, 6.5),
        ylim=(0),
        legend_options={
            'loc': 'upper right',
            'font_size': 32,
            'draw_bg': True,
            'bgcolor': 'white',
            'alpha': 0.8,
            'border': False
        },
        plot_func=plot_callback,
        show=True,
        save_svg='ion-O_rdf'
    )



# ------------------ Example Usage ------------------
# Define simulation identifiers for ion-O RDF analysis.
simulation_ids = ["Li2", "Na2", "K2", "Cs2"]
simulation_file_paths = [f"data/simulations/Pt111_{sim_id}.pkl" for sim_id in simulation_ids]


# Uncomment to run the analysis and plot the results.
# -----------------------------------------------------

# plot_ion_oxygen_rdf(simulation_files=simulation_file_paths)

In [4]:
# --- Helper Functions ---

def calculate_histogram_density(all_distances, bin_edges, cell_dimensions, valid_frames):
    """
    Compute the density histogram using compute_histogram.
    Normalizes counts by area (converted from Å² to nm²), bin width (in nm) and number of valid frames.
    """
    counts, bin_centers, bin_width = compute_histogram(all_distances, bin_edges, normalize=False)
    # Convert bin width from Å to nm.
    bin_width_nm = bin_width * 0.1
    # Compute cell area in nm² (assuming cell_dimensions are in Å).
    area_nm2 = (cell_dimensions[0] * 0.1) * (cell_dimensions[1] * 0.1)
    density = counts / (area_nm2 * bin_width_nm * valid_frames)
    return density

def extract_initial_positions(simulation, ion, skip):
    """
    Extract the initial z positions (first frame) for the specified ion.
    """
    target_mask = np.array([elem == ion for elem in simulation.ions])
    if not np.any(target_mask):
        return []
    first_frame_positions = simulation.trajectories.positions.ions[0, target_mask, 2]
    first_frame_surface = simulation.trajectories.surface_z[0]
    distances = np.abs(first_frame_positions - first_frame_surface)
    return distances.tolist()

def process_density(simulation, target_positions, skip, bin_edges):
    """
    Process density by iterating over frames (skipping early ones),
    collecting distances from the surface, and computing a density histogram.
    Returns the density array and the number of valid frames.
    """
    distances_list = []
    valid_frames = 0
    times = simulation.trajectories.times
    for i, t in enumerate(times):
        if t < skip:
            continue
        valid_frames += 1
        d = np.abs(target_positions[i] - simulation.trajectories.surface_z[i])
        distances_list.append(d)
    if valid_frames == 0:
        logging.warning("No valid frames found after skipping %d ps.", skip)
        return None, 0
    all_distances = np.concatenate(distances_list)
    density = calculate_histogram_density(all_distances, bin_edges, simulation.cell_dimensions, valid_frames)
    return density, valid_frames

# --- Main Plotting Function ---

def plot_density(files, skip=5, bins=50, smoothing=3, initial_position=False, verbose=False, normalize=True):
    """
    Process simulation files and plot the density of ions and oxygen versus distance.
    
    For hydrogen-covered systems (detected via adsorbates), ion density curves
    are drawn with a black outline (via a two-layer drawing) and oxygen density
    is filled with a hatch. The bare systems are drawn as solid colored lines.
    """
    setup_logging(verbose)
    
    # Define histogram bins from 2 Å to 5 Å.
    bin_edges = np.linspace(2, 5, bins + 1)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    # Calculate bin width in nm.
    bin_width_nm = (bin_edges[1] - bin_edges[0]) * 0.1

    # Dictionaries to group densities by hydrogen coverage (False = bare, True = covered).
    ion_density_dict = {False: {ion: [] for ion in IONS}, True: {ion: [] for ion in IONS}}
    oxygen_density_list = {False: [], True: []}
    initial_positions = {ion: [] for ion in IONS}
    
    # Process each simulation file.
    for file_path in files:
        simulation = load_simulation(file_path)
        # Detect hydrogen coverage via adsorbates.
        hydrogen_present = (hasattr(simulation.trajectories.positions, 'adsorbates') and
                            simulation.trajectories.positions.adsorbates is not None and
                            simulation.trajectories.positions.adsorbates.size > 0)
        
        # Optionally record initial ion positions.
        if initial_position:
            for ion in IONS:
                initial_positions[ion].extend(extract_initial_positions(simulation, ion, skip))
        
        # Process ion densities.
        for ion in IONS:
            target_mask = np.array([elem == ion for elem in simulation.ions])
            if not np.any(target_mask):
                continue
            target_positions = simulation.trajectories.positions.ions[:, target_mask, 2]
            density, valid_frames = process_density(simulation, target_positions, skip, bin_edges)
            if density is None:
                logging.warning("No valid frames for ion %s in file %s", ion, file_path)
                continue
            ion_density_dict[hydrogen_present][ion].append(density)
        
        # Process oxygen density if available.
        if (hasattr(simulation.trajectories.positions, "watO") and
            simulation.trajectories.positions.watO is not None and
            simulation.trajectories.positions.watO.size > 0):
            target_positions = simulation.trajectories.positions.watO[:, :, 2]
            density, valid_frames = process_density(simulation, target_positions, skip, bin_edges)
            if valid_frames > 0 and density is not None:
                oxygen_density_list[hydrogen_present].append(density)
            else:
                logging.warning("No valid frames for oxygen in file %s", file_path)
        else:
            logging.debug("No water oxygen positions found in file %s", file_path)
    
    # Average ion density curves per ion and per coverage group.
    group_max_ion = {False: 0.0, True: 0.0}
    avg_ion_density = {False: {}, True: {}}
    for coverage_flag in [False, True]:
        for ion in IONS:
            if ion_density_dict[coverage_flag][ion]:
                avg_density = np.mean(ion_density_dict[coverage_flag][ion], axis=0)
                if normalize:
                    area = np.sum(avg_density * bin_width_nm)
                    if area != 0:
                        avg_density = avg_density / area
                avg_ion_density[coverage_flag][ion] = avg_density
                group_max_ion[coverage_flag] = max(group_max_ion[coverage_flag], np.max(avg_density))
            else:
                avg_ion_density[coverage_flag][ion] = None

    # Average oxygen densities per coverage group.
    avg_oxygen_density = {False: None, True: None}
    for coverage_flag in [False, True]:
        if oxygen_density_list[coverage_flag]:
            oxygen_density = np.mean(oxygen_density_list[coverage_flag], axis=0)
            if normalize:
                area = np.trapezoid(oxygen_density, bin_centers)
                if area != 0:
                    oxygen_density = oxygen_density / area
            avg_oxygen_density[coverage_flag] = oxygen_density
        else:
            avg_oxygen_density[coverage_flag] = None

    # Scale oxygen density so its maximum matches the ion density maximum for each group.
    scaled_oxygen_density = {False: None, True: None}
    for coverage_flag in [False, True]:
        od = avg_oxygen_density[coverage_flag]
        if od is not None and np.max(od) > 0 and group_max_ion[coverage_flag] > 0:
            scaled_oxygen_density[coverage_flag] = od * (group_max_ion[coverage_flag] / np.max(od))
        else:
            scaled_oxygen_density[coverage_flag] = od

    # --- Plotting Callback ---
    def plot_callback(ax, legend_font_size=32, legend_loc='best'):
        # --- Plot Oxygen Density ---
        for coverage_flag in [False, True]:
            od = scaled_oxygen_density[coverage_flag]
            if od is not None and len(od) == len(bin_centers):
                ax.fill_between(bin_centers, od, color='black', alpha=0.3)
                # For covered simulations, also add a black outline.
                if coverage_flag:
                    ax.plot(bin_centers, od, color='black', lw=1)
        
        # --- Prepare and Plot Ion Density Curves ---
        # We build lists to later create custom legend entries.
        custom_handles = []
        custom_labels = []
        # For each group (bare and hydrogen-covered).
        for coverage_flag in [False, True]:
            for ion in IONS:
                density = avg_ion_density[coverage_flag][ion]
                if density is not None and len(density) == len(bin_centers):
                    # Optionally smooth the data.
                    if smoothing > 1:
                        density = smooth_data(density, window_size=smoothing)
                    # Interpolate onto a dense grid for smooth line plotting.
                    num_dense = 100
                    x_dense = np.linspace(bin_centers[0], bin_centers[-1], num_dense)
                    y_dense = np.interp(x_dense, bin_centers, density)
                    # Define a constant line width.
                    lw_dense = np.full_like(x_dense, 8)
                    # Create line segments for smooth drawing.
                    points = np.array([x_dense, y_dense]).T.reshape(-1, 1, 2)
                    segments = np.concatenate([points[:-1], points[1:]], axis=1)
                    # Average the line width per segment.
                    lw_segments = 0.5 * (lw_dense[:-1] + lw_dense[1:])
                    
                    # Draw the line using LineCollection.
                    if coverage_flag:
                        # For hydrogen-covered systems, first draw a thick black outline.
                        outline_width = lw_segments + 2.0
                        lc_outline = LineCollection(segments, linewidths=outline_width,
                                                    colors='black', capstyle='round', joinstyle='round')
                        ax.add_collection(lc_outline)
                        # Then overlay the colored line.
                        lc = LineCollection(segments, linewidths=lw_segments,
                                            colors=ION_COLORS[ion], capstyle='round', joinstyle='round')
                        ax.add_collection(lc)
                        # Prepare a custom legend handle with outline.
                        handle = Line2D([], [], color=ION_COLORS[ion], lw=8, linestyle='-')
                        handle._outline = True  # flag for custom legend handler
                    else:
                        # For bare systems, draw a solid colored line.
                        lc = LineCollection(segments, linewidths=lw_segments,
                                            colors=ION_COLORS[ion], capstyle='round', joinstyle='round')
                        ax.add_collection(lc)
                        # Create a standard legend handle.
                        handle = Line2D([], [], color=ION_COLORS[ion], lw=8, linestyle='-')
                    
                    custom_handles.append(handle)
                    custom_labels.append(f"{ion}")
        
        # Add legend with a custom handler so that covered entries show the outline.
        ax.legend(handles=custom_handles, labels=custom_labels, loc=legend_loc,
                  fontsize=legend_font_size, handler_map={Line2D: HandlerMaybeOutlinedLine2D()})
        
        # --- Optionally Plot Initial Ion Positions ---
        if initial_position:
            for ion in IONS:
                for pos in initial_positions[ion]:
                    ax.plot(pos, 0.2, marker='o', markersize=10, linestyle='None', color=ION_COLORS[ion])

    # Use the unified custom_plot function to create the figure.
    custom_plot(
        xlabel="Distance to Surface (Å)",
        ylabel="ρ (a.u.)",
        font_size=40,
        yticks_remove=True,
        xticks=np.arange(2, 6, 0.5),
        xlim=(2, 5),
        legend_options={
            'loc': 'upper left',
            'font_size': 32,
            'draw_bg': True,
            'bgcolor': 'white',
            'alpha': 0.8
        },
        plot_func=plot_callback,
        show=True,
        save_svg="density_plot"
    )

# Uncomment to run the density plot function.
# ----------------------------------------------

# plot_density([f"data/simulations/Pt111_{ion}2.pkl" for ion in IONS])

In [5]:
def compare_ion_densities(simulation_files, skip=5, bins=50, smoothing=3, normalize=True, verbose=False):
    """
    Compare the densities of each ion and the oxygen density per simulation.
    For each simulation, the ion density curves are plotted as separate lines,
    and the oxygen density is plotted as a shaded region.
    This version applies a custom plotting mechanism where:
      - For covered simulations, ion density curves are drawn with a black outline.
      - For bare simulations, they are drawn as solid colored lines.
      - Oxygen density is plotted with a shaded area, with an outline if covered.
    
    Parameters:
    - simulation_files (list of str): List of simulation pickle files.
    - skip (int): Skip frames with time (in ps) less than this value.
    - bins (int): Number of bins for the histogram.
    - smoothing (int): Window for smoothing the density curves.
    - normalize (bool): If True, normalize the density curves (area under the curve = 1).
    - verbose (bool): Enable verbose logging.
    """
    # Set up logging.
    setup_logging(verbose)
    
    # Define histogram bins (from 2 Å to 5 Å).
    bin_edges = np.linspace(2, 5, bins + 1)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    
    # Load simulations and extract electrode potentials.
    simulations = []
    potentials = []
    for file_path in simulation_files:
        sim = load_simulation(file_path)
        simulations.append(sim)
        potentials.append(sim.electrode_potential)
    
    # Determine coverage flags for each simulation.
    coverage_flags = []
    for sim in simulations:
        if (hasattr(sim.trajectories.positions, 'adsorbates') and
            sim.trajectories.positions.adsorbates is not None and
            sim.trajectories.positions.adsorbates.size > 0):
            coverage_flags.append(True)
        else:
            coverage_flags.append(False)
    
    # Compute global maximum values for scaling purposes.
    global_ion_max = 0
    global_oxygen_max = 0
    for sim in simulations:
        for ion in IONS:
            target_mask = np.array([elem == ion for elem in sim.ions])
            if not np.any(target_mask):
                continue
            target_positions = sim.trajectories.positions.ions[:, target_mask, 2]
            density, valid_frames = process_density(sim, target_positions, skip, bin_edges)
            if density is None:
                continue
            if smoothing > 1:
                density = np.convolve(density, np.ones(smoothing) / smoothing, mode='same')
            if normalize:
                area = np.trapezoid(density, bin_centers)
                if area != 0:
                    density /= area
            global_ion_max = max(global_ion_max, density.max())
        
        if (hasattr(sim.trajectories.positions, "watO") and
            sim.trajectories.positions.watO is not None and
            sim.trajectories.positions.watO.size > 0):
            target_positions = sim.trajectories.positions.watO[:, :, 2]
            oxygen_density, valid_frames = process_density(sim, target_positions, skip, bin_edges)
            if oxygen_density is not None:
                if smoothing > 1:
                    oxygen_density = np.convolve(oxygen_density, np.ones(smoothing) / smoothing, mode='same')
                if normalize:
                    area = np.trapezoid(oxygen_density, bin_centers)
                    if area != 0:
                        oxygen_density /= area
                global_oxygen_max = max(global_oxygen_max, oxygen_density.max())
    
    scaling_factor = global_ion_max / global_oxygen_max if global_oxygen_max > 0 else 1
    
    # Define the plotting callback.
    def plot_callback(ax, legend_font_size, legend_loc):
        # Define a minimum arrow length (in data units).
        min_length = 0.1  # Change this value as needed.
        
        # Helper function to adjust arrow endpoints if the distance is too short.
        def adjust_arrow(p, q, min_length):
            # p and q are (x, y) tuples.
            dx = q[0] - p[0]
            dy = q[1] - p[1]
            d = np.sqrt(dx**2 + dy**2)
            if d < min_length:
                # Calculate the midpoint.
                mid = ((p[0] + q[0]) / 2, (p[1] + q[1]) / 2)
                # Avoid division by zero.
                if d == 0:
                    u = (1, 0)
                else:
                    u = (dx / d, dy / d)
                # Set new endpoints such that the arrow is centered and has min_length.
                new_p = (mid[0] - 0.5 * min_length * u[0], mid[1] - 0.5 * min_length * u[1])
                new_q = (mid[0] + 0.5 * min_length * u[0], mid[1] + 0.5 * min_length * u[1])
                return new_p, new_q
            return p, q

        # Global potential range for color mapping.
        group_min = min(potentials)
        group_range = max(potentials) - min(potentials) if max(potentials) != min(potentials) else 1

        # Variables to track the oxygen density maximum for the original arrow.
        pos_max_point = None  # most positive potential oxygen density maximum
        neg_max_point = None  # most negative potential oxygen density maximum
        best_pos = -np.inf
        best_neg = np.inf

        # New variables to track oxygen density at x = 4.25.
        pos_cutoff_point = None
        neg_cutoff_point = None
        best_pos_cutoff = -np.inf
        best_neg_cutoff = np.inf

        # List to store line segments for ion density curves.
        line_segments_data = []
        legend_entries = []

        # Loop over each simulation.
        for idx, sim in enumerate(simulations):
            coverage_flag = coverage_flags[idx]
            
            # Process ion densities.
            for ion in IONS:
                target_mask = np.array([elem == ion for elem in sim.ions])
                if not np.any(target_mask):
                    continue
                target_positions = sim.trajectories.positions.ions[:, target_mask, 2]
                density, valid_frames = process_density(sim, target_positions, skip, bin_edges)
                if density is None:
                    continue
                if smoothing > 1:
                    density = np.convolve(density, np.ones(smoothing) / smoothing, mode='same')
                if normalize:
                    area = np.trapezoid(density, bin_centers)
                    if area != 0:
                        density /= area

                # Map electrode potential to a color.
                exact_fraction = (sim.electrode_potential - group_min) / (group_range if group_range else 1)
                color_fraction = 0.75 * exact_fraction
                ion_color = blend_color(ION_COLORS[ion], color_fraction)
                label = f"{ion} ({sim.electrode_potential:.2f} V)"
                
                # Interpolate the density curve onto a dense grid for smooth plotting.
                num_dense = 100
                x_dense = np.linspace(bin_centers[0], bin_centers[-1], num_dense)
                y_dense = np.interp(x_dense, bin_centers, density)
                lw_dense = np.full_like(x_dense, 8)
                
                # Build line segments.
                points = np.array([x_dense, y_dense]).T.reshape(-1, 1, 2)
                segments = np.concatenate([points[:-1], points[1:]], axis=1)
                lw_segments = 0.5 * (lw_dense[:-1] + lw_dense[1:])
                
                line_segments_data.append((segments, lw_segments, ion_color, coverage_flag, label))
                legend_entries.append((ion_color, label, coverage_flag))
            
            # Process and plot oxygen density.
            if (hasattr(sim.trajectories.positions, "watO") and
                sim.trajectories.positions.watO is not None and
                sim.trajectories.positions.watO.size > 0):
                target_positions = sim.trajectories.positions.watO[:, :, 2]
                oxygen_density, valid_frames = process_density(sim, target_positions, skip, bin_edges)
                if oxygen_density is not None:
                    if smoothing > 1:
                        oxygen_density = np.convolve(oxygen_density, np.ones(smoothing) / smoothing, mode='same')
                    if normalize:
                        area = np.trapezoid(oxygen_density, bin_centers)
                        if area != 0:
                            oxygen_density /= area
                    oxygen_density *= scaling_factor
                    oxygen_color = blend_color("#000000", 0.5)
                    ax.fill_between(bin_centers, oxygen_density, color=oxygen_color, alpha=0.3)
                    if coverage_flag:
                        ax.plot(bin_centers, oxygen_density, color='black', lw=1)
                    
                    # Track maximum oxygen density (original arrow).
                    max_index = np.argmax(oxygen_density)
                    x_max = bin_centers[max_index]
                    y_max = oxygen_density[max_index]
                    if sim.electrode_potential > best_pos:
                        best_pos = sim.electrode_potential
                        pos_max_point = (x_max, y_max)
                    if sim.electrode_potential < best_neg:
                        best_neg = sim.electrode_potential
                        neg_max_point = (x_max, y_max)
                    
                    # Track oxygen density at x = 4.25.
                    cutoff_val = 4.25
                    oxygen_at_cutoff = np.interp(cutoff_val, bin_centers, oxygen_density)
                    if sim.electrode_potential > best_pos_cutoff:
                        best_pos_cutoff = sim.electrode_potential
                        pos_cutoff_point = (cutoff_val, oxygen_at_cutoff)
                    if sim.electrode_potential < best_neg_cutoff:
                        best_neg_cutoff = sim.electrode_potential
                        neg_cutoff_point = (cutoff_val, oxygen_at_cutoff)
        
        # Draw the original arrow with minimum length adjustment.
        if pos_max_point is not None and neg_max_point is not None:
            p, q = adjust_arrow(pos_max_point, neg_max_point, min_length)
            ax.annotate(
                '', 
                xy=q, 
                xytext=p,
                arrowprops=dict(
                    facecolor='black', edgecolor='black', 
                    arrowstyle=ArrowStyle.Simple(head_length=0.2, head_width=0.2, tail_width=0.07),
                    alpha=0.7
                )
            )
        
        # Draw the new arrow at x = 4.25 with minimum length.
        if pos_cutoff_point is not None and neg_cutoff_point is not None:
            p, q = adjust_arrow(pos_cutoff_point, neg_cutoff_point, min_length)
            ax.annotate(
                '',
                xy=q,
                xytext=p,
                arrowprops=dict(
                    facecolor='black', edgecolor='black', 
                    arrowstyle=ArrowStyle.Simple(head_length=0.2, head_width=0.2, tail_width=0.07),
                    alpha=0.7
                )
            )
        
        # Draw all ion density curves using LineCollection.
        for segments, lw_segments, color, coverage_flag, label in line_segments_data:
            if coverage_flag:
                outline_width = lw_segments + 2.0
                lc_outline = LineCollection(segments, linewidths=outline_width,
                                            colors='black', capstyle='round', joinstyle='round')
                ax.add_collection(lc_outline)
                lc = LineCollection(segments, linewidths=lw_segments,
                                    colors=color, capstyle='round', joinstyle='round')
                ax.add_collection(lc)
            else:
                lc = LineCollection(segments, linewidths=lw_segments,
                                    colors=color, capstyle='round', joinstyle='round')
                ax.add_collection(lc)
        
        # Build custom legend.
        custom_handles = []
        custom_labels = []
        for color, label, cov_flag in legend_entries:
            handle = Line2D([], [], color=color, lw=8, linestyle='-')
            if cov_flag:
                handle._outline = True
            custom_handles.append(handle)
            custom_labels.append(label)
        
        ax.legend(handles=custom_handles, labels=custom_labels, 
                  fontsize=legend_font_size, loc=legend_loc,
                  handler_map={Line2D: HandlerMaybeOutlinedLine2D()})
    
    # Generate and display the plot using custom_plot.
    custom_plot(
        xlabel="Distance to Surface (Å)",
        ylabel="ρ (a.u.)",
        font_size=40,
        yticks_remove=True,
        plot_func=plot_callback,
        xticks=np.arange(2, 5.5, 0.5),
        xlim=(2, 5),
        legend_options={"font_size": 24, "alpha": 0.8, "loc": "upper left"},
        save_svg=f"ion_density_{simulations[0].ions[0]}{'_H' if coverage_flags[0] else ''}",
        show=True
    )


# Uncomment to execute the function with the desired simulation files.
# -----------------------------------------------------------------------

# for ion in IONS:
#     compare_ion_densities([f"data/simulations/Pt111_{ion}{i}.pkl" for i in [2, 3, 4]])

In [6]:
def plot_opening_angles(files, bins=180, region_width=0.5, min_count=1000, smoothing=3):
    """
    Plot the distribution of opening angles between ions, water oxygen,
    and the surface-normal. Angles are binned for each ion type and grouped by 
    the ion's distance from the surface. Only regions with region boundaries 
    (region_min > 1.5 Å and region_max < 5.5 Å) that collect at least min_count angles
    are plotted. An inset displays ion density curves (with coverage-dependent 
    black outlines) alongside the oxygen density (displayed as a shaded area).
    
    Parameters:
      files (list of str): List of simulation pickle file paths.
      bins (int): Number of bins for the opening-angle histogram (angles 0–180°).
      region_width (float): Width (in Å) of the distance regions.
      min_count (int): Minimum number of angles required for a region to be included.
      smoothing (int): Smoothing window size for density curves.
    """
    # --- Setup logging and initialize data containers ---
    setup_logging(verbose=False)
    angle_data = {ion: {} for ion in IONS}
    density_counts = {ion: {} for ion in IONS}

    # Determine whether any simulation is "covered" (i.e. has adsorbates)
    hydrogen_covered = False
    for file in files:
        sim = load_simulation(file)
        if (hasattr(sim.trajectories.positions, 'adsorbates') and
            sim.trajectories.positions.adsorbates is not None and
            sim.trajectories.positions.adsorbates.size > 0):
            hydrogen_covered = True
            break

    # --- Accumulate opening-angle data ---
    for file in files:
        sim = load_simulation(file)
        n_frames = sim.trajectories.times.shape[0]
        surface_zs = sim.trajectories.surface_z  # one surface z per frame

        ion_positions_all = sim.trajectories.positions.ions    # (n_frames, n_ions, 3)
        o_positions_all = sim.trajectories.positions.watO         # (n_frames, n_water, 3)
        if o_positions_all is None or o_positions_all.size == 0:
            print(f"No water oxygen positions found in file: {file}")
            continue

        # Build a dictionary of indices per ion type.
        ion_indices = {}
        for i, ion in enumerate(sim.ions):
            ion_indices.setdefault(ion, []).append(i)

        # Loop over each frame
        for frame in range(n_frames):
            frame_ions = ion_positions_all[frame]  # (n_ions, 3)
            frame_o = o_positions_all[frame]         # (n_water, 3)
            surface_z = surface_zs[frame]

            for ion in IONS:
                if ion not in ion_indices:
                    continue
                for idx in ion_indices[ion]:
                    ion_pos = frame_ions[idx]
                    distance = ion_pos[2] - surface_z
                    if distance < 0:
                        continue
                    # Determine region based on ion distance from surface.
                    region_min = np.floor(distance / region_width) * region_width
                    region_max = region_min + region_width
                    if region_min > 1.5 and region_max < 5.5:
                        region_key = (region_min, region_max)
                        angle_data[ion].setdefault(region_key, [])
                        density_counts[ion].setdefault(region_key, 0)
                        density_counts[ion][region_key] += 1

                        # Compute displacement vectors to water oxygens.
                        diff = frame_o - ion_pos
                        dists = np.linalg.norm(diff, axis=1)
                        valid = dists <= ION_CUTOFFS[ion]
                        if not np.any(valid):
                            continue
                        valid_diff = diff[valid]
                        valid_dists = dists[valid]
                        cos_theta = valid_diff[:, 2] / valid_dists
                        cos_theta = np.clip(cos_theta, -1.0, 1.0)
                        angles = np.abs(180 - np.degrees(np.arccos(cos_theta)))
                        angle_data[ion][region_key].extend(angles.tolist())

    # --- Build histogram data for regions meeting min_count ---
    histogram_data = []
    angle_bins = np.linspace(0, 180, bins + 1)
    for ion in IONS:
        for region_key, angles in angle_data[ion].items():
            if len(angles) < min_count:
                continue
            angles_arr = np.array(angles)
            counts, _ = np.histogram(angles_arr, bins=angle_bins)
            probability = counts / np.sum(counts)
            bin_centers = (angle_bins[:-1] + angle_bins[1:]) / 2
            histogram_data.append((ion, region_key, bin_centers, probability))
    # Sort so that regions with higher lower-bound appear first.
    histogram_data_sorted = sorted(histogram_data,
                                   key=lambda x: (IONS.index(x[0]), -x[1][0]))

    # --- Generate gradient colors per ion and region ---
    ion_region_colors = {}
    for ion in IONS:
        valid_regions = sorted([rk for rk in angle_data[ion].keys() if len(angle_data[ion][rk]) >= min_count],
                               key=lambda x: x[0])
        n_regions = len(valid_regions)
        colors = {}
        if n_regions > 0:
            for idx, region_key in enumerate(valid_regions):
                # The fraction here controls the lightening of the base ion color.
                fraction = (idx / (n_regions - 1) / 2) if n_regions > 1 else 0
                colors[region_key] = blend_color(ION_COLORS[ion], fraction)
        ion_region_colors[ion] = colors

    # --- Compute ion density data for the inset ---
    ion_density_dict = {ion: [] for ion in IONS}
    density_bin_edges = np.linspace(2, 5, 51)  # 50 bins from 2 Å to 5 Å.
    density_bin_centers = 0.5 * (density_bin_edges[:-1] + density_bin_edges[1:])

    for file in files:
        sim = load_simulation(file)
        for ion in IONS:
            target_mask = np.array([elem == ion for elem in sim.ions])
            if not np.any(target_mask):
                continue
            target_positions = sim.trajectories.positions.ions[:, target_mask, 2]
            density, _ = process_density(sim, target_positions, skip=0, bin_edges=density_bin_edges)
            if density is not None:
                if smoothing > 1:
                    density = np.convolve(density, np.ones(smoothing) / smoothing, mode='same')
                ion_density_dict[ion].append(density)
    for ion in IONS:
        if ion_density_dict[ion]:
            ion_density_dict[ion] = np.mean(ion_density_dict[ion], axis=0)
        else:
            ion_density_dict[ion] = None

    # --- Compute oxygen density (from water oxygen positions) ---
    oxygen_density_list = []
    for file in files:
        sim = load_simulation(file)
        if (sim.trajectories.positions.watO is not None and 
            sim.trajectories.positions.watO.size > 0):
            target_positions = sim.trajectories.positions.watO[:, :, 2]
            density, _ = process_density(sim, target_positions, skip=0, bin_edges=density_bin_edges)
            if density is not None:
                if smoothing > 1:
                    density = np.convolve(density, np.ones(smoothing) / smoothing, mode='same')
                oxygen_density_list.append(density)
    oxygen_density = np.mean(oxygen_density_list, axis=0) if oxygen_density_list else None

    # --- Scale oxygen density to match ion density maximum ---
    max_ion_value = 0.0
    for ion in IONS:
        d = ion_density_dict[ion]
        if d is not None:
            max_ion_value = max(max_ion_value, np.max(d))
    if (oxygen_density is not None and np.max(oxygen_density) > 0 and max_ion_value > 0):
        oxygen_density_scaled = oxygen_density * (max_ion_value / np.max(oxygen_density))
    else:
        oxygen_density_scaled = oxygen_density

    # --- Define the plotting callback ---
    def plot_callback(ax, legend_font_size, legend_loc):
        # Plot the opening-angle histograms.
        for ion, region_key, centers, prob in histogram_data_sorted:
            color = ion_region_colors[ion].get(region_key, ION_COLORS[ion])
            if hydrogen_covered:
                # Draw a thicker black base line as an outline.
                ax.plot(centers, prob, linewidth=10, color='black')
                ax.plot(centers, prob, linewidth=8, color=color)
            else:
                ax.plot(centers, prob, linestyle='-', linewidth=8, color=color)
        ax.set_xticks(np.arange(0, 181, 30))
        ax.set_xlim(0, 180)
        # Optionally, set y-limit based on plotted data.
        ax.set_ylim(0)

        # --- Create an inset axis for the density plot ---
        fig = ax.figure
        main_pos = ax.get_position().bounds  # (left, bottom, width, height)
        # Define candidate inset positions.
        candidates = {
            "top_right": [main_pos[0] + 0.68 * main_pos[2], main_pos[1] + 0.62 * main_pos[3],
                          0.35 * main_pos[2], 0.35 * main_pos[3]],
            "top_left": [main_pos[0] + 0.05 * main_pos[2], main_pos[1] + 0.60 * main_pos[3],
                         0.35 * main_pos[2], 0.35 * main_pos[3]],
            "bottom_right": [main_pos[0] + 0.60 * main_pos[2], main_pos[1] + 0.05 * main_pos[3],
                             0.35 * main_pos[2], 0.35 * main_pos[3]],
            "bottom_left": [main_pos[0] + 0.05 * main_pos[2], main_pos[1] + 0.05 * main_pos[3],
                            0.35 * main_pos[2], 0.35 * main_pos[3]]
        }
        def candidate_score(candidate_box):
            fig_left, fig_bottom, fig_width, fig_height = candidate_box
            inv = ax.transData.inverted()
            data_ll = inv.transform((fig_left, fig_bottom))
            data_ur = inv.transform((fig_left + fig_width, fig_bottom + fig_height))
            x_min, y_min = data_ll
            x_max, y_max = data_ur
            total, count = 0, 0
            for line in ax.lines:
                xdata = line.get_xdata()
                ydata = line.get_ydata()
                mask = (xdata >= x_min) & (xdata <= x_max)
                if np.any(mask):
                    total += np.mean(ydata[mask])
                    count += 1
            return total / count if count > 0 else 1e6

        candidate_scores = {key: candidate_score(box) for key, box in candidates.items()}
        best_candidate = min(candidate_scores, key=candidate_scores.get)
        best_box = candidates[best_candidate]
        ax_density = fig.add_axes(best_box, zorder=3)

        # Expand the inset bounding box slightly.
        fig.canvas.draw()
        renderer = fig.canvas.get_renderer()
        bbox = ax_density.get_tightbbox(renderer).transformed(fig.transFigure.inverted())
        margin = 0.02
        new_x0 = bbox.x0
        new_y0 = bbox.y0 - 3 * margin
        new_width = bbox.width - 2 * margin
        new_height = bbox.height + 2 * margin
        rect = Rectangle((new_x0, new_y0), new_width, new_height,
                         transform=fig.transFigure, facecolor='white', alpha=0.8,
                         edgecolor='none', zorder=2)
        fig.patches.append(rect)

        # --- Plot oxygen density in the inset ---
        if (oxygen_density_scaled is not None and len(density_bin_centers) == len(oxygen_density_scaled)):
            oxy_color = blend_color("#000000", 0.5)
            ax_density.fill_between(density_bin_centers, oxygen_density_scaled,
                                    color=oxy_color, alpha=0.3)
            if hydrogen_covered:
                ax_density.plot(density_bin_centers, oxygen_density_scaled,
                                color='black', lw=1, zorder=0)

        # --- Plot the ion density curves in the inset ---
        for ion in IONS:
            density = ion_density_dict[ion]
            if density is None:
                continue
            if hydrogen_covered:
                ax_density.plot(density_bin_centers, density, linewidth=5, color='black')
                ax_density.plot(density_bin_centers, density, linewidth=3, color=ION_COLORS[ion])
            else:
                ax_density.plot(density_bin_centers, density, linestyle='-', linewidth=3,
                                color=ION_COLORS[ion])
            # Fill the area for each region with sufficient angle data.
            for region_key in angle_data[ion]:
                if len(angle_data[ion][region_key]) < min_count:
                    continue
                region_min, region_max = region_key
                mask = (density_bin_centers >= region_min) & (density_bin_centers <= region_max)
                if np.any(mask):
                    color = ion_region_colors[ion].get(region_key, ION_COLORS[ion])
                    ax_density.fill_between(density_bin_centers[mask],
                                            density[mask],
                                            color=color)
        ax_density.set_xlabel("Distance to Surface (Å)", fontsize=28)
        ax_density.set_ylabel("ρ (a.u.)", fontsize=32)
        ax_density.set_yticks([])
        ax_density.tick_params(labelsize=28)
        ax_density.set_xticks(np.arange(2, 6, 1))
        ax_density.set_xlim(2, 5)
        ax_density.set_ylim(0)

    # --- Generate and display the plot using custom_plot ---
    custom_plot(
        xlabel="Opening angle (degrees)",
        ylabel="Probability (a.u.)",
        font_size=48,
        yticks_remove=True,
        grid=False,
        xlim=(0, 180),
        xticks=np.arange(0, 181, 30),
        ylim=(0),
        legend_options=None,  # No custom legend defined here.
        plot_func=plot_callback,
        show=True
    )

# --- Example usage ---
sim_ids = ["Li2", "Na2", "K2", "Cs2",
           "Li4_H", "Na4_H", "K4_H", "Cs4_H"]

# Uncomment and adjust the following to process the desired simulations.
# for sim_id in sim_ids:
#     plot_opening_angles([f"data/simulations/Pt111_{sim_id}.pkl"], bins=36, region_width=0.5)

In [7]:
# Compute the local continuous coordination number (CCN) for a single ion using a switching function.
def compute_local_ccn(ion_position, o_positions, o_density, rdf_bin_edges, cutoff, box_lengths=None, ccn_d=0.1):
    """
    Compute the local continuous coordination number (CN) for an ion using a switching function.
    
    Each water molecule contributes with a weight:
        f(r) = 1 / (1 + exp((r - cutoff) / ccn_d))
    
    Also computes the weighted average of the z-coordinate of oxygen atoms.
    
    Parameters:
        ion_position (np.ndarray): Array of shape (3,) representing the ion's position.
        o_positions (np.ndarray): Array of shape (n_O, 3) with water oxygen positions.
        o_density (float): Global oxygen density (unused; kept for compatibility).
        rdf_bin_edges (np.ndarray): Bin edges for the RDF histogram (unused; kept for compatibility).
        cutoff (float): Characteristic distance (Å) for the switching function.
        box_lengths (np.ndarray, optional): Box dimensions for periodic boundary corrections.
        ccn_d (float, optional): Smoothing parameter controlling the switching width.
    
    Returns:
        tuple: (cn, avg_z) where:
            cn   : Continuous coordination number.
            avg_z: Weighted average z-coordinate of O atoms.
            Returns (None, None) if no valid weights are found.
    """
    # Compute displacement vectors; apply periodic boundary corrections if available.
    diff = o_positions - ion_position
    if box_lengths is not None:
        diff = apply_periodic_boundary(diff, box_lengths)
    
    # Compute distances to each water oxygen.
    distances = np.linalg.norm(diff, axis=1)
    
    # Evaluate switching function weights.
    weights = 1.0 / (1.0 + np.exp((distances - cutoff) / ccn_d))
    
    # Sum weights to obtain the continuous CN.
    cn = np.sum(weights)
    
    # Compute the weighted average of the z-coordinate.
    if np.sum(weights) > 0:
        avg_z = np.sum(o_positions[:, 2] * weights) / np.sum(weights)
    else:
        return None, None
    
    return cn, avg_z


def plot_ccn_vs_distance(sim_files, skip=5, default_cutoff=3.0, bins=100,
                         max_distance=5.0, ccn_d=0.1, threshold=1):
    """
    Process simulation pickle files to compute and plot the continuous coordination number (CN)
    versus distance to the surface.
    
    For each simulation:
      - For every ion, compute its local CN using a switching function.
      - Determine the distance from the surface (using the ion's z-coordinate relative to the surface).
      - Bin the CN values as a function of distance.
    
    Special treatment:
      - For hydrogen-adsorbed (covered) systems, the CN curves are drawn with a black outline.
      - Oxygen density is shown as a shaded area, with an additional black outline if the system is covered.
    
    Parameters:
        sim_files (list): List of simulation pickle file paths.
        skip (int): Time (ps) to skip at the beginning of each simulation.
        default_cutoff (float): Default cutoff (Å) for the switching function if not specified per ion.
        bins (int): Number of bins for the distance-to-surface histogram.
        max_distance (float): Maximum distance (Å) for binning.
        ccn_d (float): Smoothing parameter for the switching function.
        threshold (int): Minimum count threshold for bin validity.
    """
    # Define distance bins and compute bin centers.
    distance_bins = np.linspace(0, max_distance, bins + 1)
    bin_centers_distance = 0.5 * (distance_bins[:-1] + distance_bins[1:])
    
    # Initialize accumulators.
    results = {}  # key: (ion_type, hydrogen_flag)
    oxygen_distances = {False: [], True: []}
    oxygen_frames = {False: 0, True: 0}
    global_area = None  # Area from the first simulation with valid cell dimensions.
    
    # Process each simulation file.
    for file_name in sim_files:
        sim = load_simulation(file_name)
        
        # Determine if hydrogen adsorbates (coverage) are present.
        hydrogen_present = (hasattr(sim.trajectories.positions, 'adsorbates') and
                            sim.trajectories.positions.adsorbates is not None and
                            sim.trajectories.positions.adsorbates.size > 0)
        
        # Determine the ion type and corresponding cutoff.
        ion_type = sim.ions[0] if isinstance(sim.ions, (list, tuple)) else sim.ions
        specific_cutoff = ION_CUTOFFS.get(ion_type, default_cutoff)
        rdf_bin_edges = np.linspace(0, specific_cutoff, 51)  # (Unused; kept for compatibility)
        
        # Ensure required atom positions are available.
        if sim.trajectories.positions.watO is None or sim.trajectories.positions.ions is None:
            print(f"Required atom positions (water O or {ion_type}) not found in {file_name}.")
            continue
        
        # Get simulation cell dimensions.
        cell = sim.cell_dimensions
        if cell is None:
            print(f"No cell dimensions found in {file_name}; skipping periodic corrections.")
            box_lengths = None
            area = 1.0
        else:
            box_lengths = cell if np.ndim(cell) == 1 else np.diag(cell)
            area = box_lengths[0] * box_lengths[1]
            if global_area is None:
                global_area = area
        
        volume = np.prod(box_lengths) if box_lengths is not None else 1.0
        o_count = sim.trajectories.positions.watO.shape[1]
        o_density_global = o_count / volume
        
        n_frames = sim.trajectories.positions.watO.shape[0]
        n_ions = sim.trajectories.positions.ions.shape[1]
        
        # Temporary accumulators for this simulation.
        cn_sum = np.zeros(bins)
        counts = np.zeros(bins)
        
        # Loop over frames; skip early frames based on simulation time.
        for frame in range(n_frames):
            if frame * sim.timestep < (skip * 1000):
                continue
            
            surface_z = sim.trajectories.surface_z[frame]
            o_positions = sim.trajectories.positions.watO[frame]  # (n_water, 3)
            ion_positions = sim.trajectories.positions.ions[frame]  # (n_ions, 3)
            
            # Accumulate oxygen distances relative to the surface.
            oxygen_distances_frame = o_positions[:, 2] - surface_z
            valid_mask = (oxygen_distances_frame >= distance_bins[0]) & (oxygen_distances_frame <= distance_bins[-1])
            if np.any(valid_mask):
                oxygen_distances[hydrogen_present].extend(oxygen_distances_frame[valid_mask])
                oxygen_frames[hydrogen_present] += 1
            
            # Process each ion.
            for ion_pos in ion_positions:
                result = compute_local_ccn(ion_pos, o_positions, o_density_global,
                                           rdf_bin_edges, specific_cutoff, box_lengths, ccn_d=ccn_d)
                if result is None:
                    continue
                cn, avg_z = result
                # Compute the distance from the surface.
                dist_to_surface = ion_pos[2] - surface_z
                bin_idx = np.digitize(dist_to_surface, distance_bins) - 1
                if bin_idx < 0 or bin_idx >= bins:
                    continue
                cn_sum[bin_idx] += cn
                counts[bin_idx] += 1
        
        key = (ion_type, hydrogen_present)
        if key in results:
            results[key][0] += cn_sum
            results[key][1] += counts
            results[key][3] += n_frames * n_ions
            results[key][4] += n_frames
        else:
            # Structure: [total_cn_sum, total_counts, cutoff, total_ion_count, total_frames]
            results[key] = [cn_sum, counts, specific_cutoff, n_frames * n_ions, n_frames]
    
    # Compute normalized oxygen density for each hydrogen flag.
    oxygen_density_norm = {}
    for h_flag in [False, True]:
        if oxygen_distances[h_flag] and oxygen_frames[h_flag] > 0 and global_area is not None:
            counts_o, _ = np.histogram(oxygen_distances[h_flag], bins=distance_bins)
            bin_width = distance_bins[1] - distance_bins[0]
            oxygen_density = counts_o / (global_area * bin_width * oxygen_frames[h_flag])
            # Normalize to bring oxygen density onto a comparable scale.
            if oxygen_density.max() > 0:
                oxygen_density_norm[h_flag] = 2.5 + oxygen_density * (8.5 / oxygen_density.max())
            else:
                oxygen_density_norm[h_flag] = oxygen_density
        else:
            oxygen_density_norm[h_flag] = None

    # Prepare plotting data for each ion type's CN curve.
    line_segments_data = []  # List of tuples: (segments, lw_segments, color, hydrogen_flag)
    legend_dict = {}         # Mapping: ion -> (color, hydrogen_flag)
    baseline = 1
    max_lw = 12
    for (ion, h_flag), (cn_sum_total, counts_total, _, total_ion_count, total_frames) in results.items():
        # Compute average CN per bin.
        avg_cn = np.divide(cn_sum_total, counts_total, out=np.zeros_like(cn_sum_total), where=counts_total > 0)
        
        # Select only bins with data.
        valid = counts_total > 0
        if not np.any(valid):
            continue
        x_valid = bin_centers_distance[valid]
        y_valid = avg_cn[valid]
        
        # Compute a smoothed line width based on the counts.
        smoothed = np.convolve(counts_total, np.ones(3) / 3, mode='same')[valid]
        norm = (smoothed - smoothed.min()) / (smoothed.max() - smoothed.min() + 1e-8) if smoothed.max() > 0 else smoothed
        lw_valid = baseline + (max_lw - baseline) * norm
        
        # Interpolate onto a dense grid for smooth rendering.
        num_dense = 100
        x_dense = np.linspace(x_valid[0], x_valid[-1], num_dense)
        y_dense = np.interp(x_dense, x_valid, y_valid)
        lw_dense = np.interp(x_dense, x_valid, lw_valid)
        
        # Build line segments for plotting.
        points = np.array([x_dense, y_dense]).T.reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)
        lw_segments = 0.5 * (lw_dense[:-1] + lw_dense[1:])
        
        # Get the ion color.
        color = ION_COLORS.get(ion, "#4472C4")
        line_segments_data.append((segments, lw_segments, color, h_flag))
        
        # For the legend, if an ion appears in both systems, favor the hydrogen-covered (outlined) version.
        if ion not in legend_dict or h_flag:
            legend_dict[ion] = (color, h_flag)
    
    # Build custom legend handles.
    custom_handles = []
    custom_labels = []
    for ion, (color, h_flag) in legend_dict.items():
        handle = Line2D([], [], color=color, lw=max_lw, linestyle='-')
        if h_flag:
            handle._outline = True
        custom_handles.append(handle)
        custom_labels.append(ion)
    
    # Define the plotting callback for custom_plot.
    def plot_callback(ax, legend_font_size, legend_loc):
        # Plot oxygen density as a shaded area.
        for h_flag in [False, True]:
            od = oxygen_density_norm[h_flag]
            if od is None:
                continue
            oxygen_color = blend_color("#000000", 0.5)
            ax.fill_between(bin_centers_distance, od, color=oxygen_color, alpha=0.3)
            # For hydrogen-covered systems, add an outline.
            if h_flag:
                ax.plot(bin_centers_distance, od, color='black', lw=1)
        
        # Plot CN curves using LineCollection.
        for segments, lw_segments, color, h_flag in line_segments_data:
            if h_flag:
                outline_width = lw_segments + 2.0
                lc_outline = LineCollection(segments, linewidths=outline_width,
                                            colors='black', capstyle='round', joinstyle='round')
                ax.add_collection(lc_outline)
                lc = LineCollection(segments, linewidths=lw_segments,
                                    colors=color, capstyle='round', joinstyle='round')
                ax.add_collection(lc)
            else:
                lc = LineCollection(segments, linewidths=lw_segments,
                                    colors=color, capstyle='round', joinstyle='round')
                ax.add_collection(lc)
        
        # Add the custom legend.
        ax.legend(handles=custom_handles, labels=custom_labels, 
                  fontsize=legend_font_size, loc=legend_loc,
                  handler_map={Line2D: HandlerMaybeOutlinedLine2D()})
    
    # Generate and display the final plot.
    custom_plot(
        figsize=(12, 8),
        plot_func=plot_callback,
        font_size=32,
        xlabel="Distance to Surface (Å)",
        xlim=(2, 5),
        xticks=np.arange(2, 5.25, 0.5),
        ylabel="CN",
        ylim=(2.5, 11.2),
        yticks=np.arange(4, 12, 2),
        legend_options={
            'handles': custom_handles, 
            'labels': custom_labels, 
            'loc': 'upper left', 
            'handler_map': {Line2D: HandlerMaybeOutlinedLine2D()},
            'font_size': 24
        },
        yticks_remove=False,
        grid='y',
        save_svg='ccn_vs_distance',
        show=True
    )


# ------------------------------------------------------------------------------
sim_files = [f"data/simulations/Pt111_{ion}2.pkl" for ion in IONS] + [f"data/simulations/Pt111_{ion}1_H.pkl" for ion in IONS]

# Uncomment to run the analysis on the provided simulation pickle files.
# plot_ccn_vs_distance(sim_files=sim_files)

In [8]:
def compute_local_hollow_cn_vectorized(ion_positions, o_positions,
                                         inner_radius=2.75, outer_radius=4.25,
                                         box_lengths=None):
    """
    Compute the coordination number (CN) for each ion position by counting
    water oxygen atoms within a hollow spherical shell defined by inner and
    outer radii.

    Parameters:
        ion_positions (np.ndarray): Array of shape (N, 3) with N ion positions.
        o_positions (np.ndarray): Array of shape (n, 3) with water oxygen positions.
        inner_radius (float): Inner radius of the hollow sphere (Å).
        outer_radius (float): Outer radius of the hollow sphere (Å).
        box_lengths (np.ndarray, optional): Cell dimensions for PBC correction.

    Returns:
        np.ndarray: Array of length N containing the CN for each ion position.
    """
    # Compute difference vectors: shape (N, n, 3)
    diff = o_positions[None, :, :] - ion_positions[:, None, :]
    if box_lengths is not None:
        diff = diff - box_lengths * np.round(diff / box_lengths)
    
    # Compute distances: shape (N, n)
    distances = np.linalg.norm(diff, axis=2)
    # Count oxygens within the hollow spherical shell for each ion
    mask = (distances >= inner_radius) & (distances <= outer_radius)
    cn_values = np.sum(mask, axis=1)
    return cn_values


def plot_hollow_cn_vs_distance(sim_file, ion='Cs', d_min=2.0, d_max=5.0,
                               n_d=50, grid_resolution=10):
    """
    For each frame in the simulation, computes and plots the average hollow
    coordination number (CN) as a function of the distance from the metal surface.
    An imaginary ion is placed at many (x,y) grid points (at z = surface_z + d),
    the CN is calculated using a hollow spherical shell and averaged over grid points
    and frames. Optionally, the oxygen density profile is computed and plotted.

    Coverage dependence is handled such that if the simulation is covered (determined
    by the simulation filename ending with '_H.pkl'), the CN curve and oxygen density
    are plotted with a black outline.

    Parameters:
        sim_file (str): Path to the simulation file.
        ion (str): Ion type ('Li', 'Na', 'K', or 'Cs') to use for the CN calculation.
        d_min (float): Minimum distance above the surface (Å).
        d_max (float): Maximum distance above the surface (Å).
        n_d (int): Number of distance bins.
        grid_resolution (int): Number of grid points per dimension in the x-y plane.
    """
    # Define ion-specific parameters.
    if ion not in ["Li", "Na", "K", "Cs"] or ion == 'Cs':
        ion = "Cs"
        irad, orad = 2.75, 4.25
        o_start = 2.01
        ylim = (o_start, o_start + 6.5)
        logging.info("Invalid ion type; defaulting to Cs.")
    elif ion == "Li":
        irad, orad = 1.75, 2.75
        o_start = 0.01
        ylim = (o_start, o_start + 3)
    elif ion == "Na":
        irad, orad = 2.0, 3.2
        o_start = 1.01
        ylim = (o_start, o_start + 3.5)
    else:  # ion == "K"
        irad, orad = 2.5, 3.8
        o_start = 2.51
        ylim = (o_start, o_start + 3.5)

    # Load the simulation object.
    sim = load_simulation(sim_file)
    n_frames = sim.trajectories.positions.all.shape[0]
    
    # Ensure cell dimensions are available.
    if sim.cell_dimensions is not None:
        cell = np.array(sim.cell_dimensions)
    else:
        raise ValueError("Cell dimensions are required for PBC correction and grid positioning.")
    
    # Set up grid points in the x-y plane.
    x_centers = np.linspace(0, cell[0], grid_resolution, endpoint=False) + cell[0] / (2 * grid_resolution)
    y_centers = np.linspace(0, cell[1], grid_resolution, endpoint=False) + cell[1] / (2 * grid_resolution)
    xv, yv = np.meshgrid(x_centers, y_centers, indexing='ij')
    grid_xy = np.column_stack((xv.ravel(), yv.ravel()))
    n_grid = grid_xy.shape[0]
    
    # Distance values above the surface where CN is computed.
    d_values = np.linspace(d_min, d_max, n_d)
    cn_accum = np.zeros(n_d)
    counts = np.zeros(n_d)
    
    # For oxygen density (optional shading).
    oxygen_distances_all = []
    total_oxygen_frames = 0
    distance_bins = np.linspace(d_min, d_max, n_d + 1)
    
    # Loop over all frames.
    for frame in range(n_frames):
        # Report progress every ~25% of frames.
        if frame % (n_frames // 4) == 0:
            logging.info("Processing frame %d of %d", frame, n_frames)
        
        surface_z = sim.trajectories.surface_z[frame]
        o_positions = sim.trajectories.positions.watO[frame]
        
        # Accumulate oxygen distances relative to the surface.
        oxy_z = o_positions[:, 2] - surface_z
        valid_mask = (oxy_z >= d_min) & (oxy_z <= d_max)
        if np.any(valid_mask):
            oxygen_distances_all.extend(oxy_z[valid_mask])
            total_oxygen_frames += 1
        
        # For each distance d, compute CN at all grid points.
        for i, d in enumerate(d_values):
            # Build ion positions: (x, y, surface_z + d)
            ion_positions = np.hstack((grid_xy, np.full((n_grid, 1), surface_z + d)))
            cn_grid = compute_local_hollow_cn_vectorized(ion_positions, o_positions,
                                                         inner_radius=irad, outer_radius=orad,
                                                         box_lengths=cell)
            avg_cn_frame = np.mean(cn_grid)
            cn_accum[i] += avg_cn_frame
            counts[i] += 1
    
    # Average the CN over all frames.
    avg_cn = cn_accum / counts

    # Compute the oxygen density profile, if data are available.
    oxygen_density_norm = None
    if oxygen_distances_all and total_oxygen_frames > 0:
        counts_o, _ = np.histogram(oxygen_distances_all, bins=distance_bins)
        bin_width = distance_bins[1] - distance_bins[0]
        global_area = cell[0] * cell[1]
        oxygen_density = counts_o / (global_area * bin_width * total_oxygen_frames)
        if np.max(oxygen_density) > 0:
            # Normalize and shift oxygen density for plotting.
            oxygen_density_norm = o_start + oxygen_density * ((ylim[1] - ylim[0]) / np.max(oxygen_density))
        bin_centers_distance = 0.5 * (distance_bins[:-1] + distance_bins[1:])
    else:
        bin_centers_distance = None

    # Define the plotting callback for custom_plot.
    def plot_callback(ax, legend_font_size, legend_loc):
        # Determine if the simulation is covered based on the filename.
        is_covered = sim_file.endswith('_H.pkl')
        
        # Plot oxygen density profile with coverage-dependent styling.
        if oxygen_density_norm is not None and bin_centers_distance is not None:
            ax.fill_between(bin_centers_distance, oxygen_density_norm, color="black", alpha=0.2)
            if is_covered:
                ax.plot(bin_centers_distance, oxygen_density_norm, color='black', lw=1)
        
        # Plot the average CN curve.
        if is_covered:
            ax.plot(d_values, avg_cn, linestyle='-', linewidth=8, color="black")
        else:
            ion_color = ION_COLORS.get(ion, "#4472C4")
            ax.plot(d_values, avg_cn, linestyle='-', linewidth=8, color=ion_color)
    
    # Generate and display the plot using the custom_plot helper.
    custom_plot(
        xlabel="Distance from Surface (Å)",
        ylabel="CN",
        font_size=40,
        xticks=np.arange(d_min, d_max + 1, 1),
        xlim=(d_min, d_max),
        ylim=ylim,
        legend_options={"font_size": 24, "alpha": 0.8, "loc": "upper left"},
        plot_func=plot_callback,
        show=True
    )

# --- Example usage ---
simulation_file = "data/simulations/Pt111_Cs4_H.pkl"

# Uncomment to run the analysis on the provided simulation pickle file.
# plot_hollow_cn_vs_distance(simulation_file, ion='Cs', d_max=6, grid_resolution=5, n_d=20)

In [9]:
def compute_water_residence_time(sim_file, ion_type, skip_time=0):
    """
    Compute the average residence time (lifetime) of water molecules within the solvation sphere of a cation.
    
    This function:
      1. Loads a simulation from a pickle file using the helper function `load_simulation`.
      2. Selects cations of the specified type using sim_data.ions and the global ION_CUTOFFS.
      3. Checks, for each frame (after skipping frames with time < skip_time), whether each water molecule
         (water oxygen positions labeled "watO") is inside the solvation sphere of any selected ion.
      4. Builds an occupancy time series (True if a water molecule is within the cutoff, else False).
      5. Computes contiguous "in-shell" segments for each water molecule and converts these segments into time units,
         using the simulation timestep.
      6. Returns the average residence time (in ps), a list of all residence durations (in ps), and a coverage flag.
    
    The coverage flag is determined by checking if the simulation contains adsorbates (via 
    `sim_data.trajectories.positions.adsorbates`). Downstream plotting should display water densities and lines 
    with a black outline when coverage is True.
    
    Parameters:
        sim_file (str): Path to the simulation pickle file.
        ion_type (str): The ion element (e.g., "Cs", "K", etc.) defining the solvation sphere.
        skip_time (float): Skip simulation frames with time < skip_time (in ps).
    
    Returns:
        avg_lifetime (float): Average residence time (in ps) of water molecules in the solvation sphere.
        lifetimes (list): List of all residence time durations (in ps).
        coverage_flag (bool): True if the simulation is covered (adsorbates present), False otherwise.
    
    Raises:
        ValueError: If required data or parameters are missing.
    """
    # Load simulation using the helper function.
    sim_data = load_simulation(sim_file)
    
    # Determine coverage flag based on presence of adsorbates.
    if (hasattr(sim_data.trajectories.positions, 'adsorbates') and 
        sim_data.trajectories.positions.adsorbates is not None and 
        sim_data.trajectories.positions.adsorbates.size > 0):
        coverage_flag = True
    else:
        coverage_flag = False

    # Retrieve the cutoff distance for the specified ion type.
    cutoff = ION_CUTOFFS.get(ion_type, None)
    if cutoff is None:
        raise ValueError(f"No cutoff defined for ion type {ion_type}.")

    # Identify indices corresponding to the desired ion type.
    ion_indices = [i for i, ion in enumerate(sim_data.ions) if ion == ion_type]
    if not ion_indices:
        raise ValueError(f"No ions of type {ion_type} found in simulation.")

    # Extract water oxygen positions (expected shape: [n_frames, n_water, 3]).
    try:
        water_positions = getattr(sim_data.trajectories.positions, "watO")
    except AttributeError:
        raise ValueError("Water positions with label 'watO' not found in simulation.")

    # Extract positions of selected ions (expected shape: [n_frames, n_selected, 3]).
    ion_positions_all = sim_data.trajectories.positions.ions[:, ion_indices, :]

    # Get simulation times and select only frames after skip_time.
    times = sim_data.trajectories.times
    valid_frame_indices = np.where(times >= skip_time)[0]
    if valid_frame_indices.size == 0:
        raise ValueError("No frames available after applying skip_time.")
    
    water_positions = water_positions[valid_frame_indices]      # [n_valid, n_water, 3]
    ion_positions_all = ion_positions_all[valid_frame_indices]    # [n_valid, n_selected, 3]
    valid_times = times[valid_frame_indices]

    n_frames = water_positions.shape[0]
    n_water = water_positions.shape[1]

    # Initialize occupancy matrix: occupancy[i, j] is True if water molecule j is within cutoff 
    # of any selected ion in frame i.
    occupancy = np.zeros((n_frames, n_water), dtype=bool)

    # Loop over valid frames.
    for i in range(n_frames):
        # Compute pairwise differences between water and ion positions.
        diff = water_positions[i][:, np.newaxis, :] - ion_positions_all[i][np.newaxis, :, :]
        
        # Apply periodic boundary conditions if cell dimensions are provided.
        if sim_data.cell_dimensions is not None:
            cell = np.array(sim_data.cell_dimensions)
            box_lengths = np.diag(cell) if cell.ndim > 1 else cell
            diff = apply_periodic_boundary(diff, box_lengths)
        
        # Calculate Euclidean distances.
        distances = np.linalg.norm(diff, axis=-1)  # Shape: [n_water, n_selected]
        
        # Mark occupancy for each water molecule (True if within cutoff for any ion).
        occupancy[i, :] = np.any(distances < cutoff, axis=1)

    # Determine simulation timestep.
    if hasattr(sim_data, "timestep") and sim_data.timestep is not None:
        dt = sim_data.timestep
    else:
        dt_array = np.diff(valid_times)
        dt = np.mean(dt_array) if dt_array.size > 0 else 1.0

    # Helper: extract durations (in ps) from a 1D boolean occupancy array.
    def get_contiguous_durations(bool_array):
        durations = []
        current_length = 0
        for val in bool_array:
            if val:
                current_length += 1
            else:
                if current_length > 0:
                    durations.append(current_length * dt)
                    current_length = 0
        if current_length > 0:
            durations.append(current_length * dt)
        return durations

    # Compute residence times for each water molecule.
    all_lifetimes = []
    for j in range(n_water):
        durations = get_contiguous_durations(occupancy[:, j])
        all_lifetimes.extend(durations)

    # Calculate the average lifetime (if residence events were detected).
    if all_lifetimes:
        avg_lifetime = np.mean(all_lifetimes)
    else:
        avg_lifetime = 0.0
        logging.info(f"No water molecule residence events found for ion type {ion_type}.")

    return avg_lifetime, all_lifetimes, coverage_flag



# Uncomment and adjust the following to process the desired simulation.
# ----------------------------------------------------------------------

# Compute the average lifetime per ion type.
# ion_lifetimes = {}
# for ion in IONS:
#     lifetimes = []
#     for i in range(1, 5):
#         for h in ["", "_H"]:
#             sim_file = f"data/simulations/Pt111_{ion}{i}{h}.pkl"
#             avg_lifetime, _, _ = compute_water_residence_time(sim_file, ion)
#             lifetimes.append(avg_lifetime)
#     ion_lifetimes[ion] = np.mean(lifetimes)
# print("Average residence times (fs) per ion type:", ion_lifetimes)

In [None]:
def plot_water_near_surface(pickle_path, distance_threshold=3.0, metal_tol=0.5, bins=100):
    """
    Visualize the average XY density of water molecules near a metal surface using a heatmap,
    and overlay the average positions of the top metal layer atoms.
    
    For simulations with adsorbates (coverage), the water density is outlined in black.
    
    Parameters:
      pickle_path (str): Path to the simulation pickle file.
      distance_threshold (float): Maximum vertical distance (Å) from the surface for water molecules.
      metal_tol (float): Tolerance (Å) for selecting a metal atom as part of the top layer.
      bins (int): Number of bins for the 2D histogram heatmap.
    """
    # Set up logging and load the simulation.
    setup_logging(verbose=False)
    sim = load_simulation(pickle_path)
    n_frames = sim.trajectories.times.shape[0]
    
    # Determine if the simulation is "covered" (i.e. has adsorbates).
    if (hasattr(sim.trajectories.positions, 'adsorbates') and
        sim.trajectories.positions.adsorbates is not None and
        sim.trajectories.positions.adsorbates.size > 0):
        coverage_flag = True
    else:
        coverage_flag = False
    
    # Determine cell dimensions for the plot extent.
    cell = sim.cell_dimensions
    if np.ndim(cell) == 1:
        xlim, ylim = (0, cell[0]), (0, cell[1])
    else:
        diag = np.diag(cell)
        xlim, ylim = (0, diag[0]), (0, diag[1])
    
    # Define 2D histogram bin edges.
    xedges = np.linspace(xlim[0], xlim[1], bins + 1)
    yedges = np.linspace(ylim[0], ylim[1], bins + 1)
    
    # Initialize accumulators for water density and top metal positions.
    water_hist_sum = np.zeros((bins, bins))
    first_metal = sim.trajectories.positions.metal[0]
    n_metal = first_metal.shape[0]
    metal_coords_sum = np.zeros((n_metal, 2))
    metal_counts = np.zeros(n_metal)
    
    # Loop over all frames to accumulate water density and average metal positions.
    for frame in range(n_frames):
        # Process water oxygen positions.
        watO = sim.trajectories.positions.watO[frame]
        surface_z = sim.trajectories.surface_z[frame]
        water_mask = np.abs(watO[:, 2] - surface_z) < distance_threshold
        water_near = watO[water_mask]
        if water_near.size > 0:
            water_xy = water_near[:, :2]
            hist, _, _ = np.histogram2d(water_xy[:, 0], water_xy[:, 1],
                                        bins=[xedges, yedges])
            water_hist_sum += hist
        
        # Accumulate positions for top-layer metal atoms.
        metal = sim.trajectories.positions.metal[frame]
        if metal.size > 0:
            max_metal_z = np.max(metal[:, 2])
            for j in range(n_metal):
                if (max_metal_z - metal[j, 2]) < metal_tol:
                    metal_coords_sum[j] += metal[j, :2]
                    metal_counts[j] += 1
    
    # Compute the average water density and average metal positions.
    water_hist_avg = water_hist_sum / n_frames
    valid = metal_counts > 0
    avg_metal_positions = metal_coords_sum[valid] / metal_counts[valid, None]
    
    # Define a plotting callback to be used by custom_plot.
    def plot_callback(ax, legend_font_size, legend_loc):
        extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
        
        # Plot the water density as a heatmap.
        im = ax.imshow(water_hist_avg.T, extent=extent, origin='lower', aspect='auto',
                       cmap='Reds', interpolation='nearest')
        # Add a colorbar.
        cbar = ax.figure.colorbar(im, ax=ax)
        cbar.set_label("Average Water Density", fontsize=legend_font_size)
        
        # If covered, overlay a black outline (contour) to emphasize the density.
        if coverage_flag:
            ax.contour(water_hist_avg.T, extent=extent, origin='lower', levels=5,
                       colors='black', linewidths=1)
        
        # Overlay average top metal atom positions.
        scatter_handle = ax.scatter(avg_metal_positions[:, 0], avg_metal_positions[:, 1],
                                    facecolors='none', edgecolors='blue', s=50,
                                    label="Average Top Metal Atoms")
        # Build the legend.
        ax.legend(handles=[scatter_handle], labels=["Average Top Metal Atoms"],
                  fontsize=legend_font_size, loc=legend_loc)
    
    # Create and display the final plot.
    custom_plot(
        xlabel="X (Å)",
        ylabel="Y (Å)",
        font_size=20,
        figsize=(10, 8),
        plot_func=plot_callback,
        legend_options={"font_size": 16, "loc": "upper right"}
    )


# ----- Example usage -----
sim_file = "data/simulations/Pt111_NoIon.pkl"

# Uncomment to run the analysis on the provided simulation pickle file.
# plot_water_near_surface(sim_file, distance_threshold=2.5)