Main Jupyter Notebook that contains all codes to generate the plots used in the paper by R.S. Kort, A. Hagopian, K. Doblhoff-Dier, M. Koper about the Impact of Applied Potential and Hydrogen Coverage on Alkali Metal Cation Behaviour.





In [None]:
# Import all the required libraries
import numpy as np
import logging
import matplotlib.pyplot as plt
import pickle

In [None]:
# Define some parameters
IONS = ["Li", "Na", "K", "Cs"]
ION_COLORS = {"Li": "#70AD47", "Na": "#FFC000", "K": "#A032A0", "Cs": "#4472C4"}
ION_CUTOFFS = {"Li": 2.75, "Na": 3.2, "K": 3.8, "Cs": 4.25}

# Define same general functions
# -----------------------------
def blend_color(hex_color, fraction):
    """
    Blend the given hex color with white. fraction=0 returns the original color,
    fraction=1 returns white.
    """
    hex_color = hex_color.lstrip('#')
    r = int(hex_color[0:2], 16)
    g = int(hex_color[2:4], 16)
    b = int(hex_color[4:6], 16)
    r_new = int(r + (255 - r) * fraction)
    g_new = int(g + (255 - g) * fraction)
    b_new = int(b + (255 - b) * fraction)
    return f"#{r_new:02X}{g_new:02X}{b_new:02X}"

def init_plot(
    xlabel: str = "X",
    ylabel: str = "Y",
    font_size: int = 60,
    font_family: str = 'Times New Roman',
    figsize: tuple = (12, 8),
    yticks_remove: bool = False,
    grid: bool = False,
    tight_layout: bool = True
):
    """
    Customize the current matplotlib plot with various styling options.
    
    Parameters:
      xlabel (str): Label for the x-axis.
      ylabel (str): Label for the y-axis.
      font_size (int): Font size for the labels.
      font_family (str): Font family for the labels.
      figsize (tuple): Size of the figure.
      yticks_remove (bool): Whether to remove y-axis ticks.
      grid (bool): Whether to display a grid.
      tight_layout (bool): Whether to use tight layout.
    """
  
    # Create a new figure with the specified size
    plt.figure(figsize=figsize)

    # Set font properties
    plt.rcParams.update({'font.size': font_size, 'font.family': font_family})

    # Set axis labels
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    
    # Remove y-axis ticks
    if yticks_remove:
        plt.gca().axes.get_yaxis().set_ticks([])
    
    # Hide the top, right, and left spines
    for spine in ['top', 'right', 'left']:
        plt.gca().spines[spine].set_visible(False)
       
    # Configure grid
    plt.grid(grid)
    
    # Apply tight layout if requested
    if tight_layout:
        plt.tight_layout()

def show_plot(legend_font_size: int = 36,
              legend_frameon: bool = False,
              xticks: list = None,
              xlim: tuple = None,
              ylim = None,
              legend_handles: list = None,
              legend_labels: list = None):
    """
    Display the current matplotlib plot with options for custom legend handles and labels.
    If custom legend_handles and legend_labels are provided, they will be used in the given order,
    preserving the plotting order (e.g. the 'darkest' plot on top, plotted last).

    Parameters:
        legend_font_size (int): Font size for the legend.
        legend_frameon (bool): Whether to display a frame around the legend.
        xticks (list, optional): Tick locations for the x-axis.
        xlim (tuple, optional): Limits for the x-axis as (min, max). Use None for automatic limit.
        ylim (tuple or float, optional): Limits for the y-axis. If a single number is provided, it sets the bottom limit.
        legend_handles (list, optional): Custom legend handles to use.
        legend_labels (list, optional): Custom legend labels to use.
    """
    # Set x-ticks and x-axis limits if provided.
    if xticks is not None:
        plt.xticks(xticks)
    if xlim is not None:
        plt.xlim(xlim)
    if ylim is not None:
        # If ylim is a single number, set only the lower limit.
        if not isinstance(ylim, (tuple, list)):
            current_ylim = plt.ylim()
            ylim = (ylim, current_ylim[1])
        plt.ylim(ylim)
    
    plt.rcParams.update({'font.size': legend_font_size})
    
    # Use custom legend handles and labels if provided.
    if legend_handles is not None and legend_labels is not None:
        plt.legend(legend_handles, legend_labels, frameon=legend_frameon)
    else:
        plt.legend(frameon=legend_frameon)
    
    plt.tight_layout()
    plt.show()



