Main Jupyter Notebook that contains all codes to generate the plots used in the paper by R.S. Kort, A. Hagopian, K. Doblhoff-Dier, M. Koper about the Impact of Applied Potential and Hydrogen Coverage on Alkali Metal Cation Behaviour.





In [49]:
# Import all the required libraries
import numpy as np
import logging

# Import all matplotlib related libraries
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib.patheffects as pe
from matplotlib.collections import LineCollection
from matplotlib.lines import Line2D
from matplotlib.legend_handler import HandlerLine2D

import pickle
import logging


In [None]:
# Definition of parameters, basic functions, and logging
# -------------------------------------------------------

# Define some parameters
IONS = ["Li", "Na", "K", "Cs"]
ION_COLORS = {"Li": "#70AD47", "Na": "#FFC000", "K": "#A032A0", "Cs": "#4472C4"}
ION_CUTOFFS = {"Li": 2.75, "Na": 3.2, "K": 3.8, "Cs": 4.25}

# Define same general functions
# -----------------------------
def blend_color(hex_color, fraction):
    """
    Blend the given hex color with white. fraction=0 returns the original color,
    fraction=1 returns white.
    """
    hex_color = hex_color.lstrip('#')
    r = int(hex_color[0:2], 16)
    g = int(hex_color[2:4], 16)
    b = int(hex_color[4:6], 16)
    r_new = int(r + (255 - r) * fraction)
    g_new = int(g + (255 - g) * fraction)
    b_new = int(b + (255 - b) * fraction)
    return f"#{r_new:02X}{g_new:02X}{b_new:02X}"

def init_plot(
    xlabel: str = "X",
    ylabel: str = "Y",
    font_size: int = 60,
    font_family: str = 'Times New Roman',
    figsize: tuple = (12, 8),
    yticks_remove: bool = False,
    grid: str = "None",
    # Add tight_layout option
    tight_layout: bool = True
):
    """
    Customize the current matplotlib plot with various styling options.
    
    Parameters:
      xlabel (str): Label for the x-axis.
      ylabel (str): Label for the y-axis.
      font_size (int): Font size for the labels.
      font_family (str): Font family for the labels.
      figsize (tuple): Size of the figure.
      yticks_remove (bool): Whether to remove y-axis ticks.
      grid (str): Type of grid to display. Options: "x", "y", "both", "None".
      tight_layout (bool): Whether to use tight layout.
    """
  
    # Create a new figure with the specified size
    plt.figure(figsize=figsize)

    # Set font properties
    plt.rcParams.update({'font.size': font_size, 'font.family': font_family})

    # Set axis labels
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)

    # Get the current axis
    ax = plt.gca()
    
    # Remove y-axis ticks
    if yticks_remove:
        ax.axes.get_yaxis().set_ticks([])
    
    # Hide the top, right, and left spines
    for spine in ['top', 'right', 'left']:
        ax.spines[spine].set_visible(False)
       
    # Configure grid
    if grid == "x":
        ax.yaxis.grid(False)  # Disable grid on y-axis
        ax.xaxis.grid(True)   # Enable grid on x-axis
    elif grid == "y":
        ax.xaxis.grid(False)  # Disable grid on x-axis
        ax.yaxis.grid(True)   # Enable grid on y-axis
    elif grid in [None, "None", False]:  # Handle different ways to disable grid
        ax.grid(False)
    else:
        ax.grid(True)

    
    # Apply tight layout if requested
    if tight_layout:
        plt.tight_layout()

def show_plot(
    legend_font_size: int = 36,
    legend_draw_bg: bool = True,   # Whether to draw a background patch behind the legend.
    legend_border: bool = False,       # Whether the legend patch should have a border.
    legend_bgcolor: str = 'white',     # Background color of the legend patch.
    legend_edgecolor: str = 'black',   # Edge (border) color of the legend patch.
    legend_alpha: float = 0.8,         # Transparency of the legend background (0: transparent, 1: opaque).
    legend_linewidth: float = None,    # Width of the lines in the legend (if applicable).
    legend_loc: str = 'best',          # Location of the legend.
    xticks: list = None,
    yticks: list = None,
    xlim: tuple = None,
    ylim: tuple = None,
    legend_handles: list = None,
    legend_labels: list = None,
    legend_handler_map: dict = None,
    show: bool = True
):
    """
    Display the current matplotlib plot with fully parameter-determined styling for both axes and legend.

    Parameters:
        legend_font_size (int): Font size for the legend text.
        legend_draw_bg (bool): Whether to display a background patch for the legend.
        legend_border (bool): Whether to draw a border around the legend background patch.
        legend_bgcolor (str): Background color for the legend.
        legend_edgecolor (str): Border color for the legend.
        legend_alpha (float): Transparency of the legend background (0 to 1).
        legend_linewidth (float, optional): Line width for the lines in the legend.
        legend_loc (str): Location of the legend.
        xticks (list, optional): Tick locations for the x-axis.
        yticks (list, optional): Tick locations for the y-axis.
        xlim (tuple, optional): Limits for the x-axis as (min, max).
        ylim (tuple, optional): Limits for the y-axis as (min, max).
        legend_handles (list, optional): Custom legend handles.
        legend_labels (list, optional): Custom legend labels.
        legend_handler_map (dict, optional): Custom legend handler map.
        show (bool): Whether to display the plot.
    """
    # Set axis ticks and limits if provided.
    if xticks is not None:
        plt.xticks(xticks)
    if yticks is not None:
        plt.yticks(yticks)
    if xlim is not None:
        plt.xlim(xlim)
    if ylim is not None:
        plt.ylim(ylim)
    
    # Update the global font size for legend text.
    plt.rcParams.update({'font.size': legend_font_size})
    
    # Create the legend with the specified background patch option.
    if legend_handler_map is None:
        if legend_handles is not None and legend_labels is not None:
            leg = plt.legend(legend_handles, legend_labels, frameon=legend_draw_bg, loc=legend_loc)
        else:
            leg = plt.legend(frameon=legend_draw_bg, loc=legend_loc)
    else:
        if legend_handles is not None and legend_labels is not None:
            leg = plt.legend(legend_handles, legend_labels, frameon=legend_draw_bg, loc=legend_loc, handler_map=legend_handler_map)
        else:
            leg = plt.legend(frameon=legend_draw_bg, loc=legend_loc, handler_map=legend_handler_map)

    # If a background patch is drawn, customize its appearance.
    if legend_draw_bg:
        leg.get_frame().set_facecolor(legend_bgcolor)
        # Apply the alpha (transparency) value.
        leg.get_frame().set_alpha(legend_alpha)
        # Set the edge color: use provided color only if a border is desired.
        if legend_border:
            leg.get_frame().set_edgecolor(legend_edgecolor)
        else:
            leg.get_frame().set_edgecolor('none')
    
    # Apply legend line width for the legend lines if provided.
    if legend_linewidth is not None:
        for line in leg.get_lines():
            line.set_linewidth(legend_linewidth)

    if show:
        plt.show()

    return leg


def apply_periodic_boundary(diff, box_lengths):
    """
    Applies periodic boundary conditions to a difference vector.
    """
    return diff - box_lengths * np.round(diff / box_lengths)


def setup_logging(verbose=False):
    """
    Configure logging based on verbosity.

    Parameters:
        verbose (bool): If True, set logging level to DEBUG, else INFO.
    """
    log_level = logging.DEBUG if verbose else logging.INFO
    logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=log_level)



In [None]:
# Function to load the Simulation object from a pickle file.

def load_simulation(filename):
    """
    Load a Simulation object from a pickle file.

    Parameters:
        filename (str): The path to the pickle file.

    Returns:
        Simulation: The loaded Simulation object.
    """
    try:
        with open(filename, 'rb') as pf:
            simulation = pickle.load(pf)
        logging.debug("Successfully loaded Simulation from %s", filename)
        return simulation
    except Exception as e:
        logging.error("Failed to load Simulation from %s: %s", filename, e)
        raise

"""    
# Load a Simulation object from a pickle file and print some metadata.
# Example usage:
pickle_file = "data/simulations/Pt111_Cs1.pkl"
sim = load_simulation(pickle_file)

# Access simulation-level metadata.
print("Simulation Metadata:\n----------------")
print("Project Name:", sim.project_name)
print("Timestep:", sim.timestep)
print("Cell Dimensions:", sim.cell_dimensions)
print("Lattice Dimensions:", sim.lattice_dimensions)
print("Electrode Potential:", sim.electrode_potential)
print("Metal Type:", sim.metal_type)
print("Ions:", sim.ions)

# Access trajectory-level metadata.
print("\nTrajectory Metadata:\n----------------")
print("Number of Trajectories:", len(sim.trajectories.times))
print("Time of the first few frames:", sim.trajectories.times[:5])
print("Surface Z-coordinate of the first few frames:", sim.trajectories.surface_z[:5])

# Access position data.
print("\nPosition data:\n----------------")
print("Metal positions:", sim.trajectories.positions.metal[:5])
if len(sim.trajectories.positions.adsorbates[0]) != 0:
    print("Adsorbate positions:", sim.trajectories.positions.adsorbates[:5])
print("Water positions:", sim.trajectories.positions.water[:5])
print("Ion positions:", sim.trajectories.positions.ions[:5])
print("Oxygen (of water) positions:", sim.trajectories.positions.watO[:5])
print("Hydrogen (of water) positions:", sim.trajectories.positions.watH[:5])

"""

In [None]:
# Function to compute the ion-O RDF for one frame.

# --- Helper function to compute simulation parameters for ion-O RDF --- 
def get_simulation_parameters_ion_O(data, bin_edges):
    """
    Computes simulation cell volume, spherical shell volumes (for the oxygen-ions RDF),
    bin centers, bin width, and the box lengths (assuming a 1D or diagonal matrix)
    from the provided simulation data.
    """
    cell = data.cell_dimensions
    if cell is None:
        logging.warning("No cell dimensions found in file.")
        return None, None, None, None, None
    # If cell is given as a 1D array, assume a cubic-like cell.
    if np.array(cell).ndim == 1:
        volume = np.prod(cell)
        box_lengths = np.array(cell)
    else:
        # Assume cell is given as a full matrix; extract diagonal elements as box lengths.
        diag = np.diag(cell)
        volume = np.prod(diag)
        box_lengths = diag

    # Calculate the volumes of the spherical shells for the oxygen-ions RDF.
    shell_volumes = (4.0 / 3.0) * np.pi * (bin_edges[1:]**3 - bin_edges[:-1]**3)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    bin_width = bin_edges[1] - bin_edges[0]
    return volume, shell_volumes, bin_centers, bin_width, box_lengths


# --- Compute the ion-O RDF for one frame --- 
def compute_frame_ion_O_rdf(oxygen_positions, ion_positions, oxygen_label, ion_label,
                           bin_edges, shell_volumes, bin_centers, bin_width, volume, box_lengths):
    """
    Computes the radial distribution function (RDF) for one frame, where distances
    are calculated between each oxygen atom (e.g. water oxygen) and every ion.
    
    If the oxygen and ion labels are identical, self-correlations are removed.
    """
    # Calculate pairwise differences and apply periodic boundary conditions.
    diff = oxygen_positions[:, np.newaxis, :] - ion_positions
    diff = apply_periodic_boundary(diff, box_lengths)
    distances = np.linalg.norm(diff, axis=-1).ravel()
    
    # Exclude self-correlations when oxygen and ion labels match.
    if oxygen_label == ion_label:
        distances = distances[distances > 1e-6]
    
    hist, _ = np.histogram(distances, bins=bin_edges)
    num_oxygen = oxygen_positions.shape[0]
    density_oxygen = num_oxygen / volume
    num_ions = ion_positions.shape[1]
    expected_total = num_ions * density_oxygen * shell_volumes
    with np.errstate(divide='ignore', invalid='ignore'):
        rdf_frame = hist / expected_total
        rdf_frame[np.isnan(rdf_frame)] = 0
    return rdf_frame, density_oxygen