def setup_logging(verbose=False):
    """
    Configure logging based on verbosity.

    Parameters:
        verbose (bool): If True, set logging level to DEBUG, else INFO.
    """
    log_level = logging.DEBUG if verbose else logging.INFO
    logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=log_level)



In [None]:
# Function to load the Simulation object from a pickle file.
def load_simulation(filename):
    """
    Load a Simulation object from a pickle file.

    Parameters:
        filename (str): The path to the pickle file.

    Returns:
        Simulation: The loaded Simulation object.
    """
    try:
        with open(filename, 'rb') as pf:
            simulation = pickle.load(pf)
        logging.debug("Successfully loaded Simulation from %s", filename)
        return simulation
    except Exception as e:
        logging.error("Failed to load Simulation from %s: %s", filename, e)
        raise

"""    
# Load a Simulation object from a pickle file and print some metadata.
# Example usage:
pickle_file = "data/simulations/Pt111_Cs1.pkl"
sim = load_simulation(pickle_file)

# Access simulation-level metadata.
print("Simulation Metadata:\n----------------")
print("Project Name:", sim.project_name)
print("Timestep:", sim.timestep)
print("Cell Dimensions:", sim.cell_dimensions)
print("Lattice Dimensions:", sim.lattice_dimensions)
print("Electrode Potential:", sim.electrode_potential)
print("Metal Type:", sim.metal_type)
print("Ions:", sim.ions)

# Access trajectory-level metadata.
print("\nTrajectory Metadata:\n----------------")
print("Number of Trajectories:", len(sim.trajectories.times))
print("Time of the first few frames:", sim.trajectories.times[:5])
print("Surface Z-coordinate of the first few frames:", sim.trajectories.surface_z[:5])

# Access position data.
print("\nPosition data:\n----------------")
print("Metal positions:", sim.trajectories.positions.metal[:5])
if len(sim.trajectories.positions.adsorbates[0]) != 0:
    print("Adsorbate positions:", sim.trajectories.positions.adsorbates[:5])
print("Water positions:", sim.trajectories.positions.water[:5])
print("Ion positions:", sim.trajectories.positions.ions[:5])
print("Oxygen (of water) positions:", sim.trajectories.positions.watO[:5])
print("Hydrogen (of water) positions:", sim.trajectories.positions.watH[:5])

"""

In [None]:
# --- Helper function to compute simulation parameters for ion-O RDF --- 
def get_simulation_parameters_ion_O(data, bin_edges):
    """
    Computes simulation cell volume, spherical shell volumes (for the oxygen-ions RDF),
    bin centers, bin width, and the box lengths (assuming a 1D or diagonal matrix)
    from the provided simulation data.
    """
    cell = data.cell_dimensions
    if cell is None:
        logging.warning("No cell dimensions found in file.")
        return None, None, None, None, None
    # If cell is given as a 1D array, assume a cubic-like cell.
    if np.array(cell).ndim == 1:
        volume = np.prod(cell)
        box_lengths = np.array(cell)
    else:
        # Assume cell is given as a full matrix; extract diagonal elements as box lengths.
        diag = np.diag(cell)
        volume = np.prod(diag)
        box_lengths = diag

    # Calculate the volumes of the spherical shells for the oxygen-ions RDF.
    shell_volumes = (4.0 / 3.0) * np.pi * (bin_edges[1:]**3 - bin_edges[:-1]**3)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    bin_width = bin_edges[1] - bin_edges[0]
    return volume, shell_volumes, bin_centers, bin_width, box_lengths

# --- Helper to apply periodic boundary conditions for ion-O distances --- 
def apply_periodic_boundary(diff, box_lengths):
    """
    Applies periodic boundary conditions to a difference vector.
    """
    return diff - box_lengths * np.round(diff / box_lengths)

# --- Compute the ion-O RDF for one frame --- 
def compute_frame_ion_O_rdf(oxygen_positions, ion_positions, oxygen_label, ion_label,
                           bin_edges, shell_volumes, bin_centers, bin_width, volume, box_lengths):
    """
    Computes the radial distribution function (RDF) for one frame, where distances
    are calculated between each oxygen atom (e.g. water oxygen) and every ion.
    
    If the oxygen and ion labels are identical, self-correlations are removed.
    """
    # Calculate pairwise differences and apply periodic boundary conditions.
    diff = oxygen_positions[:, np.newaxis, :] - ion_positions
    diff = apply_periodic_boundary(diff, box_lengths)
    distances = np.linalg.norm(diff, axis=-1).ravel()
    
    # Exclude self-correlations when oxygen and ion labels match.
    if oxygen_label == ion_label:
        distances = distances[distances > 1e-6]
    
    hist, _ = np.histogram(distances, bins=bin_edges)
    num_oxygen = oxygen_positions.shape[0]
    density_oxygen = num_oxygen / volume
    num_ions = ion_positions.shape[1]
    expected_total = num_ions * density_oxygen * shell_volumes
    with np.errstate(divide='ignore', invalid='ignore'):
        rdf_frame = hist / expected_total
        rdf_frame[np.isnan(rdf_frame)] = 0
    return rdf_frame, density_oxygen

# --- Calculate ion-O RDF over multiple frames for a given oxygen species --- 
def calculate_ion_O_rdf(data, oxygen_label, ion_label, skip_time, bin_edges, ion_indices=None):
    """
    Computes the oxygen-ions RDF averaged over simulation frames.
    The ion positions are taken from data.trajectories.positions.ions.
    
    Parameters:
      data: simulation data object.
      oxygen_label (str): Attribute name in data.trajectories.positions for the oxygen species (e.g., "watO").
      ion_label (str): A string used for comparing oxygen and ion species.
                     If oxygen_label equals ion_label, self-correlations are excluded.
      skip_time (float): Time threshold (ps) to skip early simulation frames.
      bin_edges (np.array): Edges of the radial bins.
      ion_indices (list or None): Optional indices to select a subset of ions (by type) from data.trajectories.positions.ions.
      
    Returns:
      bin_centers (np.array): The centers of the radial bins.
      avg_rdf (np.array): The averaged oxygen-ions radial distribution function.
      avg_density (float): The averaged density of the oxygen species.
    """
    try:
        oxygen_positions_all = getattr(data.trajectories.positions, oxygen_label)
    except AttributeError:
        logging.warning("No positions for oxygen species '%s' found.", oxygen_label)
        return None, None, None

    # Always use the ions from the simulation as reference positions.
    if ion_indices is None:
        ion_positions_all = data.trajectories.positions.ions
    else:
        ion_positions_all = data.trajectories.positions.ions[:, ion_indices, :]

    volume, shell_volumes, bin_centers, bin_width, box_lengths = get_simulation_parameters_ion_O(data, bin_edges)
    if volume is None:
        return None, None, None

    rdf_list = []
    density_list = []
    valid_frames = 0
    for i, t in enumerate(data.trajectories.times):
        if t < skip_time:
            continue
        valid_frames += 1
        oxygen_positions = oxygen_positions_all[i]
        ion_positions = ion_positions_all[i]
        rdf_frame, density_oxygen = compute_frame_ion_O_rdf(
            oxygen_positions, ion_positions,
            oxygen_label, ion_label,
            bin_edges, shell_volumes, bin_centers, bin_width,
            volume, box_lengths
        )
        rdf_list.append(rdf_frame)
        density_list.append(density_oxygen)

    if valid_frames == 0:
        logging.warning("No frames passed the skip time threshold.")
        return None, None, None

    avg_rdf = np.mean(rdf_list, axis=0)
    avg_density = np.mean(density_list)
    return bin_centers, avg_rdf, avg_density