# --- Calculate ion-O RDF over multiple frames for a given oxygen species --- 
def calculate_ion_O_rdf(data, oxygen_label, ion_label, skip_time, bin_edges, ion_indices=None):
    """
    Computes the oxygen-ions RDF averaged over simulation frames.
    The ion positions are taken from data.trajectories.positions.ions.
    
    Parameters:
      data: simulation data object.
      oxygen_label (str): Attribute name in data.trajectories.positions for the oxygen species (e.g., "watO").
      ion_label (str): A string used for comparing oxygen and ion species.
                     If oxygen_label equals ion_label, self-correlations are excluded.
      skip_time (float): Time threshold (ps) to skip early simulation frames.
      bin_edges (np.array): Edges of the radial bins.
      ion_indices (list or None): Optional indices to select a subset of ions (by type) from data.trajectories.positions.ions.
      
    Returns:
      bin_centers (np.array): The centers of the radial bins.
      avg_rdf (np.array): The averaged oxygen-ions radial distribution function.
      avg_density (float): The averaged density of the oxygen species.
    """
    try:
        oxygen_positions_all = getattr(data.trajectories.positions, oxygen_label)
    except AttributeError:
        logging.warning("No positions for oxygen species '%s' found.", oxygen_label)
        return None, None, None

    # Always use the ions from the simulation as reference positions.
    if ion_indices is None:
        ion_positions_all = data.trajectories.positions.ions
    else:
        ion_positions_all = data.trajectories.positions.ions[:, ion_indices, :]

    volume, shell_volumes, bin_centers, bin_width, box_lengths = get_simulation_parameters_ion_O(data, bin_edges)
    if volume is None:
        return None, None, None

    rdf_list = []
    density_list = []
    valid_frames = 0
    for i, t in enumerate(data.trajectories.times):
        if t < skip_time:
            continue
        valid_frames += 1
        oxygen_positions = oxygen_positions_all[i]
        ion_positions = ion_positions_all[i]
        rdf_frame, density_oxygen = compute_frame_ion_O_rdf(
            oxygen_positions, ion_positions,
            oxygen_label, ion_label,
            bin_edges, shell_volumes, bin_centers, bin_width,
            volume, box_lengths
        )
        rdf_list.append(rdf_frame)
        density_list.append(density_oxygen)

    if valid_frames == 0:
        logging.warning("No frames passed the skip time threshold.")
        return None, None, None

    avg_rdf = np.mean(rdf_list, axis=0)
    avg_density = np.mean(density_list)
    return bin_centers, avg_rdf, avg_density

# --- Main function to aggregate and plot ion-O RDF data --- 
def plot_ion_O_rdfs(files, skip=5, bins=50, max_distance=8, verbose=False):
    """
    Loads simulation pickle files, computes the oxygen-ions RDF (ion–O RDF) for each file,
    aggregates the RDF data per ion type (as defined in data.ions), and plots one averaged RDF
    curve per ion type.
    
    Parameters:
      files (list): List of simulation pickle file paths.
      skip (float): Time (ps) to skip at the beginning of each simulation.
      bins (int): Number of bins for the RDF histogram.
      max_distance (float): Maximum distance for the RDF histogram.
      verbose (bool): If True, enable debug-level logging.
    """
    setup_logging(verbose)

    bin_edges = np.linspace(0, max_distance, bins + 1)
    # Dictionary to hold aggregated ion-O RDF data per ion type.
    # Each key stores:
    #   'bin_centers': bin centers (assumed identical for all files),
    #   'rdf_sum': sum of (rdf * ion_count) over files,
    #   'total_count': total ion count across files.
    aggregated = {}
    
    for file_path in files:
        logging.debug("Processing file: %s", file_path)
        try:
            sim_data = load_simulation(file_path)  # load_simulation is assumed to be defined elsewhere.
        except Exception as e:
            logging.error("Failed to load file %s: %s", file_path, e)
            continue

        # Loop over unique ion types in the simulation.
        unique_ions = set(sim_data.ions)
        for ion_type in unique_ions:
            # Get indices for the current ion type.
            indices = [i for i, ion in enumerate(sim_data.ions) if ion == ion_type]
            count = sim_data.trajectories.positions.ions.shape[1] if indices is None else len(indices)
            
            # Compute the oxygen (e.g., water oxygen) RDF with respect to these ions.
            for oxygen in ["watO"]:
                # Here, ion_label is set to "ions" for the purpose of comparison.
                bin_centers, rdf_avg, avg_density = calculate_ion_O_rdf(
                    sim_data, oxygen, "ions", skip, bin_edges, ion_indices=indices
                )
                if bin_centers is None or rdf_avg is None:
                    logging.warning("Skipping file %s for ion type %s and oxygen %s due to missing RDF data.", file_path, ion_type, oxygen)
                    continue
                
                # Weight the computed RDF by the number of ions of this type.
                weighted_rdf = rdf_avg * len(indices)
                label = f"{ion_type}"
                if label not in aggregated:
                    aggregated[label] = {
                        "bin_centers": bin_centers,
                        "rdf_sum": weighted_rdf,
                        "total_count": len(indices)
                    }
                else:
                    aggregated[label]["rdf_sum"] += weighted_rdf
                    aggregated[label]["total_count"] += len(indices)
    
    # Prepare aggregated data for plotting.
    rdf_data_list = []
    for label, data_dict in aggregated.items():
        # Compute the weighted average ion-O RDF for this ion type.
        avg_rdf = data_dict["rdf_sum"] / data_dict["total_count"]
        rdf_data_list.append((data_dict["bin_centers"], avg_rdf, label))
        logging.debug("Aggregated ion-O RDF for %s: Total ion count = %d", label, data_dict["total_count"])
    
    if rdf_data_list:
        init_plot(xlabel="Ion–O Distance (Å)", ylabel="g(r)", yticks_remove=True)  # Assumes init_plot is defined elsewhere.
        for bin_centers, rdf, label in rdf_data_list:
            plt.plot(bin_centers, rdf, linestyle='-', linewidth=8, marker=None,
                     label=label, color=ION_COLORS.get(label, 'black'))  # ION_COLORS is assumed to be defined elsewhere.
        show_plot(xticks=np.arange(2, 7, 1), xlim=(1.5, 6.5), ylim=(0))  # Assumes show_plot is defined elsewhere.
    else:
        logging.error("No valid ion-O RDF data to plot.")

# --- Example usage --- 
# List of simulation IDs (used to generate file paths) for ion-O RDF analysis.
sim_ids = ["Li2", "Na2", "K2", "Cs2"]
files = [f"data/simulations/Pt111_{sim_id}.pkl" for sim_id in sim_ids]

#plot_ion_O_rdfs(files=files)


In [None]:
# Function to compare the densities of each ion per simulation.

def calculate_histogram_density(all_distances, bin_edges, cell_dimensions, valid_frames):
    """
    Compute the density histogram given distances, cell dimensions, bin_edges, and the number of valid frames.
    """
    counts, _ = np.histogram(all_distances, bins=bin_edges)
    # Convert bin width from Å to nm.
    bin_width_nm = (bin_edges[1] - bin_edges[0]) * 0.1
    area_nm2 = (cell_dimensions[0] * 0.1) * (cell_dimensions[1] * 0.1)
    density = counts / (area_nm2 * bin_width_nm * valid_frames)
    return density

def extract_initial_positions(simulation, ion, skip):
    """
    Extract the initial positions (first frame) of the specified ion type.
    """
    # Create a boolean mask using the stored ion element symbols.
    target_mask = np.array([elem == ion for elem in simulation.ions])
    if not np.any(target_mask):
        return []
    # Extract z positions for the first frame.
    first_frame_positions = simulation.trajectories.positions.ions[0, target_mask, 2]
    first_frame_surface = simulation.trajectories.surface_z[0]
    distances = np.abs(first_frame_positions - first_frame_surface)
    return distances.tolist()

def process_density(simulation, target_positions, skip, bin_edges):
    """
    Process density for either ions or water oxygen:
      - Iterates through frames (skipping early ones),
      - Collects distances from the surface,
      - Returns the computed density histogram and the number of valid frames.
    """
    distances_list = []
    valid_frames = 0
    times = simulation.trajectories.times
    for i in range(len(times)):
        if times[i] < skip:
            continue
        valid_frames += 1
        # Compute distance from surface for the current frame.
        d = np.abs(target_positions[i] - simulation.trajectories.surface_z[i])
        distances_list.append(d)
    if valid_frames == 0:
        logging.warning("No valid frames found after skipping %d ps.", skip)
        return None, 0
    all_distances = np.concatenate(distances_list)
    density = calculate_histogram_density(all_distances, bin_edges, simulation.cell_dimensions, valid_frames)
    return density, valid_frames

def plot_density(files, skip=5, bins=50, smoothing=3, initial_position=False, verbose=False, normalize=True):
    """
    Plot the density of ions and oxygen as a function of distance from the surface.
    Density is computed as the number of atoms in a bin divided by the area, bin width,
    and the number of valid frames.
    
    Special treatment:
      - If a simulation is hydrogen-covered (i.e. adsorbates present), its ion density curves
        are plotted as hollow lines (a thin white line overlaid by a colored stroke via path effects)
        and its oxygen density is drawn with a hatched fill.
    
    Parameters:
      - files (list of str): List of simulation file paths.
      - skip (int): Skip frames with time (in ps) less than this value.
      - bins (int): Number of bins for the histogram.
      - smoothing (int): Smoothing window for the density curves.
      - initial_position (bool): Whether to plot initial positions.
      - verbose (bool): Enable verbose logging.
      - normalize (bool): If True, normalize each averaged ion density curve (area set to 1) and
                          scale oxygen density to match the ion density maximum.
    """
    setup_logging(verbose)
    init_plot(xlabel="Distance to Surface (Å)", 
              ylabel="ρ (a.u.)",
              yticks_remove=True,
              font_size=32)
    
    # Define histogram bins (2 Å to 5 Å) and compute bin centers.
    bin_edges = np.linspace(2, 5, bins + 1)
        
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    bin_width_nm = (bin_edges[1] - bin_edges[0]) * 0.1

    # Prepare dictionaries for each hydrogen group.
    # For ions, group by hydrogen flag: {False: {ion: [densities]}, True: {ion: [densities]}}
    ion_density_dict = {False: {ion: [] for ion in IONS}, True: {ion: [] for ion in IONS}}
    # For oxygen, store densities per group.
    oxygen_density_list = {False: [], True: []}
    # For initial positions (if requested).
    initial_positions = {ion: [] for ion in IONS}
    
    # Loop over simulation files and group results by hydrogen coverage.
    for file_path in files:
        simulation = load_simulation(file_path)
        # Detect hydrogen coverage by checking for adsorbates.
        hydrogen_present = (hasattr(simulation.trajectories.positions, 'adsorbates') and
                            simulation.trajectories.positions.adsorbates is not None and
                            simulation.trajectories.positions.adsorbates.size > 0)
        
        # Optionally extract initial positions.
        if initial_position:
            for ion in IONS:
                initial_positions[ion].extend(extract_initial_positions(simulation, ion, skip))
        
        # Process density for each ion.
        for ion in IONS:
            target_mask = np.array([elem == ion for elem in simulation.ions])
            if not np.any(target_mask):
                continue
            target_positions = simulation.trajectories.positions.ions[:, target_mask, 2]
            density, valid_frames = process_density(simulation, target_positions, skip, bin_edges)
            if density is None:
                logging.warning("No valid frames for ion %s in file %s", ion, file_path)
                continue
            ion_density_dict[hydrogen_present][ion].append(density)
        
        # Process oxygen density if water oxygen positions are available.
        if simulation.trajectories.positions.watO is None or simulation.trajectories.positions.watO.size == 0:
            logging.debug("No water oxygen positions found in file %s", file_path)
        else:
            target_positions = simulation.trajectories.positions.watO[:, :, 2]
            density, valid_frames = process_density(simulation, target_positions, skip, bin_edges)
            if valid_frames > 0 and density is not None:
                oxygen_density_list[hydrogen_present].append(density)
            else:
                logging.warning("No valid frames for oxygen in file %s", file_path)
    
    # Average densities over files for each group.
    # Also determine maximum ion density (for scaling oxygen) separately for each group.
    group_max_ion = {False: 0.0, True: 0.0}
    avg_ion_density = {False: {}, True: {}}
    for h_flag in [False, True]:
        for ion in IONS:
            if ion_density_dict[h_flag][ion]:
                avg_density = np.mean(ion_density_dict[h_flag][ion], axis=0)
                if normalize:
                    area = np.sum(avg_density * bin_width_nm)
                    if area != 0:
                        avg_density = avg_density / area
                avg_ion_density[h_flag][ion] = avg_density
                group_max_ion[h_flag] = max(group_max_ion[h_flag], np.max(avg_density))
            else:
                avg_ion_density[h_flag][ion] = None
    
    avg_oxygen_density = {False: None, True: None}
    for h_flag in [False, True]:
        if oxygen_density_list[h_flag]:
            oxygen_density = np.mean(oxygen_density_list[h_flag], axis=0)
            if normalize:
                area = np.trapz(oxygen_density, bin_centers)
                if area != 0:
                    oxygen_density = oxygen_density / area
            avg_oxygen_density[h_flag] = oxygen_density
        else:
            avg_oxygen_density[h_flag] = None

    # Scale oxygen density so its maximum matches the ion density maximum (per group).
    scaled_oxygen_density = {False: None, True: None}
    for h_flag in [False, True]:
        od = avg_oxygen_density[h_flag]
        if od is not None and np.max(od) > 0 and group_max_ion[h_flag] > 0:
            scaled_oxygen_density[h_flag] = od * (group_max_ion[h_flag] / np.max(od))
        else:
            scaled_oxygen_density[h_flag] = od
    
    # --- Plotting ---
    # Plot oxygen densities first.
    # For hydrogen-covered systems, add a hatch pattern.
    for h_flag in [False, True]:
        od = scaled_oxygen_density[h_flag]
        if od is not None and len(od) == len(bin_centers):
            hatch_pattern = '//' if h_flag else None
            poly = plt.fill_between(bin_centers, od, color="black", alpha=0.2)
            if hatch_pattern:
                poly.set_hatch(hatch_pattern)
    
    # Plot ion density curves for each group.
    # For hydrogen-covered systems, use the special hollow-line style.
    for h_flag in [False, True]:
        for ion in IONS:
            density = avg_ion_density[h_flag][ion]
            if density is not None and len(density) == len(bin_centers):
                if smoothing > 1:
                    density = np.convolve(density, np.ones(smoothing) / smoothing, mode='same')
                if h_flag:
                    # Hollow-line: plot a thin white line then overlay with a thicker colored stroke.
                    line, = plt.plot(bin_centers, density, linewidth=1, color='white',
                                       label=f"{ion}")
                    line.set_path_effects([pe.Stroke(linewidth=8, foreground=ION_COLORS[ion]), pe.Normal()])
                else:
                    plt.plot(bin_centers, density, linestyle='-', linewidth=8, marker=None,
                             color=ION_COLORS[ion], label=f"{ion}")
    
    # Plot initial positions if requested.
    if initial_position:
        for ion in IONS:
            for pos in initial_positions[ion]:
                plt.plot(pos, 0.2, marker='o', markersize=10, linestyle='None', color=ION_COLORS[ion])
    
    show_plot(xticks=np.arange(2, 6, 0.5), xlim=(2, 5), ylim=(0.15), legend_font_size=32)
    
# Example calls:
# plot_density([f"data/simulations/Pt111_{ion}2.pkl" for ion in IONS])
# plot_density([f"data/simulations/Pt111_{ion}4_H.pkl" for ion in IONS])

In [None]:
# Function to compare the densities of each ion per simulation.

def compare_ion_densities(simulation_files, skip=5, bins=50, smoothing=3, normalize=True, verbose=False):
    """
    Compare the densities of each ion and the oxygen density per simulation.
    For each simulation, the ion density curves are plotted as separate lines,
    and the oxygen density is plotted as a shaded region.

    The oxygen density is scaled such that its highest peak (across all simulations)
    matches the maximum ion density (for distances below 5 Å), while preserving
    the relative differences between multiple oxygen density curves.

    Parameters:
    - simulation_files (list of str): List of simulation pickle files.
    - skip (int): Skip frames with time (in ps) less than this value.
    - bins (int): Number of bins for the histogram.
    - smoothing (int): Window for smoothing the density curves.
    - normalize (bool): If True, normalize the density curves (area under the curve = 1).
    - verbose (bool): Enable verbose logging.
    """
    setup_logging(verbose)
    init_plot(xlabel="Distance to Surface (Å)",
              ylabel="ρ (a.u.)",
              yticks_remove=True,
              font_size=32)

    # Define histogram bins (from 2 Å to 5 Å).
    bin_edges = np.linspace(2, 5, bins + 1)
    
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])

    # Load simulations and collect their electrode potentials.
    simulations = []
    potentials = []
    for file_path in simulation_files:
        sim = load_simulation(file_path)
        simulations.append(sim)
        potentials.append(sim.electrode_potential)
    
    # Pre-compute hydrogen coverage flag for each simulation.
    coverage_flags = []
    for sim in simulations:
        if (hasattr(sim.trajectories.positions, 'adsorbates') and
            sim.trajectories.positions.adsorbates is not None and
            sim.trajectories.positions.adsorbates.size > 0):
            coverage_flags.append(True)
        else:
            coverage_flags.append(False)
    
    # Group potentials based on hydrogen coverage.
    potentials_covered = [sim.electrode_potential for sim, cov in zip(simulations, coverage_flags) if cov]
    potentials_bare = [sim.electrode_potential for sim, cov in zip(simulations, coverage_flags) if not cov]
    
    if potentials_covered:
        max_pot_covered = max(potentials_covered)
        min_pot_covered = min(potentials_covered)
    else:
        max_pot_covered = min_pot_covered = None

    if potentials_bare:
        max_pot_bare = max(potentials_bare)
        min_pot_bare = min(potentials_bare)
    else:
        max_pot_bare = min_pot_bare = None

    # ---- Pass 1: Compute Global Maximums ----
    global_ion_max = 0
    global_oxygen_max = 0

    # Loop over each simulation to compute maximum ion and oxygen densities.
    for sim in simulations:
        # Process ion densities for each ion.
        for ion in IONS:
            target_mask = np.array([elem == ion for elem in sim.ions])
            if not np.any(target_mask):
                continue
            target_positions = sim.trajectories.positions.ions[:, target_mask, 2]
            density, valid_frames = process_density(sim, target_positions, skip, bin_edges)
            if density is None:
                continue
            if smoothing > 1:
                density = np.convolve(density, np.ones(smoothing) / smoothing, mode='same')
            if normalize:
                area = np.trapz(density, bin_centers)
                if area != 0:
                    density /= area
            global_ion_max = max(global_ion_max, density.max())
        
        # Process oxygen density if available.
        if (hasattr(sim.trajectories.positions, "watO") and
            sim.trajectories.positions.watO is not None and
            sim.trajectories.positions.watO.size > 0):
            target_positions = sim.trajectories.positions.watO[:, :, 2]
            oxygen_density, valid_frames = process_density(sim, target_positions, skip, bin_edges)
            if oxygen_density is not None:
                if smoothing > 1:
                    oxygen_density = np.convolve(oxygen_density, np.ones(smoothing) / smoothing, mode='same')
                if normalize:
                    area = np.trapz(oxygen_density, bin_centers)
                    if area != 0:
                        oxygen_density /= area
                global_oxygen_max = max(global_oxygen_max, oxygen_density.max())

    # Compute scaling factor for oxygen densities.
    if global_oxygen_max > 0:
        scaling_factor = global_ion_max / global_oxygen_max
    else:
        scaling_factor = 1  # Avoid division by zero.
    
    # ---- Pass 2: Process and Plot Each Simulation ----
    for idx, sim in enumerate(simulations):
        hydrogen_present = coverage_flags[idx]

        # Determine the group-specific potential range.
        if hydrogen_present and (min_pot_covered is not None and max_pot_covered is not None):
            group_min = min_pot_covered
            group_max = max_pot_covered
        elif (not hydrogen_present) and (min_pot_bare is not None and max_pot_bare is not None):
            group_min = min_pot_bare
            group_max = max_pot_bare
        else:
            group_min = min(potentials)
            group_max = max(potentials)
        group_range = group_max - group_min if group_max != group_min else 1  # avoid division by zero

        # Process and plot ion densities.
        for ion in IONS:
            target_mask = np.array([elem == ion for elem in sim.ions])
            if not np.any(target_mask):
                continue
            target_positions = sim.trajectories.positions.ions[:, target_mask, 2]
            density, valid_frames = process_density(sim, target_positions, skip, bin_edges)
            if density is None:
                logging.warning("No valid frames for ion %s in simulation %s", ion, sim.filename)
                continue
            if smoothing > 1:
                density = np.convolve(density, np.ones(smoothing) / smoothing, mode='same')
            if normalize:
                area = np.trapz(density, bin_centers)
                if area != 0:
                    density /= area

            # Compute the fraction for color blending based on the group-specific potential range.
            exact_fraction = (sim.electrode_potential - group_min) / group_range
            color_fraction = 0.75 * exact_fraction
            ion_color = blend_color(ION_COLORS[ion], color_fraction)
            label = f"{ion} ({sim.electrode_potential:.2f} V)"

            if hydrogen_present:
                # Draw a thin white line as background.
                line, = plt.plot(bin_centers, density, linewidth=1, color='white', label=label)
                # Overlay a thicker stroke using the blended color.
                line.set_path_effects([pe.Stroke(linewidth=8, foreground=ion_color), pe.Normal()])
            else:
                plt.plot(bin_centers, density, linestyle='-', linewidth=8, marker=None,
                         color=ion_color, label=label)

        # Process and plot oxygen density.
        if (hasattr(sim.trajectories.positions, "watO") and
            sim.trajectories.positions.watO is not None and
            sim.trajectories.positions.watO.size > 0):
            target_positions = sim.trajectories.positions.watO[:, :, 2]
            oxygen_density, valid_frames = process_density(sim, target_positions, skip, bin_edges)
            if oxygen_density is None:
                logging.warning("No valid frames for oxygen in simulation %s", sim.filename)
            else:
                if smoothing > 1:
                    oxygen_density = np.convolve(oxygen_density, np.ones(smoothing) / smoothing, mode='same')
                if normalize:
                    area = np.trapz(oxygen_density, bin_centers)
                    if area != 0:
                        oxygen_density /= area
                # Scale the oxygen density using the global scaling factor.
                oxygen_density *= scaling_factor

                # Use the same color blending for oxygen.
                #oxygen_color = blend_color("#000000", color_fraction)
                oxygen_color = blend_color("#000000", 0.5)
                hatch_pattern = '//' if hydrogen_present else None
                poly = plt.fill_between(bin_centers, oxygen_density, color=oxygen_color, alpha=0.3)
                if hatch_pattern:
                    poly.set_hatch(hatch_pattern)
        else:
            logging.debug("No water oxygen positions found in simulation %s", sim.filename)
    
    show_plot(xticks=np.arange(2, 5.5, 0.5), xlim=(2, 5), ylim=(0), legend_font_size=24, legend_alpha=0.8)