# --- Main function to aggregate and plot ion-O RDF data --- 
def plot_ion_O_rdfs(pickle_files, skip=5, bins=50, max_distance=8, verbose=False):
    """
    Loads simulation pickle files, computes the oxygen-ions RDF (ion–O RDF) for each file,
    aggregates the RDF data per ion type (as defined in data.ions), and plots one averaged RDF
    curve per ion type.
    
    Parameters:
      pickle_files (list): List of simulation pickle file paths.
      skip (float): Time (ps) to skip at the beginning of each simulation.
      bins (int): Number of bins for the RDF histogram.
      max_distance (float): Maximum distance for the RDF histogram.
      verbose (bool): If True, enable debug-level logging.
    """
    setup_logging(verbose)

    bin_edges = np.linspace(0, max_distance, bins + 1)
    # Dictionary to hold aggregated ion-O RDF data per ion type.
    # Each key stores:
    #   'bin_centers': bin centers (assumed identical for all files),
    #   'rdf_sum': sum of (rdf * ion_count) over files,
    #   'total_count': total ion count across files.
    aggregated = {}
    
    for file_path in pickle_files:
        logging.debug("Processing file: %s", file_path)
        try:
            sim_data = load_simulation(file_path)  # load_simulation is assumed to be defined elsewhere.
        except Exception as e:
            logging.error("Failed to load file %s: %s", file_path, e)
            continue

        # Loop over unique ion types in the simulation.
        unique_ions = set(sim_data.ions)
        for ion_type in unique_ions:
            # Get indices for the current ion type.
            indices = [i for i, ion in enumerate(sim_data.ions) if ion == ion_type]
            count = sim_data.trajectories.positions.ions.shape[1] if indices is None else len(indices)
            
            # Compute the oxygen (e.g., water oxygen) RDF with respect to these ions.
            for oxygen in ["watO"]:
                # Here, ion_label is set to "ions" for the purpose of comparison.
                bin_centers, rdf_avg, avg_density = calculate_ion_O_rdf(
                    sim_data, oxygen, "ions", skip, bin_edges, ion_indices=indices
                )
                if bin_centers is None or rdf_avg is None:
                    logging.warning("Skipping file %s for ion type %s and oxygen %s due to missing RDF data.", file_path, ion_type, oxygen)
                    continue
                
                # Weight the computed RDF by the number of ions of this type.
                weighted_rdf = rdf_avg * len(indices)
                label = f"{ion_type}"
                if label not in aggregated:
                    aggregated[label] = {
                        "bin_centers": bin_centers,
                        "rdf_sum": weighted_rdf,
                        "total_count": len(indices)
                    }
                else:
                    aggregated[label]["rdf_sum"] += weighted_rdf
                    aggregated[label]["total_count"] += len(indices)
    
    # Prepare aggregated data for plotting.
    rdf_data_list = []
    for label, data_dict in aggregated.items():
        # Compute the weighted average ion-O RDF for this ion type.
        avg_rdf = data_dict["rdf_sum"] / data_dict["total_count"]
        rdf_data_list.append((data_dict["bin_centers"], avg_rdf, label))
        logging.debug("Aggregated ion-O RDF for %s: Total ion count = %d", label, data_dict["total_count"])
    
    if rdf_data_list:
        init_plot(xlabel="Ion–O Distance (Å)", ylabel="g(r)", yticks_remove=True)  # Assumes init_plot is defined elsewhere.
        for bin_centers, rdf, label in rdf_data_list:
            plt.plot(bin_centers, rdf, linestyle='-', linewidth=8, marker=None,
                     label=label, color=ION_COLORS.get(label, 'black'))  # ION_COLORS is assumed to be defined elsewhere.
        show_plot(xticks=np.arange(2, 7, 1), xlim=(1.5, 6.5), ylim=(0))  # Assumes show_plot is defined elsewhere.
    else:
        logging.error("No valid ion-O RDF data to plot.")

# --- Example usage --- 
# List of simulation IDs (used to generate file paths) for ion-O RDF analysis.
sim_ids = ["Li2", "Na2", "K2", "Cs2"]
pickle_files = [f"data/simulations/Pt111_{sim_id}.pkl" for sim_id in sim_ids]

plot_ion_O_rdfs(pickle_files=pickle_files)


In [None]:
def calculate_histogram_density(all_distances, bin_edges, cell_dimensions, valid_frames):
    """
    Compute the density histogram given distances, cell dimensions, bin_edges, and the number of valid frames.
    """
    counts, _ = np.histogram(all_distances, bins=bin_edges)
    # Convert bin width from Å to nm.
    bin_width_nm = (bin_edges[1] - bin_edges[0]) * 0.1
    area_nm2 = (cell_dimensions[0] * 0.1) * (cell_dimensions[1] * 0.1)
    density = counts / (area_nm2 * bin_width_nm * valid_frames)
    return density

def extract_initial_positions(simulation, ion, skip):
    """
    Extract the initial positions (first frame) of the specified ion type.
    """
    # Create a boolean mask using the stored ion element symbols.
    target_mask = np.array([elem == ion for elem in simulation.ions])
    if not np.any(target_mask):
        return []
    # Extract z positions for the first frame.
    first_frame_positions = simulation.trajectories.positions.ions[0, target_mask, 2]
    first_frame_surface = simulation.trajectories.surface_z[0]
    distances = np.abs(first_frame_positions - first_frame_surface)
    return distances.tolist()

def process_density(simulation, target_positions, skip, bin_edges):
    """
    Process density for either ions or water oxygen:
      - Iterates through frames (skipping early ones),
      - Collects distances from the surface,
      - Returns the computed density histogram and the number of valid frames.
    """
    distances_list = []
    valid_frames = 0
    times = simulation.trajectories.times
    for i in range(len(times)):
        if times[i] < skip:
            continue
        valid_frames += 1
        # Compute distance from surface for the current frame.
        d = np.abs(target_positions[i] - simulation.trajectories.surface_z[i])
        distances_list.append(d)
    if valid_frames == 0:
        logging.warning("No valid frames found after skipping %d ps.", skip)
        return None, 0
    all_distances = np.concatenate(distances_list)
    density = calculate_histogram_density(all_distances, bin_edges, simulation.cell_dimensions, valid_frames)
    return density, valid_frames