# Example calls for different ions and potentials:
# for ion in IONS:
#     compare_ion_densities([f"data/simulations/Pt111_{ion}{i}.pkl" for i in [2, 3, 4]])
    
# for ion in IONS:
#     compare_ion_densities([f"data/simulations/Pt111_{ion}{i}_H.pkl" for i in range(2, 5)])

# for ion in IONS:
#     compare_ion_densities([f"data/simulations/Pt111_{ion}4.pkl", f"data/simulations/Pt111_{ion}4_H.pkl"])

In [None]:
def plot_opening_angles(files, bins=180, region_width=0.5, min_count=1000, smoothing=3):
    """
    Plot the distribution of opening angles between ions, water oxygen, and the surface-normal.
    The angles are computed for each ion type and grouped by the ion's distance from the surface.
    Only regions with region_min > 1.5 Å and region_max < 5.5 Å (and with at least min_count angles)
    are plotted. In addition to the opening-angle curves (drawn in blended colors), an inset density
    plot is added where the ion density (computed via process_density) is drawn versus distance.
    The inset also shows the oxygen density (averaged and scaled to the ion density maximum) as
    a shaded region, similar to the original density plot functions.
    
    Parameters:
      files (list): List of simulation pickle file paths.
      bins (int): Number of bins for the opening-angle histogram (angles from 0 to 180°).
      region_width (float): Width (in Å) of the distance regions.
      min_count (int): Minimum number of angles required for a region to be included.
      smoothing (int): Smoothing window size for density curves.
    """
    # --- Accumulate opening-angle data and density counts ---
    angle_data = {ion: {} for ion in IONS}
    density_counts = {ion: {} for ion in IONS}
    
    # Determine hydrogen coverage flag for the set (assumes homogeneous set).
    hydrogen_covered = False
    for file in files:
        sim = load_simulation(file)
        if (hasattr(sim.trajectories.positions, 'adsorbates') and
            sim.trajectories.positions.adsorbates is not None and
            sim.trajectories.positions.adsorbates.size > 0):
            hydrogen_covered = True
            break

    for file in files:
        sim = load_simulation(file)
        n_frames = sim.trajectories.times.shape[0]
        surface_zs = sim.trajectories.surface_z  # one surface z per frame
        
        ion_positions_all = sim.trajectories.positions.ions    # (n_frames, n_ions, 3)
        o_positions_all = sim.trajectories.positions.watO         # (n_frames, n_water, 3)
        if o_positions_all is None or o_positions_all.size == 0:
            print(f"No water oxygen positions found in file: {file}")
            continue
        
        # Build a dictionary of indices per ion type.
        ion_indices = {}
        for i, ion in enumerate(sim.ions):
            ion_indices.setdefault(ion, []).append(i)
        
        for frame in range(n_frames):
            frame_ions = ion_positions_all[frame]  # (n_ions, 3)
            frame_o = o_positions_all[frame]         # (n_water, 3)
            surface_z = surface_zs[frame]
            
            for ion in IONS:
                if ion not in ion_indices:
                    continue
                for idx in ion_indices[ion]:
                    ion_pos = frame_ions[idx]
                    distance = ion_pos[2] - surface_z
                    if distance < 0:
                        continue
                    region_min = np.floor(distance / region_width) * region_width
                    region_max = region_min + region_width
                    if region_min > 1.5 and region_max < 5.5:
                        region_key = (region_min, region_max)
                        angle_data[ion].setdefault(region_key, [])
                        density_counts[ion].setdefault(region_key, 0)
                        density_counts[ion][region_key] += 1
                        
                        # Compute displacement vectors from ion to all water oxygens.
                        diff = frame_o - ion_pos
                        dists = np.linalg.norm(diff, axis=1)
                        valid = dists <= ION_CUTOFFS[ion]
                        if not np.any(valid):
                            continue
                        valid_diff = diff[valid]
                        valid_dists = dists[valid]
                        cos_theta = valid_diff[:, 2] / valid_dists
                        cos_theta = np.clip(cos_theta, -1.0, 1.0)
                        angles = np.abs(180 - np.degrees(np.arccos(cos_theta)))
                        angle_data[ion][region_key].extend(angles.tolist())
    
    # --- Build histogram data for each (ion, region) meeting min_count ---
    histogram_data = []
    angle_bins = np.linspace(0, 180, bins + 1)
    for ion in IONS:
        for region_key, angles in angle_data[ion].items():
            if len(angles) < min_count:
                continue
            angles_arr = np.array(angles)
            counts, edges = np.histogram(angles_arr, bins=angle_bins)
            probability = counts / np.sum(counts)
            bin_centers = (edges[:-1] + edges[1:]) / 2
            histogram_data.append((ion, region_key, bin_centers, probability))
    
    # --- Generate gradient colors per ion and region ---
    ion_region_colors = {}
    for ion in IONS:
        valid_regions = sorted([rk for rk in angle_data[ion].keys() if len(angle_data[ion][rk]) >= min_count],
                               key=lambda x: x[0])
        n_regions = len(valid_regions)
        colors = {}
        if n_regions > 0:
            for idx, region_key in enumerate(valid_regions):
                fraction = (idx / (n_regions - 1) / 2) if n_regions > 1 else 0
                colors[region_key] = blend_color(ION_COLORS[ion], fraction)
        ion_region_colors[ion] = colors
    
    histogram_data_sorted = sorted(histogram_data,
                                   key=lambda x: (IONS.index(x[0]), -x[1][0]))
    
    # --- Plot the opening-angle histograms ---
    init_plot(xlabel="Opening angle (degrees)", ylabel="Probability (a.u.)",
              font_size=48, yticks_remove=True, grid=False)
    for ion, region_key, bin_centers, probability in histogram_data_sorted:
        color = ion_region_colors[ion].get(region_key, ION_COLORS[ion])
        if hydrogen_covered:
            # Plot as a hollow line: thin white base then colored overlay.
            line, = plt.plot(bin_centers, probability, linewidth=1, color='white')
            line.set_path_effects([pe.Stroke(linewidth=8, foreground=color), pe.Normal()])
        else:
            plt.plot(bin_centers, probability, linestyle='-', linewidth=8, marker=None,
                     color=color)
    ax = plt.gca()
    ax.set_xticks(np.arange(0, 181, 30))
    ax.set_xlim(0, 180)
    ax.set_ylim(0)
    show_plot(legend_font_size=32, legend_draw_bg=False, xticks=np.arange(0, 181, 30), show=False)
    plt.draw()  # Render the main plot
    
    # --- Compute ion density data (using process_density) ---
    # For homogeneous file sets, we assume files share the same hydrogen coverage state.
    ion_density_dict = {ion: [] for ion in IONS}
    density_bin_edges = np.linspace(2, 5, 51)  # 50 bins from 2 Å to 5 Å.
    density_bin_centers = 0.5 * (density_bin_edges[:-1] + density_bin_edges[1:])
    
    for file in files:
        sim = load_simulation(file)
        for ion in IONS:
            target_mask = np.array([elem == ion for elem in sim.ions])
            if not np.any(target_mask):
                continue
            target_positions = sim.trajectories.positions.ions[:, target_mask, 2]
            density, valid_frames = process_density(sim, target_positions, skip=0, bin_edges=density_bin_edges)
            if density is not None:
                if smoothing > 1:
                    density = np.convolve(density, np.ones(smoothing) / smoothing, mode='same')
                ion_density_dict[ion].append(density)
    for ion in IONS:
        if ion_density_dict[ion]:
            ion_density_dict[ion] = np.mean(ion_density_dict[ion], axis=0)
        else:
            ion_density_dict[ion] = None

    # --- Compute oxygen density (using water oxygen positions) ---
    oxygen_density_list = []
    for file in files:
        sim = load_simulation(file)
        if sim.trajectories.positions.watO is not None and sim.trajectories.positions.watO.size > 0:
            target_positions = sim.trajectories.positions.watO[:, :, 2]
            density, valid_frames = process_density(sim, target_positions, skip=0, bin_edges=density_bin_edges)
            if density is not None:
                if smoothing > 1:
                    density = np.convolve(density, np.ones(smoothing) / smoothing, mode='same')
                oxygen_density_list.append(density)
    if oxygen_density_list:
        oxygen_density = np.mean(oxygen_density_list, axis=0)
    else:
        oxygen_density = None

    # --- Scale oxygen density to match ion density ---
    max_ion_value = 0.0
    for ion in IONS:
        d = ion_density_dict[ion]
        if d is not None:
            max_ion_value = max(max_ion_value, np.max(d))
    if oxygen_density is not None and np.max(oxygen_density) > 0 and max_ion_value > 0:
        oxygen_density_scaled = oxygen_density * (max_ion_value / np.max(oxygen_density))
    else:
        oxygen_density_scaled = oxygen_density

    # --- Determine the "best" inset location without external modules ---
    fig = plt.gcf()
    main_pos = ax.get_position().bounds  # (left, bottom, width, height) in figure coordinates
    candidates = {
        "top_right": [main_pos[0] + 0.68 * main_pos[2], main_pos[1] + 0.62 * main_pos[3],
                      0.35 * main_pos[2], 0.35 * main_pos[3]],
        "top_left": [main_pos[0] + 0.05 * main_pos[2], main_pos[1] + 0.60 * main_pos[3],
                     0.35 * main_pos[2], 0.35 * main_pos[3]],
        "bottom_right": [main_pos[0] + 0.60 * main_pos[2], main_pos[1] + 0.05 * main_pos[3],
                         0.35 * main_pos[2], 0.35 * main_pos[3]],
        "bottom_left": [main_pos[0] + 0.05 * main_pos[2], main_pos[1] + 0.05 * main_pos[3],
                        0.35 * main_pos[2], 0.35 * main_pos[3]]
    }
    
    def candidate_score(candidate_box):
        fig_left, fig_bottom, fig_width, fig_height = candidate_box
        inv = ax.transData.inverted()
        data_ll = inv.transform((fig_left, fig_bottom))
        data_ur = inv.transform((fig_left + fig_width, fig_bottom + fig_height))
        x_min, y_min = data_ll
        x_max, y_max = data_ur
        total = 0
        count = 0
        for line in ax.lines:
            xdata = line.get_xdata()
            ydata = line.get_ydata()
            mask = (xdata >= x_min) & (xdata <= x_max)
            if np.any(mask):
                total += np.mean(ydata[mask])
                count += 1
        return total / count if count > 0 else 1e6

    candidate_scores = {key: candidate_score(box) for key, box in candidates.items()}
    best_candidate = min(candidate_scores, key=candidate_scores.get)
    best_box = candidates[best_candidate]
    
    # --- Create the inset axis ---
    ax_density = fig.add_axes(best_box, zorder=3)
    
    # Force a draw to get the inset's tight bounding box (including labels)
    fig.canvas.draw()
    renderer = fig.canvas.get_renderer()
    bbox = ax_density.get_tightbbox(renderer)
    bbox = bbox.transformed(fig.transFigure.inverted())
    
    # Expand the bounding box by a margin (e.g., 2% of figure dimensions)
    margin = 0.02
    new_x0 = bbox.x0
    new_y0 = bbox.y0 - 3 * margin
    new_width = bbox.width - 2 * margin
    new_height = bbox.height + 2 * margin
    
    # Add a white rectangle behind the inset covering the expanded area
    rect = Rectangle((new_x0, new_y0), new_width, new_height,
                     transform=fig.transFigure, facecolor='white', alpha=0.8,
                     edgecolor='none', zorder=2)
    fig.patches.append(rect)
    
    # --- Plot oxygen density in the inset ---
    if oxygen_density_scaled is not None and len(density_bin_centers) == len(oxygen_density_scaled):
        # If hydrogen is present, apply a hatch pattern to the oxygen density fill.
        if hydrogen_covered:
            poly = ax_density.fill_between(density_bin_centers, oxygen_density_scaled,
                                           color="black", alpha=0.2)
            poly.set_hatch('//')
        else:
            ax_density.fill_between(density_bin_centers, oxygen_density_scaled,
                                    color="black", alpha=0.2)
    
    # --- Plot the ion density curves in the inset and fill regions ---
    for ion in IONS:
        density = ion_density_dict[ion]
        if density is None:
            continue
        # Use special "hollow" line style if hydrogen coverage is detected.
        if hydrogen_covered:
            # First draw a thin white line.
            line, = ax_density.plot(density_bin_centers, density, linewidth=1, color='white')
            # Then overlay with a thicker colored stroke.
            line.set_path_effects([pe.Stroke(linewidth=8, foreground=ION_COLORS[ion]), pe.Normal()])
        else:
            ax_density.plot(density_bin_centers, density, linestyle='-', linewidth=8,
                            marker=None, color=ION_COLORS[ion])
        # For each region with sufficient angle data, fill the area under the curve.
        for region_key in angle_data[ion]:
            if len(angle_data[ion][region_key]) < min_count:
                continue
            region_min, region_max = region_key
            mask = (density_bin_centers >= region_min) & (density_bin_centers <= region_max)
            if np.any(mask):
                color = ion_region_colors[ion].get(region_key, ION_COLORS[ion])
                ax_density.fill_between(density_bin_centers[mask],
                                        density[mask],
                                        color=color)
    
    ax_density.set_xlabel("Distance to Surface (Å)", fontsize=28)
    ax_density.set_ylabel("ρ (a.u.)", fontsize=32)
    ax_density.set_yticks([])
    ax_density.tick_params(labelsize=32)
    ax_density.set_xticks(np.arange(2, 6, 1))
    ax_density.set_xlim(2, 5)
    ax_density.set_ylim(0)
    
    plt.show()
    
# --- Example usage ---
sim_ids = []
for ion in IONS:
    for i in range(2, 5):
        for h in ["", "_H"]:
            sim_ids.append(f"{ion}{i}{h}")

sim_ids = ["Li2", "Na2", "K2", "Cs2",
           "Li4_H", "Na4_H", "K4_H", "Cs4_H"]

# for sim_id in sim_ids:
#     print(f"Processing simulation: Pt111_{sim_id}")
#     plot_opening_angles([f"data/simulations/Pt111_{sim_id}.pkl"], bins=36, region_width=0.5)


In [None]:
# Function to plot the CN vs. distance to the surface for each ion type.

# Function to compute the local coordination number (CN) for an ion.
def compute_local_cn(ion_position, o_positions, o_density, rdf_bin_edges, cutoff, box_lengths=None):
    """
    For a single ion, compute the local RDF against O atoms and integrate it
    up to the cutoff to yield a coordination number (CN).

    Also, compute the average z of the O atoms within the cutoff.

    Parameters:
        ion_position (np.ndarray): (3,) array for the ion.
        o_positions (np.ndarray): (n_O, 3) positions of O atoms in the frame.
        o_density (float): Global oxygen density (n_O / simulation volume).
        rdf_bin_edges (np.ndarray): Bin edges for the RDF histogram.
        cutoff (float): Cutoff distance for integration.
        box_lengths (np.ndarray, optional): Box lengths (shape (3,)) for PBC.

    Returns:
        tuple: (cn, avg_z) where:
           cn   : Coordination number computed for the ion.
           avg_z: Average z-coordinate of those O atoms with r <= cutoff.
                 Returns (None, None) if no O atom falls within the cutoff.
    """
    diff = o_positions - ion_position
    if box_lengths is not None:
        diff = apply_periodic_boundary(diff, box_lengths)
    distances = np.linalg.norm(diff, axis=1)

    hist, _ = np.histogram(distances, bins=rdf_bin_edges)
    bin_width = rdf_bin_edges[1] - rdf_bin_edges[0]

    # Compute shell volumes: V_shell = (4/3)*pi*(r_out^3 - r_in^3)
    r_in = rdf_bin_edges[:-1]
    r_out = rdf_bin_edges[1:]
    shell_volumes = (4.0/3.0) * np.pi * (r_out**3 - r_in**3)

    with np.errstate(divide='ignore', invalid='ignore'):
        g_r = hist / (o_density * shell_volumes)
        g_r[np.isnan(g_r)] = 0.0

    bin_centers = 0.5 * (rdf_bin_edges[:-1] + rdf_bin_edges[1:])
    mask = bin_centers <= cutoff
    if not np.any(mask):
        return None, None
    cn = 4 * np.pi * o_density * np.sum((bin_centers[mask]**2) * g_r[mask] * bin_width)

    selected = distances <= cutoff
    if np.any(selected):
        avg_z = np.mean(o_positions[selected, 2])
    else:
        return None, None

    return cn, avg_z

# ------------------------------------------------------------------------------
def plot_cn_vs_distance(sim_files, default_cutoff=3.0, bins=100, max_distance=5.0, threshold=20):
    """
    Process simulation pickle files to compute and plot Coordination Number (CN) vs.
    distance to the surface, with oxygen density (normalized) shown as a shaded area.

    For each simulation:
        - For every ion (per ion type), compute its local CN by integrating the local RDF
          (using ion-specific cutoffs when available).
        - Compute the distance from the surface (average O z - surface_z).
        - Bin the CN per distance.
    Special treatment:
        - If hydrogen (adsorbate) coverage is detected in a simulation,
          the ion CN curves are plotted as hollow lines (thin white base overlaid by a colored stroke),
          and the oxygen density is drawn with a hatched fill pattern.
          
    Parameters:
        sim_files (list): List of simulation pickle file paths.
        default_cutoff (float): Default cutoff for CN integration (Å) if not specified per ion.
        bins (int): Number of bins for distance to surface.
        max_distance (float): Maximum distance (Å) for binning.
        threshold (int): Minimum count threshold for bin validity.
    """

    # Define binning for distance to surface.
    distance_bins = np.linspace(0, max_distance, bins + 1)
    bin_centers_distance = 0.5 * (distance_bins[:-1] + distance_bins[1:])
    
    # Dictionaries to accumulate results.
    # Keys are tuples: (ion_type, hydrogen_flag), where hydrogen_flag is True if adsorbates are present.
    results = {}

    # Separate accumulators for oxygen distances (for oxygen density) for bare and H-covered systems.
    oxygen_distances = {False: [], True: []}
    oxygen_frames = {False: 0, True: 0}
    global_area = None  # to store simulation area (from first file with cell dimensions)

    # Process each simulation file.
    for file_name in sim_files:
        sim = load_simulation(file_name)

        # Determine if hydrogen (adsorbate) coverage is present.
        hydrogen_present = (hasattr(sim.trajectories.positions, 'adsorbates') and
                            sim.trajectories.positions.adsorbates is not None and
                            sim.trajectories.positions.adsorbates.size > 0)

        # Determine the ion type from the Simulation attribute.
        ion_type = sim.ions[0] if isinstance(sim.ions, (list, tuple)) else sim.ions
        specific_cutoff = ION_CUTOFFS.get(ion_type, default_cutoff)
        rdf_bin_edges = np.linspace(0, specific_cutoff, 51)

        # Check that required atom positions are available.
        if sim.trajectories.positions.watO is None or sim.trajectories.positions.ions is None:
            print(f"Required atoms (water O or {ion_type}) not found in file {file_name}.")
            continue

        cell = sim.cell_dimensions
        if cell is None:
            print(f"No cell dimensions found in simulation file {file_name}; skipping PBC corrections.")
            box_lengths = None
            area = 1.0
        else:
            if np.ndim(cell) == 1:
                box_lengths = cell
            else:
                box_lengths = np.diag(cell)
            area = box_lengths[0] * box_lengths[1]
            if global_area is None:
                global_area = area

        volume = np.prod(box_lengths) if box_lengths is not None else 1.0
        # Global oxygen density from the number of water oxygen atoms in the first frame.
        o_count = sim.trajectories.positions.watO.shape[1]
        o_density_global = o_count / volume

        n_frames = sim.trajectories.positions.all.shape[0]
        n_ions = sim.trajectories.positions.ions.shape[1]

        # Temporary accumulators for CN per bin for this simulation.
        cn_sum = np.zeros(bins)
        counts = np.zeros(bins)

        # Loop over frames.
        for frame in range(n_frames):
            surface_z = sim.trajectories.surface_z[frame]
            o_positions = sim.trajectories.positions.watO[frame]   # shape: (n_water, 3)
            ion_positions = sim.trajectories.positions.ions[frame]   # shape: (n_ions, 3)

            # Accumulate oxygen distances for density.
            oxygen_distances_frame = o_positions[:, 2] - surface_z
            valid_mask = (oxygen_distances_frame >= distance_bins[0]) & (oxygen_distances_frame <= distance_bins[-1])
            if np.any(valid_mask):
                oxygen_distances[hydrogen_present].extend(oxygen_distances_frame[valid_mask])
                oxygen_frames[hydrogen_present] += 1

            # Process ions for CN.
            for ion_pos in ion_positions:
                result = compute_local_cn(ion_pos, o_positions, o_density_global, rdf_bin_edges, specific_cutoff, box_lengths)
                if result is None:
                    continue
                cn, avg_z = result
                dist_to_surface = ion_pos[2] - surface_z
                bin_idx = np.digitize(dist_to_surface, distance_bins) - 1
                if bin_idx < 0 or bin_idx >= bins:
                    continue
                cn_sum[bin_idx] += cn
                counts[bin_idx] += 1

        # Use a key that distinguishes between bare and hydrogen-covered simulations.
        key = (ion_type, hydrogen_present)
        if key in results:
            results[key][0] += cn_sum
            results[key][1] += counts
            results[key][3] += n_frames * n_ions
            results[key][4] += n_frames
        else:
            # Structure: [cn_sum_total, counts_total, specific_cutoff, total_ion_count, total_frames]
            results[key] = [cn_sum, counts, specific_cutoff, n_frames * n_ions, n_frames]

    # --- Compute and normalize oxygen density for each group ---
    # We compute separate oxygen densities for bare (False) and hydrogen covered (True) systems.
    oxygen_density_norm = {}
    for h_flag in [False, True]:
        if oxygen_distances[h_flag] and oxygen_frames[h_flag] > 0 and global_area is not None:
            counts_o, _ = np.histogram(oxygen_distances[h_flag], bins=distance_bins)
            bin_width = distance_bins[1] - distance_bins[0]
            oxygen_density = counts_o / (global_area * bin_width * oxygen_frames[h_flag])
            # Scale oxygen density (here we use the same normalization as before).
            if np.max(oxygen_density) > 0:
                oxygen_density_norm[h_flag] = 2 + oxygen_density * (8 / np.max(oxygen_density))
            else:
                oxygen_density_norm[h_flag] = oxygen_density
        else:
            oxygen_density_norm[h_flag] = None

    # --- Set up the plot ---
    init_plot(xlabel="Distance to Surface (Å)", ylabel="CN", yticks_remove=False, grid='y')

    # Plot oxygen density for each group with appropriate style.
    for h_flag in [False, True]:
        od = oxygen_density_norm[h_flag]
        if od is None:
            continue
        # For hydrogen-covered systems, use a hatch pattern.
        hatch_pattern = '//' if h_flag else None
        poly = plt.fill_between(bin_centers_distance, od, color="black", alpha=0.2)
        if hatch_pattern:
            poly.set_hatch(hatch_pattern)

    # Plot CN vs distance for each ion type, separately for bare and H-covered simulations.
    for (ion, h_flag), (cn_sum_total, counts_total, _, total_ion_count, total_frames) in results.items():
        # Calculate average CN per bin.
        avg_cn = np.divide(cn_sum_total, counts_total, out=np.zeros_like(cn_sum_total), where=counts_total > 0)
        # Determine average ions per frame.
        avg_ions = total_ion_count / total_frames  # average number of ions per frame
        sim_threshold = avg_ions * threshold

        # Create a validity mask for each bin.
        valid_mask = counts_total >= sim_threshold
        # Fill in gaps if needed.
        y_plot = avg_cn.copy()
        for i in range(1, len(y_plot)-1):
            if (not valid_mask[i]) and valid_mask[i-1] and valid_mask[i+1]:
                y_plot[i] = (y_plot[i-1] + y_plot[i+1]) / 2
                valid_mask[i] = True

        # Find contiguous segments of valid bins.
        valid_indices = np.where(valid_mask)[0]
        if valid_indices.size == 0:
            continue
        segments = np.split(valid_indices, np.where(np.diff(valid_indices) != 1)[0]+1)
        color = ION_COLORS.get(ion, "#4472C4")
        label = f"{ion}"
        # For each contiguous segment, plot the CN curve.
        for seg in segments:
            if h_flag:
                # For H-covered systems, use the special hollow-line style.
                # First plot a thin white line.
                line, = plt.plot(bin_centers_distance[seg], y_plot[seg], linewidth=1, color='white', label=label)
                # Then overlay a thicker colored line using a stroke.
                line.set_path_effects([pe.Stroke(linewidth=8, foreground=color), pe.Normal()])
            else:
                plt.plot(bin_centers_distance[seg], y_plot[seg],
                         linestyle='-', linewidth=8, marker=None, color=color, label=label)

    # --- Add square markers at 5 Å with bulk values ---
    bulk_values = {"Li": 4.3, "Na": 6.0, "K": 7.9, "Cs": 10.2}
    for (ion, _), _ in results.items():
        if ion in bulk_values:
            color = ION_COLORS.get(ion, "#4472C4")
            plt.plot(4.974, bulk_values[ion], marker='s', markersize=15,
                     linestyle='None', color=color, markeredgecolor='black')

    # Reverse the legend order to match the plot order.
    # Limit the legend to contain every unique label only once.
    handles, labels = plt.gca().get_legend_handles_labels()
    unique_labels = []
    unique_handles = []
    for i, label in enumerate(labels):
        if label not in unique_labels:
            unique_labels.append(label)
            unique_handles.append(handles[i])
    unique_labels.reverse()
    unique_handles.reverse()
    
    show_plot(xticks=np.arange(2, 5.5, 1), yticks=np.arange(3, 10, 2), 
              legend_handles=unique_handles, legend_labels=unique_labels, 
              xlim=(2, 5), ylim=(2, 10.5), legend_font_size=24)