def plot_density(files, skip=5, bins=50, smoothing=3, initial_position=False, verbose=False, normalize=True):
    """
    Plot the density of ions and oxygen as a function of distance from the surface.
    Density is computed as the number of ions in a bin divided by the area, bin width, and the number of valid frames.
    
    Parameters:
    - files (list of str): List of simulation file paths.
    - skip (int): Skip frames with time (in ps) less than this value.
    - bins (int): Number of bins for the histogram.
    - smoothing (int): Smoothing window for the density curves.
    - initial_position (bool): Whether to plot initial positions.
    - verbose (bool): Enable verbose logging.
    - normalize (bool): 
          If True, normalize each averaged ion density curve (area set to 1) and scale oxygen density to match.
          If False, use raw ion densities and scale oxygen density to match the raw ion density maximum.
    """
    setup_logging(verbose)
    init_plot(xlabel="Distance to Surface (Å)", 
              ylabel="ρ (a.u.)",
              yticks_remove=True)
    
    # Define histogram bins from 2 Å to 5 Å.
    bin_edges = np.linspace(2, 5, bins + 1)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    bin_width_nm = (bin_edges[1] - bin_edges[0]) * 0.1

    # Prepare dictionaries for ion densities and initial positions.
    ion_density_dict = {ion: [] for ion in IONS}
    initial_positions = {ion: [] for ion in IONS}
    oxygen_density_list = []

    # Process each simulation file.
    for file_path in files:
        simulation = load_simulation(file_path)

        # If requested, extract initial positions for each ion type.
        if initial_position:
            for ion in IONS:
                initial_positions[ion].extend(extract_initial_positions(simulation, ion, skip))

        # Process density for each ion.
        for ion in IONS:
            # Create mask to select positions for the given ion.
            target_mask = np.array([elem == ion for elem in simulation.ions])
            if not np.any(target_mask):
                continue
            # target_positions: all frames for this ion's z-coordinate.
            target_positions = simulation.trajectories.positions.ions[:, target_mask, 2]
            density, valid_frames = process_density(simulation, target_positions, skip, bin_edges)
            if density is None:
                logging.warning("No valid frames for ion %s in file %s", ion, file_path)
                continue
            ion_density_dict[ion].append(density)

        # Process oxygen density if water oxygen positions are available.
        if simulation.trajectories.positions.watO is None or simulation.trajectories.positions.watO.size == 0:
            logging.debug("No water oxygen positions found in file %s", file_path)
        else:
            target_positions = simulation.trajectories.positions.watO[:, :, 2]
            density, valid_frames = process_density(simulation, target_positions, skip, bin_edges)
            if valid_frames > 0:
                oxygen_density_list.append(density)
            else:
                logging.warning("No valid frames for oxygen in file %s", file_path)

    # Average ion densities over files and determine maximum ion density.
    max_ion_value = 0.0
    for ion in IONS:
        if ion_density_dict[ion]:
            # Average the densities over the simulation files.
            avg_density = np.mean(ion_density_dict[ion], axis=0)
            # Normalize if requested.
            if normalize:
                area = np.sum(avg_density * bin_width_nm)
                if area != 0:
                    avg_density = avg_density / area
            ion_density_dict[ion] = avg_density
            max_ion_value = max(max_ion_value, np.max(avg_density))
        else:
            ion_density_dict[ion] = None

    # Average oxygen density over files.
    oxygen_density = np.mean(oxygen_density_list, axis=0) if oxygen_density_list else None
    if oxygen_density is not None and np.max(oxygen_density) > 0 and max_ion_value > 0:
        # Scale oxygen density so that its maximum matches that of the ion densities
        oxygen_density_scaled = oxygen_density * (max_ion_value / np.max(oxygen_density))
        logging.debug("Scaling oxygen density to match ion density maximum.")
    else:
        oxygen_density_scaled = oxygen_density
        logging.warning("No valid oxygen or ion density data to scale.")

    # Plot oxygen density as a shaded black area.
    if oxygen_density_scaled is not None and len(bin_centers) == len(oxygen_density_scaled):
        plt.fill_between(bin_centers, oxygen_density_scaled, color="black", alpha=0.2)

    # Plot ion density curves.
    for ion in IONS:
        density = ion_density_dict[ion]
        if density is not None and len(bin_centers) == len(density):
            # Optionally smooth the density curve.
            if smoothing > 1:
                density = np.convolve(density, np.ones(smoothing) / smoothing, mode='same')
            plt.plot(bin_centers, density, linestyle='-', linewidth=8, marker=None,
                     color=ION_COLORS[ion], label=ion)

    # Plot initial positions if requested.
    if initial_position:
        for ion in IONS:
            for pos in initial_positions[ion]:
                plt.plot(pos, 0.2, marker='o', markersize=10, linestyle='None', color=ION_COLORS[ion])
    
    show_plot(xticks=np.arange(2, 6, 1), xlim=(2, 5), ylim=(0.15))
    
# Example calls:
plot_density([f"data/simulations/Pt111_{ion}2.pkl" for ion in IONS])
plot_density([f"data/simulations/Pt111_{ion}4_H.pkl" for ion in IONS])