# ------------------------------------------------------------------------------
# Run the analysis on the provided simulation pickle files.

plot_cn_vs_distance(["data/simulations/Pt111_Li2.pkl",
                     "data/simulations/Pt111_Na2.pkl",
                     "data/simulations/Pt111_K2.pkl",
                     "data/simulations/Pt111_Cs2.pkl"], threshold=1)

# plot_cn_vs_distance(["data/simulations/Pt111_Li2.pkl",
#                      "data/simulations/Pt111_Na2.pkl",
#                      "data/simulations/Pt111_K2.pkl",
#                      "data/simulations/Pt111_Cs2.pkl",
#                      "data/simulations/Pt111_Li1_H.pkl",
#                      "data/simulations/Pt111_Na1_H.pkl",
#                      "data/simulations/Pt111_K1_H.pkl",
#                      "data/simulations/Pt111_Cs1_H.pkl"])

# plot_cn_vs_distance(["data/simulations/Pt111_Li4_H.pkl",
#                      "data/simulations/Pt111_Na4_H.pkl",
#                      "data/simulations/Pt111_K4_H.pkl",
#                      "data/simulations/Pt111_Cs4_H.pkl"])

In [None]:
# Function to compute the continous coordination number (CCN) vs. distance to the surface for each ion type.

# Function to compute the local continuous coordination number (CN) for an ion.
def compute_local_ccn(ion_position, o_positions, o_density, rdf_bin_edges, cutoff, box_lengths=None, ccn_d=0.1):
    """
    For a single ion, compute the local continuous coordination number (CN) using a switching function.
    Instead of a hard cutoff (i.e. counting only water molecules with r <= cutoff), each water molecule
    contributes with a weight given by:
    
        f(r) = 1 / (1 + exp((r - cutoff) / d))
    
    where r is the distance from the ion to a water O atom, cutoff is the characteristic distance, and d
    controls the smoothness of the transition. The continuous CN is the sum over all water molecules.
    
    Additionally, the function computes a weighted average of the z-coordinate of the O atoms using the same weights.
    
    Parameters:
        ion_position (np.ndarray): (3,) array for the ion position.
        o_positions (np.ndarray): (n_O, 3) positions of water O atoms.
        o_density (float): Global oxygen density (unused in this formulation; kept for compatibility).
        rdf_bin_edges (np.ndarray): Bin edges for the RDF histogram (unused here; kept for compatibility).
        cutoff (float): Characteristic distance (Å) for the switching function.
        box_lengths (np.ndarray, optional): Box lengths (shape (3,)) for periodic boundary corrections.
        ccn_d (float, optional): Smoothing parameter (default 0.1 Å) controlling the transition width.
        
    Returns:
        tuple: (cn, avg_z) where:
           cn   : Continuous coordination number computed for the ion.
           avg_z: Weighted average z-coordinate of O atoms using the switching weights.
                  Returns (None, None) if no O atom contributes significantly.
    """
    # Compute displacement vectors and apply periodic boundary conditions if necessary.
    diff = o_positions - ion_position
    if box_lengths is not None:
        diff = apply_periodic_boundary(diff, box_lengths)
    
    # Compute distances from the ion to each water O atom.
    distances = np.linalg.norm(diff, axis=1)
    
    # Evaluate the switching function for each water molecule.
    weights = 1.0 / (1.0 + np.exp((distances - cutoff) / ccn_d))
    
    # The continuous CN is the sum of these weights.
    cn = np.sum(weights)
    
    # Compute the weighted average of the z-coordinate.
    if np.sum(weights) > 0:
        avg_z = np.sum(o_positions[:, 2] * weights) / np.sum(weights)
    else:
        return None, None
         
    return cn, avg_z


# Custom legend handler that adds a black outline only if the handle has _outline=True.
class HandlerMaybeOutlinedLine2D(HandlerLine2D):
    def create_artists(self, legend, orig_handle,
                       x0, y0, width, height, fontsize, trans):
        # If _outline is True, draw an outline.
        if getattr(orig_handle, '_outline', False):
            outline_line = Line2D([x0, x0+width],
                                  [y0+height/2, y0+height/2],
                                  linestyle=orig_handle.get_linestyle(),
                                  color='black',
                                  linewidth=orig_handle.get_linewidth() + 2,
                                  marker=orig_handle.get_marker())
            outline_line.set_transform(trans)
            line = Line2D([x0, x0+width],
                          [y0+height/2, y0+height/2],
                          linestyle=orig_handle.get_linestyle(),
                          color=orig_handle.get_color(),
                          linewidth=orig_handle.get_linewidth(),
                          marker=orig_handle.get_marker())
            line.set_transform(trans)
            return [outline_line, line]
        else:
            # Otherwise, return a standard line.
            line = Line2D([x0, x0+width],
                          [y0+height/2, y0+height/2],
                          linestyle=orig_handle.get_linestyle(),
                          color=orig_handle.get_color(),
                          linewidth=orig_handle.get_linewidth(),
                          marker=orig_handle.get_marker())
            line.set_transform(trans)
            return [line]