In [None]:
def compare_ion_densities(simulation_files, skip=5, bins=50, smoothing=3, normalize=True, verbose=False):
    """
    Compare the densities of each ion per simulation.
    Each simulation's density curve for an ion is plotted as a separate line.
    The line color is determined by its electrode potential: the highest potential gets the original
    ION_COLORS hex color and the lowest potential gets a lighter hue via blend_color().

    Parameters:
    - simulation_files (list of str): List of simulation pickle files.
    - skip (int): Skip frames with time (in ps) less than this value.
    - bins (int): Number of bins for the histogram.
    - smoothing (int): Window for smoothing the density curves.
    - verbose (bool): Enable verbose logging.
    """
    setup_logging(verbose)
    init_plot(xlabel="Distance to Surface (Å)",
              ylabel="ρ (a.u.)",
              yticks_remove=True)

    # Define histogram bins (from 2 Å to 5 Å).
    bin_edges = np.linspace(2, 5, bins + 1)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])

    # Load all simulations to get the potential values.
    simulations = []
    potentials = []
    for file_path in simulation_files:
        sim = load_simulation(file_path)
        simulations.append(sim)
        potentials.append(sim.electrode_potential)
    
    # Determine the potential range.
    max_pot = max(potentials)
    min_pot = min(potentials)
    pot_range = max_pot - min_pot if max_pot != min_pot else 1  # avoid division by zero
    
    # For each simulation, process each ion and plot its density.
    for sim in simulations:
        for ion in IONS:
            # Build a mask to select positions for the current ion.
            target_mask = np.array([elem == ion for elem in sim.ions])
            if not np.any(target_mask):
                continue
            # target_positions: z-coordinates for the ion across all frames.
            target_positions = sim.trajectories.positions.ions[:, target_mask, 2]
            density, valid_frames = process_density(sim, target_positions, skip, bin_edges)
            if density is None:
                logging.warning("No valid frames for ion %s in simulation %s", ion, sim.filename)
                continue

            # Optionally smooth the density curve.
            if smoothing > 1:
                density = np.convolve(density, np.ones(smoothing) / smoothing, mode='same')
            
            # Normalize the density so the area under the curve (from 2 to 5 Å) is 1.
            # Here we use the trapezoidal integration method.
            if normalize:
                area = np.trapz(density, bin_centers)
                if area != 0:
                    density /= area

            # Compute the fraction for color blending:
            # Lowest (most negative potential) gets fraction=0
            # Highest (most positive potential) gets fraction=0.75
            exact_fraction = (sim.electrode_potential - min_pot) / pot_range
            color_fraction = 0.75 * exact_fraction
            color = blend_color(ION_COLORS[ion], color_fraction)

            label = f"{ion} ({sim.electrode_potential:.2f} V)"
            plt.plot(bin_centers, density, linestyle='-', linewidth=8, marker=None,
                     color=color, label=label)

    show_plot(xticks=np.arange(2, 6, 1), xlim=(2, 5), ylim=(0), legend_font_size=24)

for ion in IONS:
    compare_ion_densities([f"data/simulations/Pt111_{ion}{i}.pkl" for i in range(1, 5)])
    compare_ion_densities([f"data/simulations/Pt111_{ion}{i}_H.pkl" for i in range(1, 5)])

In [None]:
def plot_opening_angles(files, bins=180, region_width=0.5, min_count=1000):
    """
    Plot the distribution of opening angles between ions, water oxygen, and the surface-normal.
    The angles are computed for each ion type and grouped by the ion's distance from the surface.
    Only regions with region_min > 1.5 Å and region_max < 5.5 Å are plotted.
    
    Parameters:
        files (list): List of simulation pickle file paths.
        bins (int): Number of bins for the histogram (angles from 0 to 180°).
        region_width (float): Width of the distance regions in Å.
        min_count (int): Minimum number of angles required for a region to be included in the plot.
    """
    # Dictionary to hold opening angles for each ion and region.
    angle_data = {ion: {} for ion in IONS}
    
    # Process each simulation file.
    for file in files:
        sim = load_simulation(file)
        
        # Extract ion positions and water oxygen positions.
        ion_positions_all = sim.trajectories.positions.ions    # shape: (n_frames, n_ions, 3)
        o_positions_all = sim.trajectories.positions.watO         # shape: (n_frames, n_water, 3)
        if o_positions_all is None or o_positions_all.size == 0:
            print(f"No water oxygen positions found in file: {file}")
            continue
        
        n_frames = sim.trajectories.times.shape[0]
        surface_zs = sim.trajectories.surface_z  # one surface z per frame
        
        # Build ion indices from sim.ions.
        ion_indices = {}
        for i, ion in enumerate(sim.ions):
            ion_indices.setdefault(ion, []).append(i)
        
        # Loop over frames.
        for frame in range(n_frames):
            frame_ions = ion_positions_all[frame]  # shape: (n_ions, 3)
            frame_o = o_positions_all[frame]         # shape: (n_water, 3)
            surface_z = surface_zs[frame]
            
            # Process each ion type.
            for ion in IONS:
                if ion not in ion_indices:
                    continue
                for idx in ion_indices[ion]:
                    ion_pos = frame_ions[idx]
                    # Calculate cation-surface distance (assuming surface normal is +z).
                    distance = ion_pos[2] - surface_z
                    if distance < 0:
                        continue  # skip ions below the surface
                    # Determine the region based on the distance.
                    region_min = np.floor(distance / region_width) * region_width
                    region_max = region_min + region_width
                    if region_min > 1.5 and region_max < 5.5:
                        region_key = (region_min, region_max)
                        angle_data[ion].setdefault(region_key, [])
                        
                        # Compute displacement vectors from ion to all water oxygen atoms.
                        diff = frame_o - ion_pos
                        distances = np.linalg.norm(diff, axis=1)
                        # Consider only oxygens within the ion's solvation cutoff.
                        valid = distances <= ION_CUTOFFS[ion]
                        if not np.any(valid):
                            continue
                        valid_diff = diff[valid]
                        valid_dist = distances[valid]
                        # Compute cosine of the angle with respect to the surface-normal [0, 0, 1].
                        cos_theta = valid_diff[:, 2] / valid_dist
                        cos_theta = np.clip(cos_theta, -1.0, 1.0)
                        # Compute the opening angle: |180° - angle| in degrees.
                        angles = np.abs(180 - np.degrees(np.arccos(cos_theta)))
                        angle_data[ion][region_key].extend(angles.tolist())
    
    # Prepare histogram data for each (ion, region) combination.
    histogram_data = []
    angle_bins = np.linspace(0, 180, bins + 1)  # bins for 0 to 180 degrees
    for ion in IONS:
        for region_key, angle_list in angle_data[ion].items():
            if len(angle_list) < min_count:
                continue
            angles_array = np.array(angle_list)
            counts, edges = np.histogram(angles_array, bins=angle_bins)
            probability = counts / np.sum(counts)
            bin_centers = (edges[:-1] + edges[1:]) / 2
            histogram_data.append((ion, region_key, bin_centers, probability))
    
    # Generate gradient colors for each region per ion.
    ion_region_colors = {}
    for ion in IONS:
        region_keys = sorted([rk for rk in angle_data[ion].keys() if len(angle_data[ion][rk]) >= min_count],
                             key=lambda x: x[0])
        n_regions = len(region_keys)
        colors = {}
        if n_regions > 0:
            for idx, region_key in enumerate(region_keys):
                # The first (lowest region_min) gets fraction 0 (darkest) and subsequent regions get a higher fraction.
                fraction = (idx / (n_regions - 1) / 2) if n_regions > 1 else 0
                colors[region_key] = blend_color(ION_COLORS[ion], fraction)
        ion_region_colors[ion] = colors
    
    # ---- Reorder the plotting so that for each ion the darkest (lowest region_min) plot is drawn last.
    # Sort histogram_data first by ion order (according to IONS) and then by region_min in descending order.
    histogram_data_sorted = sorted(histogram_data, key=lambda x: (IONS.index(x[0]), -x[1][0]))
    
    # Plot using your notebook's plotting functions.
    init_plot(xlabel="Opening angle (degrees)", ylabel="Probability (a.u.)",
              font_size=48, yticks_remove=True, grid=False)
    
    for ion, region_key, bin_centers, probability in histogram_data_sorted:
        color = ion_region_colors[ion].get(region_key, ION_COLORS[ion])
        region_min, region_max = region_key
        label = f"{ion} ({region_min:.1f}–{region_max:.1f} Å)"
        plt.plot(bin_centers, probability, linestyle='-', linewidth=8, marker=None,
                 color=color, label=label)
    
    # Get legend handles and labels in plotting order.
    handles, labels = plt.gca().get_legend_handles_labels()
    show_plot(xticks=np.arange(0, 181, 30), xlim=(0, 180), ylim=(0.001, None),
              legend_handles=handles, legend_labels=labels, legend_font_size=28)

# Example usage:
sim_ids = []
for ion in IONS:
    for i in range(2, 5):
        for h in ["", "_H"]:
            sim_ids.append(f"{ion}{i}{h}")
for sim_id in sim_ids:
    print(f"Processing simulation: Pt111_{sim_id}")
    plot_opening_angles([f"data/simulations/Pt111_{sim_id}.pkl"])