def plot_ccn_vs_distance(sim_files, skip=5, default_cutoff=3.0, bins=100,
                         max_distance=5.0, ccn_d=0.1, threshold=1):
    """
    Process simulation pickle files to compute and plot Coordination Number (CN) vs.
    distance to the surface, with oxygen density (normalized) shown as a shaded area.
    
    For each simulation:
      - For every ion (per ion type), compute its local continuous CN using a switching function.
      - Compute the distance from the surface (weighted average O z - surface_z).
      - Bin the CN per distance.
      
    Special treatment:
      - For hydrogen (adsorbate) covered systems, the ion CN curves are drawn with a black outline:
          a black line (with a slightly larger linewidth) is drawn behind the variable-width colored line.
      - For bare surfaces the colored curve remains unchanged.
      - The legend will only display a black outline for ions that were plotted with a black outline.
    
    In this updated version, bins with zero data are omitted. The valid data is then globally interpolated
    (from the first to the last valid bin) onto a very dense grid. The resulting segments are drawn with round
    join and cap styles so that no segment boundaries are visible.
    
    Parameters:
      sim_files (list): List of simulation pickle file paths.
      skip (int): Number of ps to skip at the beginning of each simulation.
      default_cutoff (float): Default cutoff for the switching function (Å) if not specified per ion.
      bins (int): Number of bins for distance to surface.
      max_distance (float): Maximum distance (Å) for binning.
      threshold (int): Minimum count threshold for bin validity.
    """
    
    # Define binning for distance to surface.
    distance_bins = np.linspace(0, max_distance, bins + 1)
    bin_centers_distance = 0.5 * (distance_bins[:-1] + distance_bins[1:])
    
    # Dictionaries to accumulate results.
    # Keys are tuples: (ion_type, hydrogen_flag)
    results = {}
    
    # Separate accumulators for oxygen distances for density (bare and H-covered).
    oxygen_distances = {False: [], True: []}
    oxygen_frames = {False: 0, True: 0}
    global_area = None  # stored from first file with cell dimensions
    
    # Process each simulation file.
    for file_name in sim_files:
        sim = load_simulation(file_name)
        
        # Determine if hydrogen adsorbates are present.
        hydrogen_present = (hasattr(sim.trajectories.positions, 'adsorbates') and
                            sim.trajectories.positions.adsorbates is not None and
                            sim.trajectories.positions.adsorbates.size > 0)
        
        # Get ion type and cutoff.
        ion_type = sim.ions[0] if isinstance(sim.ions, (list, tuple)) else sim.ions
        specific_cutoff = ION_CUTOFFS.get(ion_type, default_cutoff)
        rdf_bin_edges = np.linspace(0, specific_cutoff, 51)  # unused here
        
        # Check required atom positions.
        if sim.trajectories.positions.watO is None or sim.trajectories.positions.ions is None:
            print(f"Required atoms (water O or {ion_type}) not found in file {file_name}.")
            continue
        
        cell = sim.cell_dimensions
        if cell is None:
            print(f"No cell dimensions found in {file_name}; skipping PBC corrections.")
            box_lengths = None
            area = 1.0
        else:
            box_lengths = cell if np.ndim(cell) == 1 else np.diag(cell)
            area = box_lengths[0] * box_lengths[1]
            if global_area is None:
                global_area = area
        
        volume = np.prod(box_lengths) if box_lengths is not None else 1.0
        o_count = sim.trajectories.positions.watO.shape[1]
        o_density_global = o_count / volume
        
        n_frames = sim.trajectories.positions.all.shape[0]
        n_ions = sim.trajectories.positions.ions.shape[1]
        
        # Temporary accumulators for this simulation.
        cn_sum = np.zeros(bins)
        counts = np.zeros(bins)
        
        for frame in range(n_frames):
            if frame * sim.timestep < (skip * 1000):
                continue
            
            surface_z = sim.trajectories.surface_z[frame]
            o_positions = sim.trajectories.positions.watO[frame]  # (n_water, 3)
            ion_positions = sim.trajectories.positions.ions[frame]  # (n_ions, 3)
            
            # Accumulate oxygen distances.
            oxygen_distances_frame = o_positions[:, 2] - surface_z
            valid_mask = (oxygen_distances_frame >= distance_bins[0]) & (oxygen_distances_frame <= distance_bins[-1])
            if np.any(valid_mask):
                oxygen_distances[hydrogen_present].extend(oxygen_distances_frame[valid_mask])
                oxygen_frames[hydrogen_present] += 1
            
            # Process each ion.
            for ion_pos in ion_positions:
                result = compute_local_ccn(ion_pos, o_positions, o_density_global,
                                           rdf_bin_edges, specific_cutoff, box_lengths, ccn_d=ccn_d)
                if result is None:
                    continue
                cn, avg_z = result
                dist_to_surface = ion_pos[2] - surface_z
                bin_idx = np.digitize(dist_to_surface, distance_bins) - 1
                if bin_idx < 0 or bin_idx >= bins:
                    continue
                cn_sum[bin_idx] += cn
                counts[bin_idx] += 1
        
        key = (ion_type, hydrogen_present)
        if key in results:
            results[key][0] += cn_sum
            results[key][1] += counts
            results[key][3] += n_frames * n_ions
            results[key][4] += n_frames
        else:
            # Structure: [total_cn_sum, total_counts, cutoff, total_ion_count, total_frames]
            results[key] = [cn_sum, counts, specific_cutoff, n_frames * n_ions, n_frames]
    
    # Compute and normalize oxygen density for each group.
    oxygen_density_norm = {}
    for h_flag in [False, True]:
        if oxygen_distances[h_flag] and oxygen_frames[h_flag] > 0 and global_area is not None:
            counts_o, _ = np.histogram(oxygen_distances[h_flag], bins=distance_bins)
            bin_width = distance_bins[1] - distance_bins[0]
            oxygen_density = counts_o / (global_area * bin_width * oxygen_frames[h_flag])
            if oxygen_density.max() > 0:
                oxygen_density_norm[h_flag] = 2 + oxygen_density * (8 / oxygen_density.max())
            else:
                oxygen_density_norm[h_flag] = oxygen_density
        else:
            oxygen_density_norm[h_flag] = None
    
    # Set up plot.
    init_plot(xlabel="Distance to Surface (Å)", ylabel="CN", yticks_remove=False,
              grid='y', font_size=32, figsize=(12,6))
    
    for h_flag in [False, True]:
        od = oxygen_density_norm[h_flag]
        if od is None:
            continue
        hatch_pattern = '//' if h_flag else None
        poly = plt.fill_between(bin_centers_distance, od, color="black", alpha=0.2)
        if hatch_pattern:
            poly.set_hatch(hatch_pattern)
    
    # Collect legend info.
    legend_dict = {}  # mapping ion -> (color, h_flag)
    
    # Plot CN vs. distance as one continuous line.
    for (ion, h_flag), (cn_sum_total, counts_total, _, total_ion_count, total_frames) in results.items():
        # Average CN per bin.
        avg_cn = np.divide(cn_sum_total, counts_total, out=np.zeros_like(cn_sum_total), where=counts_total > 0)
        
        # Select only bins with data.
        valid = counts_total > 0
        if not np.any(valid):
            continue
        x_valid = bin_centers_distance[valid]
        y_valid = avg_cn[valid]
        
        # Smooth counts for linewidth.
        smoothed = np.convolve(counts_total, np.ones(3)/3, mode='same')[valid]
        if smoothed.max() > 0:
            norm = (smoothed - smoothed.min()) / (smoothed.max() - smoothed.min() + 1e-8)
        else:
            norm = smoothed
        baseline = 1
        max_lw = 12
        lw_valid = baseline + (max_lw - baseline) * norm
        
        # Global interpolation: from the first valid to the last valid bin.
        num_dense = 100  # very dense for smooth appearance
        x_dense = np.linspace(x_valid[0], x_valid[-1], num_dense)
        y_dense = np.interp(x_dense, x_valid, y_valid)
        lw_dense = np.interp(x_dense, x_valid, lw_valid)
        
        # Build segments.
        points = np.array([x_dense, y_dense]).T.reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)
        # For each segment, average the linewidth at the endpoints.
        lw_segments = 0.5 * (lw_dense[:-1] + lw_dense[1:])
        
        color = ION_COLORS.get(ion, "#4472C4")
        ax = plt.gca()
        if h_flag:
            # For covered systems, draw a black outline behind the colored line.
            outline_width = lw_segments + 2.0  # adjust offset as needed
            lc_outline = LineCollection(segments, linewidths=outline_width, colors='black',
                                        capstyle='round', joinstyle='round')
            ax.add_collection(lc_outline)
            lc = LineCollection(segments, linewidths=lw_segments, colors=color,
                                capstyle='round', joinstyle='round')
            ax.add_collection(lc)
        else:
            # Bare surfaces remain unchanged.
            lc = LineCollection(segments, linewidths=lw_segments, colors=color,
                                capstyle='round', joinstyle='round')
            ax.add_collection(lc)
        
        # Record the legend info.
        # If an ion appears in both hydrogen-covered and bare systems, we favor the outlined version.
        if ion not in legend_dict or h_flag:
            legend_dict[ion] = (color, h_flag)
    
    # Build custom legend handles.
    custom_handles = []
    custom_labels = []
    # Use a representative linewidth for the legend entries.
    legend_lw = max_lw  
    for ion, (color, h_flag) in legend_dict.items():
        handle = Line2D([], [], color=color, lw=legend_lw, linestyle='-')
        # Mark the handle with _outline if hydrogen coverage is present.
        if h_flag:
            handle._outline = True
        custom_handles.append(handle)
        custom_labels.append(ion)

    custom_handles.reverse()
    custom_labels.reverse()
    
    show_plot(xlim=(2, 5), xticks=np.arange(2, 5.25, 0.5), 
                    ylim=(2, 10.5), yticks=np.arange(3, 10, 2),
                    legend_handles=custom_handles, legend_labels=custom_labels,
                    legend_font_size=24, legend_loc='upper left',
                    legend_handler_map={Line2D: HandlerMaybeOutlinedLine2D()})

# ------------------------------------------------------------------------------
# Run the analysis on the provided simulation pickle files.

# plot_ccn_vs_distance(["data/simulations/Pt111_Li2.pkl",
#                      "data/simulations/Pt111_Na2.pkl",
#                      "data/simulations/Pt111_K2.pkl",
#                      "data/simulations/Pt111_Cs2.pkl"])

# plot_ccn_vs_distance(["data/simulations/Pt111_Li4_H.pkl",
#                      "data/simulations/Pt111_Na4_H.pkl",
#                      "data/simulations/Pt111_K4_H.pkl",
#                      "data/simulations/Pt111_Cs4_H.pkl"])


In [None]:
# ------------------------
# Function to compute the hollow coordination number (CN) for an imaginary ion.
def compute_local_hollow_cn_vectorized(ion_positions, o_positions, inner_radius=2.75, outer_radius=4.25, box_lengths=None):
    """
    Compute the coordination numbers (CN) for a set of ion positions (vectorized)
    by counting water oxygen atoms within a hollow spherical shell defined by
    inner and outer radii.

    Parameters:
        ion_positions (np.ndarray): Array of shape (N, 3) for N ion positions.
        o_positions (np.ndarray): Array of shape (n, 3) for water oxygen positions.
        inner_radius (float): Inner radius of the spherical shell (Å).
        outer_radius (float): Outer radius of the spherical shell (Å).
        box_lengths (np.ndarray, optional): Simulation cell lengths for PBC correction.
    
    Returns:
        np.ndarray: Array of length N containing the CN for each ion position.
    """
    # Compute differences: result shape (N, n, 3)
    diff = o_positions[None, :, :] - ion_positions[:, None, :]
    if box_lengths is not None:
        diff = diff - box_lengths * np.round(diff / box_lengths)
    # Compute distances for each grid point and oxygen atom: shape (N, n)
    distances = np.linalg.norm(diff, axis=2)
    # Count atoms within the hollow sphere for each ion position
    mask = (distances >= inner_radius) & (distances <= outer_radius)
    cn_values = np.sum(mask, axis=1)
    return cn_values

# ------------------------
# Main function to compute and plot average CN vs. distance from the surface.
def plot_hollow_cn_vs_distance(sim_file, ion='Cs', d_min=2.0, d_max=5.0, n_d=50, grid_resolution=10):
    """
    For each frame in the simulation, for each distance d (from d_min to d_max, in Å)
    above the metal surface, place an imaginary ion at many (x,y) grid points within the cell.
    The ion position is defined as (x, y, surface_z + d). The coordination number (CN)
    is computed using a hollow spherical shell (2.75–4.25 Å) in a vectorized fashion and then
    averaged over all grid points and frames. Additionally, the oxygen density profile is computed.

    Parameters:
        sim_file (str): Path to the simulation file (XYZ or pickle format).
        ion (str): Ion type (Li, Na, K, Cs) to use for the CN calculation.
        d_min (float): Minimum distance above the metal surface (Å).
        d_max (float): Maximum distance above the metal surface (Å).
        n_d (int): Number of distance bins between d_min and d_max.
        grid_resolution (int): Number of grid points per dimension in the x-y plane.
    """
    if ion not in ["Li", "Na", "K", "Cs"] or ion == 'Cs':
        ion = "Cs"
        irad, orad = 2.75, 4.25
        o_start = 2.01
        ylim=(o_start, o_start+6.5)
        print("Invalid ion type; defaulting to Cs.")
    elif ion == "Li":
        irad, orad = 1.75, 2.75
        o_start = 0.01
        ylim=(o_start, o_start+3)
    elif ion == "Na":
        irad, orad = 2.0, 3.2
        o_start = 1.01
        ylim=(o_start, o_start+3.5)
    else: # K
        irad, orad = 2.5, 3.8
        o_start = 2.51
        ylim=(o_start, o_start+3.5)

    # Initialize the simulation.
    sim = load_simulation(sim_file)
    n_frames = sim.trajectories.positions.all.shape[0]
    
    # Ensure cell dimensions are available.
    if sim.cell_dimensions is not None:
        cell = np.array(sim.cell_dimensions)
    else:
        raise ValueError("Cell dimensions are required for PBC correction and grid positioning.")
    
    # Define grid points in the x-y plane (grid centers)
    x_centers = np.linspace(0, cell[0], grid_resolution, endpoint=False) + cell[0] / (2 * grid_resolution)
    y_centers = np.linspace(0, cell[1], grid_resolution, endpoint=False) + cell[1] / (2 * grid_resolution)
    # Create a mesh grid and reshape into (n_grid, 2)
    xv, yv = np.meshgrid(x_centers, y_centers, indexing='ij')
    grid_xy = np.column_stack((xv.ravel(), yv.ravel()))
    n_grid = grid_xy.shape[0]
    
    # Create an array of d values (distances above the surface)
    d_values = np.linspace(d_min, d_max, n_d)
    cn_accum = np.zeros(n_d)
    counts = np.zeros(n_d)
    
    # --- For oxygen density (optional) ---
    oxygen_distances_all = []
    total_oxygen_frames = 0
    distance_bins = np.linspace(d_min, d_max, n_d + 1)
    
    # Loop over frames.
    for frame in range(n_frames):
        # print the frame number of every 10% of the frames
        if frame % (n_frames // 4) == 0:
            print(f"Processing frame {frame} of {n_frames}")

        surface_z = sim.trajectories.surface_z[frame]
        o_positions = sim.trajectories.positions.watO[frame]
        
        # Accumulate oxygen z distances relative to surface.
        oxy_z = o_positions[:, 2] - surface_z
        valid_mask = (oxy_z >= d_min) & (oxy_z <= d_max)
        if np.any(valid_mask):
            oxygen_distances_all.extend(oxy_z[valid_mask])
            total_oxygen_frames += 1
        
        # For each d, vectorize over grid points.
        for i, d in enumerate(d_values):
            # Construct ion positions for all grid points: shape (n_grid, 3)
            # x and y from grid; z = surface_z + d for all points.
            ion_positions = np.hstack((grid_xy, np.full((n_grid, 1), surface_z + d)))
            # Compute CN vectorized for all grid points.
            cn_grid = compute_local_hollow_cn_vectorized(ion_positions, o_positions,
                                                         inner_radius=irad, outer_radius=orad,
                                                         box_lengths=cell)
            # Average CN for this d in the current frame.
            avg_cn_frame = np.mean(cn_grid)
            cn_accum[i] += avg_cn_frame
            counts[i] += 1
    
    # Average over frames.
    avg_cn = cn_accum / counts

    # --- Compute oxygen density (optional shading) ---
    oxygen_density_norm = None
    if oxygen_distances_all and total_oxygen_frames > 0:
        counts_o, _ = np.histogram(oxygen_distances_all, bins=distance_bins)
        bin_width = distance_bins[1] - distance_bins[0]
        global_area = cell[0] * cell[1]
        oxygen_density = counts_o / (global_area * bin_width * total_oxygen_frames)
        if np.max(oxygen_density) > 0:
            oxygen_density_norm = o_start + oxygen_density * ((ylim[1]-ylim[0]) / np.max(oxygen_density))
        bin_centers_distance = 0.5 * (distance_bins[:-1] + distance_bins[1:])
    else:
        bin_centers_distance = None

    # --- Plot the results ---
    init_plot(xlabel="Distance from Surface (Å)", ylabel="CN")
    if oxygen_density_norm is not None and bin_centers_distance is not None:
        plt.fill_between(bin_centers_distance, oxygen_density_norm, color="black", alpha=0.2)
    plt.plot(d_values, avg_cn, linestyle='-', linewidth=8, color="black")
    show_plot(xticks=np.arange(2, 10.5, 1), yticks=np.arange(1, 10, 1), 
              xlim=(2, 10), ylim=ylim, legend_font_size=32)

# ------------------------
# Run the code when executed as a script.
if __name__ == "__main__":
    simulation_file = "data/simulations/Pt111_Cs4_H.pkl"
    plot_hollow_cn_vs_distance(simulation_file, ion='Cs', d_max=10, grid_resolution=1)

In [None]:
# Function to compute water molecule residence time in a cation solvation sphere

def compute_water_residence_time(sim_file, ion_type, skip_time=0):
    """
    Compute the average residence time (lifetime) of water molecules within the solvation sphere of a cation.
    
    The function:
      1. Loads the simulation from sim_file and selects only the cations of the specified type (using sim_data.ions and ION_CUTOFFS).
      2. For each frame (after skipping frames with time < skip_time), it checks whether each water molecule
         (water oxygen, given by water_label) is inside the solvation sphere of any of the selected ions.
      3. Builds an occupancy time series (True if a water molecule is within the cutoff, else False).
      4. For each water molecule, finds the lengths of contiguous "in-shell" segments and converts these lengths
         into time units (using the simulation timestep, assumed constant).
      5. Returns the average residence time and a list of all residence durations (in ps).

    Parameters:
        sim_file (str): Path to the simulation file.
        ion_type (str): The ion element (e.g., "Cs", "K", etc.) for which the solvation sphere is defined.
        water_label (str): Attribute name for water oxygen positions (default "watO").
        skip_time (float): Skip simulation frames with time < skip_time (in ps).

    Returns:
        avg_lifetime (float): Average residence time (in ps) of water molecules in the solvation sphere.
        lifetimes (list): List of all residence time durations (in ps) measured over the trajectory.
    """
    # Load simulation data from the file.
    sim_data = load_simulation(sim_file)

    # Get the cutoff for the given ion type (from global ION_CUTOFFS)
    cutoff = ION_CUTOFFS.get(ion_type, None)
    if cutoff is None:
        raise ValueError(f"No cutoff defined for ion type {ion_type}.")
    
    # Determine indices of the ions of the given type.
    ion_indices = [i for i, ion in enumerate(sim_data.ions) if ion == ion_type]
    if not ion_indices:
        raise ValueError(f"No ions of type {ion_type} found in simulation.")

    # Extract water oxygen positions. Shape: (n_frames, n_water, 3)
    try:
        water_positions = getattr(sim_data.trajectories.positions, "watO")
    except AttributeError:
        raise ValueError(f"Water positions with label 'watO' not found in simulation.")

    # Extract positions of the selected ions. Shape: (n_frames, n_selected_ions, 3)
    ion_positions_all = sim_data.trajectories.positions.ions[:, ion_indices, :]
    
    # Get simulation time array and determine valid frames (after skip_time).
    times = sim_data.trajectories.times
    valid_frame_indices = np.where(times >= skip_time)[0]
    if valid_frame_indices.size == 0:
        raise ValueError("No frames available after applying skip_time.")
    
    # Restrict water and ion positions to valid frames.
    water_positions = water_positions[valid_frame_indices]  # shape: (n_valid, n_water, 3)
    ion_positions_all = ion_positions_all[valid_frame_indices]  # shape: (n_valid, n_selected, 3)
    valid_times = times[valid_frame_indices]
    
    n_frames = water_positions.shape[0]
    n_water = water_positions.shape[1]
    
    # Create occupancy matrix: occupancy[i, j] is True if water molecule j is within the cutoff
    # of ANY selected ion in frame i.
    occupancy = np.zeros((n_frames, n_water), dtype=bool)
    
    # Loop over frames and mark occupancy.
    # For each frame, compute pairwise distances between water molecules and selected ions.
    for i in range(n_frames):
        # water_positions[i]: shape (n_water, 3)
        # ion_positions_all[i]: shape (n_selected, 3)
        # Compute pairwise differences using broadcasting.
        diff = water_positions[i][:, np.newaxis, :] - ion_positions_all[i][np.newaxis, :, :]
        
        # Apply periodic boundary corrections using the cell dimensions from sim_data.
        if sim_data.cell_dimensions is not None:
            cell = np.array(sim_data.cell_dimensions)
            # If cell is a matrix, extract the diagonal (assuming box is aligned with axes).
            box_lengths = np.diag(cell) if cell.ndim > 1 else cell
            diff = diff - box_lengths * np.round(diff / box_lengths)
        
        # Compute Euclidean distances.
        distances = np.linalg.norm(diff, axis=-1)  # shape: (n_water, n_selected)
        # For each water molecule, mark True if any distance is less than cutoff.
        occupancy[i, :] = np.any(distances < cutoff, axis=1)
    
    # Determine the simulation timestep.
    if hasattr(sim_data, "timestep") and sim_data.timestep is not None:
        dt = sim_data.timestep
    else:
        dt_array = np.diff(valid_times)
        dt = np.mean(dt_array) if dt_array.size > 0 else 1.0

    # Helper function: given a 1D boolean array, return a list of durations (in ps) of contiguous True segments.
    def get_contiguous_durations(bool_array):
        durations = []
        current_length = 0
        for val in bool_array:
            if val:
                current_length += 1
            else:
                if current_length > 0:
                    durations.append(current_length * dt)
                    current_length = 0
        if current_length > 0:
            durations.append(current_length * dt)
        return durations

    # Compute residence times for each water molecule.
    all_lifetimes = []
    for j in range(n_water):
        water_occ = occupancy[:, j]
        durations = get_contiguous_durations(water_occ)
        all_lifetimes.extend(durations)
    
    # Compute average lifetime (if any residence events found)
    if all_lifetimes:
        avg_lifetime = np.mean(all_lifetimes)
    else:
        avg_lifetime = 0.0
        print("No water molecule residence events found for ion type", ion_type)
    
    return avg_lifetime, all_lifetimes

# ----- Example usage -----


# for ion in IONS:
#     for i in range(1, 5):
#         for h in ["", "_H"]:
#             sim_file = f"data/simulations/Pt111_{ion}{i}{h}.pkl"
#             avg_lifetime, lifetimes = compute_water_residence_time(sim_file, ion)
#             print(f"Ion: {ion}{i}{h}, Average Residence Time: {avg_lifetime:.2f} ps")

# Compute the average lifetime per ion type.
# ion_lifetimes = {}
# for ion in IONS:
#     lifetimes = []
#     for i in range(1, 5):
#         for h in ["", "_H"]:
#             sim_file = f"data/simulations/Pt111_{ion}{i}{h}.pkl"
#             avg_lifetime, _ = compute_water_residence_time(sim_file, ion)
#             lifetimes.append(avg_lifetime)
#     ion_lifetimes[ion] = np.mean(lifetimes)
# print("Average residence times (fs) per ion type:", ion_lifetimes)
