In [2]:
#!/usr/bin/env python3
"""
Comparison Topology Plotter v3 - Clean Annotations
3-panel side-by-side comparison with repositioned annotations that don't hide plots
PES range [-5,5] with small starting point markers
"""

import numpy as np
import h5py
import json
import os
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
import random

# Set default renderer
pio.renderers.default = 'browser'

class ComparisonTopologyPlotterV3:
    def __init__(self, data_directory=None):
        """Initialize the plotter with data directory"""
        self.data_dir = data_directory or os.path.expanduser("~/Desktop/hpc/upper/results/a1")
        
    def load_h5_file(self, filename):
        """Load trajectory data from HDF5 file"""
        filepath = os.path.join(self.data_dir, filename)
        
        if not os.path.exists(filepath):
            print(f"Error: File {filepath} not found")
            return None, None
            
        results = []
        parameters = {}
        
        print(f"Loading {filename}...")
        
        with h5py.File(filepath, 'r') as f:
            # Load parameters
            param_group = f['parameters']
            for key in param_group.attrs:
                try:
                    value = param_group.attrs[key]
                    if isinstance(value, str) and (value.startswith('{') or value.startswith('[')):
                        parameters[key] = json.loads(value)
                    else:
                        parameters[key] = value
                except:
                    parameters[key] = param_group.attrs[key]
            
            # Load trajectories
            traj_count = 0
            for traj_name in sorted([k for k in f.keys() if k.startswith('trajectory_')]):
                traj_group = f[traj_name]
                traj = {}
                
                # Load metadata
                traj['type'] = traj_group.attrs.get('type', 'regular')
                traj['is_prelooping'] = traj_group.attrs.get('is_prelooping', False)
                
                # Load trajectory data
                traj['positions'] = traj_group['positions'][()]
                traj['velocities'] = traj_group['velocities'][()]
                traj['momenta'] = traj_group['momenta'][()]
                
                # Reconstruct complex psi values
                psi_real = traj_group['psi_t_real'][()]
                psi_imag = traj_group['psi_t_imag'][()]
                traj['psi_t'] = psi_real + 1j * psi_imag
                
                traj['energies'] = traj_group['energies'][()]
                traj['populations'] = traj_group['populations'][()]
                traj['adiabatic_pops'] = traj_group['adiabatic_pops'][()]
                traj['berry_curvatures'] = traj_group['berry_curvatures'][()]
                traj['forces'] = traj_group['forces'][()]
                traj['index'] = traj_count
                
                results.append(traj)
                traj_count += 1
                
        print(f"Successfully loaded {traj_count} trajectories")
        return results, parameters
    
    def hamiltonian(self, x, y, parameters):
        """Compute the Hamiltonian for PES calculation"""
        if parameters['z_choice'] == 'constant':
            z = parameters['z_val']
        else:
            z = np.sqrt(x**2 + (1 - parameters['e']**2) * y**2)
        
        off_diag = x - 1j * np.sqrt(1 - parameters['e']**2) * y
        
        return 0.5 * parameters['s'] * np.array([
            [parameters['a'] * z, off_diag], 
            [np.conjugate(off_diag), -z]
        ], dtype=complex)
    
    def analyze_trajectory_types(self, results):
        """Analyze and separate trajectory types"""
        prelooping_trajs = []
        regular_trajs = []
        
        for i, traj in enumerate(results):
            traj['index'] = i
            if traj.get('is_prelooping', False):
                prelooping_trajs.append(traj)
            else:
                regular_trajs.append(traj)
        
        return prelooping_trajs, regular_trajs
    
    def calculate_average_quantities(self, trajectories, ns):
        """Calculate average quantities for a set of trajectories"""
        if not trajectories:
            return None
            
        n_traj = len(trajectories)
        
        # Initialize averages
        adiabatic_avg = np.zeros((ns, 2))
        forces_avg = np.zeros((ns, 2))
        
        # Sum over trajectories
        for traj in trajectories:
            adiabatic_avg += traj['adiabatic_pops']
            forces_avg += traj['forces']
        
        # Average
        adiabatic_avg /= n_traj
        forces_avg /= n_traj
        
        return {
            'adiabatic_pops': adiabatic_avg,
            'forces': forces_avg
        }
    
    def plot_adiabatic_comparison_all(self, topology_files):
        """
        Create 3-panel side-by-side comparison of adiabatic population dynamics (all trajectories)
        """
        # Load all data
        loaded_data = {}
        for label, filename in topology_files.items():
            results, parameters = self.load_h5_file(filename)
            if results is not None:
                loaded_data[label] = (results, parameters)
                prelooping_trajs, regular_trajs = self.analyze_trajectory_types(results)
                print(f"\n{label}:")
                print(f"  - Total trajectories: {len(results)}")
                print(f"  - Pre-looped: {len(prelooping_trajs)}, Regular: {len(regular_trajs)}")
        
        if not loaded_data:
            print("Error: No data files could be loaded!")
            return None
        
        n_cases = len(loaded_data)
        
        # Create side-by-side comparison focusing on all trajectories
        fig = make_subplots(
            rows=1, cols=n_cases,
            subplot_titles=[f"{label}<br><sub>Adiabatic Population (All Trajectories)</sub>" for label in loaded_data.keys()],
            horizontal_spacing=0.08
        )
        
        col = 1
        for label, (results, parameters) in loaded_data.items():
            # Time array
            ns = parameters['ns']
            dt = parameters['dt']
            time = np.arange(ns) * dt
            
            # Separate trajectory types
            prelooping_trajs, regular_trajs = self.analyze_trajectory_types(results)
            
            # Calculate averages
            all_avg = self.calculate_average_quantities(results, ns)
            
            # Plot all trajectories average with thick lines
            if all_avg:
                fig.add_trace(
                    go.Scatter(x=time, y=all_avg['adiabatic_pops'][:, 0], mode='lines',
                              name='Lower Adiabatic' if col == 1 else None,
                              line=dict(color='blue', width=4),
                              legendgroup='lower', showlegend=(col == 1)),
                    row=1, col=col
                )
                fig.add_trace(
                    go.Scatter(x=time, y=all_avg['adiabatic_pops'][:, 1], mode='lines',
                              name='Upper Adiabatic' if col == 1 else None,
                              line=dict(color='red', width=4),
                              legendgroup='upper', showlegend=(col == 1)),
                    row=1, col=col
                )
            
            # Add reference line manually
            fig.add_trace(
                go.Scatter(x=[time[0], time[-1]], y=[0.5, 0.5], mode='lines',
                          line=dict(color='gray', width=1, dash='dot'),
                          opacity=0.5, showlegend=False),
                row=1, col=col
            )
            
            # Update axes with trajectory info
            total_trajs = len(results)
            preloop_count = len(prelooping_trajs)
            regular_count = len(regular_trajs)
            
            fig.update_xaxes(title_text=f"Time (dt={dt:.3f})", row=1, col=col)
            fig.update_yaxes(title_text=f"Population<br>Total: {total_trajs}<br>P:{preloop_count}, R:{regular_count}", 
                           range=[0, 1.05], row=1, col=col)
            
            col += 1
        
        # Get sample parameters for title
        sample_params = list(loaded_data.values())[0][1]
        
        fig.update_layout(
            height=600,
            width=500 * n_cases,
            title=f"Adiabatic Population Dynamics Comparison (All Trajectories)<br><sub>e={sample_params['e']:.1f}, dt={sample_params['dt']:.3f}, Berry+Geometric Forces</sub>",
            showlegend=True,
            legend=dict(font=dict(size=10))
        )
        
        return fig
    
    def plot_adiabatic_comparison_by_type(self, topology_files):
        """
        Create 3-panel side-by-side comparison of adiabatic population dynamics (by trajectory type)
        """
        # Load all data
        loaded_data = {}
        for label, filename in topology_files.items():
            results, parameters = self.load_h5_file(filename)
            if results is not None:
                loaded_data[label] = (results, parameters)
        
        if not loaded_data:
            print("Error: No data files could be loaded!")
            return None
        
        n_cases = len(loaded_data)
        
        # Create side-by-side comparison by trajectory type
        fig = make_subplots(
            rows=1, cols=n_cases,
            subplot_titles=[f"{label}<br><sub>Adiabatic Population (By Trajectory Type)</sub>" for label in loaded_data.keys()],
            horizontal_spacing=0.08
        )
        
        col = 1
        for label, (results, parameters) in loaded_data.items():
            # Time array
            ns = parameters['ns']
            dt = parameters['dt']
            time = np.arange(ns) * dt
            
            # Separate trajectory types
            prelooping_trajs, regular_trajs = self.analyze_trajectory_types(results)
            
            # Calculate averages
            all_avg = self.calculate_average_quantities(results, ns)
            preloop_avg = self.calculate_average_quantities(prelooping_trajs, ns) if prelooping_trajs else None
            regular_avg = self.calculate_average_quantities(regular_trajs, ns) if regular_trajs else None
            
            # Plot all trajectories average
            if all_avg:
                fig.add_trace(
                    go.Scatter(x=time, y=all_avg['adiabatic_pops'][:, 0], mode='lines',
                              name='All - Lower' if col == 1 else None,
                              line=dict(color='blue', width=2),
                              legendgroup='all_lower', showlegend=(col == 1)),
                    row=1, col=col
                )
                fig.add_trace(
                    go.Scatter(x=time, y=all_avg['adiabatic_pops'][:, 1], mode='lines',
                              name='All - Upper' if col == 1 else None,
                              line=dict(color='red', width=2),
                              legendgroup='all_upper', showlegend=(col == 1)),
                    row=1, col=col
                )
            
            # Plot pre-looped trajectories average
            if preloop_avg:
                fig.add_trace(
                    go.Scatter(x=time, y=preloop_avg['adiabatic_pops'][:, 0], mode='lines',
                              name='Pre-looped - Lower' if col == 1 else None,
                              line=dict(color='blue', width=2, dash='dash'),
                              legendgroup='pre_lower', showlegend=(col == 1)),
                    row=1, col=col
                )
                fig.add_trace(
                    go.Scatter(x=time, y=preloop_avg['adiabatic_pops'][:, 1], mode='lines',
                              name='Pre-looped - Upper' if col == 1 else None,
                              line=dict(color='red', width=2, dash='dash'),
                              legendgroup='pre_upper', showlegend=(col == 1)),
                    row=1, col=col
                )
            
            # Plot regular trajectories average
            if regular_avg:
                fig.add_trace(
                    go.Scatter(x=time, y=regular_avg['adiabatic_pops'][:, 0], mode='lines',
                              name='Regular - Lower' if col == 1 else None,
                              line=dict(color='blue', width=2, dash='dot'),
                              legendgroup='reg_lower', showlegend=(col == 1)),
                    row=1, col=col
                )
                fig.add_trace(
                    go.Scatter(x=time, y=regular_avg['adiabatic_pops'][:, 1], mode='lines',
                              name='Regular - Upper' if col == 1 else None,
                              line=dict(color='red', width=2, dash='dot'),
                              legendgroup='reg_upper', showlegend=(col == 1)),
                    row=1, col=col
                )
            
            # Add reference line manually
            fig.add_trace(
                go.Scatter(x=[time[0], time[-1]], y=[0.5, 0.5], mode='lines',
                          line=dict(color='gray', width=1, dash='dot'),
                          opacity=0.5, showlegend=False),
                row=1, col=col
            )
            
            # Update axes
            preloop_count = len(prelooping_trajs)
            regular_count = len(regular_trajs)
            
            fig.update_xaxes(title_text=f"Time (dt={dt:.3f})", row=1, col=col)
            fig.update_yaxes(title_text=f"Population<br>P:{preloop_count}, R:{regular_count}", 
                           range=[0, 1.05], row=1, col=col)
            
            col += 1
        
        # Get sample parameters for title
        sample_params = list(loaded_data.values())[0][1]
        
        fig.update_layout(
            height=600,
            width=500 * n_cases,
            title=f"Adiabatic Population Dynamics Comparison (By Trajectory Type)<br><sub>e={sample_params['e']:.1f}, dt={sample_params['dt']:.3f}, Berry+Geometric Forces</sub>",
            showlegend=True,
            legend=dict(font=dict(size=10))
        )
        
        return fig
    
    def plot_forces_comparison(self, topology_files):
        """
        Create 3-panel side-by-side comparison of forces by trajectory type
        """
        # Load all data
        loaded_data = {}
        for label, filename in topology_files.items():
            results, parameters = self.load_h5_file(filename)
            if results is not None:
                loaded_data[label] = (results, parameters)
        
        if not loaded_data:
            print("Error: No data files could be loaded!")
            return None
        
        n_cases = len(loaded_data)
        
        # Create side-by-side forces comparison
        fig = make_subplots(
            rows=1, cols=n_cases,
            subplot_titles=[f"{label}<br><sub>Forces by Trajectory Type (±1000)</sub>" for label in loaded_data.keys()],
            horizontal_spacing=0.08
        )
        
        col = 1
        for label, (results, parameters) in loaded_data.items():
            # Time array
            ns = parameters['ns']
            dt = parameters['dt']
            time = np.arange(ns) * dt
            
            # Separate trajectory types
            prelooping_trajs, regular_trajs = self.analyze_trajectory_types(results)
            
            # Calculate averages
            all_avg = self.calculate_average_quantities(results, ns)
            preloop_avg = self.calculate_average_quantities(prelooping_trajs, ns) if prelooping_trajs else None
            regular_avg = self.calculate_average_quantities(regular_trajs, ns) if regular_trajs else None
            
            # Plot all trajectories forces
            if all_avg:
                fig.add_trace(
                    go.Scatter(x=time, y=np.clip(all_avg['forces'][:, 0], -1000, 1000), mode='lines',
                              name='All - Fx' if col == 1 else None,
                              line=dict(color='green', width=2),
                              legendgroup='all_fx', showlegend=(col == 1)),
                    row=1, col=col
                )
                fig.add_trace(
                    go.Scatter(x=time, y=np.clip(all_avg['forces'][:, 1], -1000, 1000), mode='lines',
                              name='All - Fy' if col == 1 else None,
                              line=dict(color='purple', width=2),
                              legendgroup='all_fy', showlegend=(col == 1)),
                    row=1, col=col
                )
            
            # Plot pre-looped forces
            if preloop_avg:
                fig.add_trace(
                    go.Scatter(x=time, y=np.clip(preloop_avg['forces'][:, 0], -1000, 1000), mode='lines',
                              name='Pre-looped - Fx' if col == 1 else None,
                              line=dict(color='green', width=2, dash='dash'),
                              legendgroup='pre_fx', showlegend=(col == 1)),
                    row=1, col=col
                )
                fig.add_trace(
                    go.Scatter(x=time, y=np.clip(preloop_avg['forces'][:, 1], -1000, 1000), mode='lines',
                              name='Pre-looped - Fy' if col == 1 else None,
                              line=dict(color='purple', width=2, dash='dash'),
                              legendgroup='pre_fy', showlegend=(col == 1)),
                    row=1, col=col
                )
            
            # Plot regular forces
            if regular_avg:
                fig.add_trace(
                    go.Scatter(x=time, y=np.clip(regular_avg['forces'][:, 0], -1000, 1000), mode='lines',
                              name='Regular - Fx' if col == 1 else None,
                              line=dict(color='green', width=2, dash='dot'),
                              legendgroup='reg_fx', showlegend=(col == 1)),
                    row=1, col=col
                )
                fig.add_trace(
                    go.Scatter(x=time, y=np.clip(regular_avg['forces'][:, 1], -1000, 1000), mode='lines',
                              name='Regular - Fy' if col == 1 else None,
                              line=dict(color='purple', width=2, dash='dot'),
                              legendgroup='reg_fy', showlegend=(col == 1)),
                    row=1, col=col
                )
            
            # Add zero reference line manually
            fig.add_trace(
                go.Scatter(x=[time[0], time[-1]], y=[0, 0], mode='lines',
                          line=dict(color='gray', width=1, dash='dot'),
                          opacity=0.5, showlegend=False),
                row=1, col=col
            )
            
            # Update axes
            preloop_count = len(prelooping_trajs)
            regular_count = len(regular_trajs)
            fig.update_xaxes(title_text=f"Time (dt={dt:.3f})", row=1, col=col)
            fig.update_yaxes(title_text=f"Force<br>P:{preloop_count}, R:{regular_count}", 
                           range=[-1000, 1000], row=1, col=col)
            
            col += 1
        
        # Get sample parameters for title
        sample_params = list(loaded_data.values())[0][1]
        
        fig.update_layout(
            height=600,
            width=500 * n_cases,
            title=f"Force Components Comparison by Trajectory Type<br><sub>e={sample_params['e']:.1f}, dt={sample_params['dt']:.3f}, Berry+Geometric Forces, Range: ±1000</sub>",
            showlegend=True,
            legend=dict(font=dict(size=10))
        )
        
        return fig
    
    def plot_pes_comparison(self, topology_files, n_traj_per_case=3):
        """
        Create 3-panel side-by-side PES with  trajectories comparison (enhanced visibility)
        PES range [-5,5] with small starting point markers
        """
        # Load all data
        loaded_data = {}
        for label, filename in topology_files.items():
            results, parameters = self.load_h5_file(filename)
            if results is not None:
                loaded_data[label] = (results, parameters)
        
        if not loaded_data:
            print("Error: No data files could be loaded!")
            return None
        
        n_cases = len(loaded_data)
        
        fig = make_subplots(
            rows=1, cols=n_cases,
            subplot_titles=[f"{label}<br><sub>PES +  Trajectories</sub>" for label in loaded_data.keys()],
            specs=[[{'type': 'surface'} for _ in range(n_cases)]],
            horizontal_spacing=0.05
        )
        
        # Define trajectory colors (bright colors for visibility)
        traj_colors = ['lime', 'cyan', 'yellow', 'orange', 'magenta', 'white']
        
        col = 1
        for label, (results, parameters) in loaded_data.items():
            # Create compact PES grid (-5 to 5)
            x_range = np.linspace(-5, 5, 50)
            y_range = np.linspace(-5, 5, 50)
            X, Y = np.meshgrid(x_range, y_range)
            
            # Calculate PES
            Z_lower = np.zeros_like(X)
            Z_upper = np.zeros_like(X)
            
            for i in range(len(x_range)):
                for j in range(len(y_range)):
                    x, y = X[i, j], Y[i, j]
                    H = self.hamiltonian(x, y, parameters)
                    eigenvalues, _ = np.linalg.eigh(H)
                    Z_lower[i, j] = eigenvalues[0]
                    Z_upper[i, j] = eigenvalues[1]
            
            # Add PES surfaces
            fig.add_trace(
                go.Surface(x=X, y=Y, z=Z_lower, 
                          colorscale='Blues', opacity=0.7,
                          showscale=False, name='Lower PES'),
                row=1, col=col
            )
            
            fig.add_trace(
                go.Surface(x=X, y=Y, z=Z_upper, 
                          colorscale='Reds', opacity=0.7,
                          showscale=False, name='Upper PES'),
                row=1, col=col
            )
            
            # Separate trajectory types
            prelooping_trajs, regular_trajs = self.analyze_trajectory_types(results)
            
            # Select trajectories from both types
            selected_preloop = random.sample(prelooping_trajs, min(2, len(prelooping_trajs))) if prelooping_trajs else []
            selected_regular = random.sample(regular_trajs, min(2, len(regular_trajs))) if regular_trajs else []
            all_selected = selected_preloop + selected_regular
            
            # Plot trajectories with enhanced visibility (width=6)
            color_idx = 0
            for traj in all_selected:
                positions = traj['positions']
                traj_idx = traj['index']
                traj_type = "Pre-looped" if traj.get('is_prelooping', False) else "Regular"
                
                # Filter positions within compact range (-5 to 5)
                mask = (np.abs(positions[:, 0]) <= 5) & (np.abs(positions[:, 1]) <= 5)
                valid_pos = positions[mask]
                
                if len(valid_pos) < 10:  # Skip if too few points
                    continue
                
                # Calculate energy along trajectory
                energies = []
                valid_indices_mask = np.where(mask)[0]
                
                for j, (x, y) in enumerate(valid_pos):
                    if j < len(valid_indices_mask):
                        psi = traj['psi_t'][valid_indices_mask[j]]
                        H = self.hamiltonian(x, y, parameters)
                        E = np.real(np.vdot(psi, np.dot(H, psi)))
                        energies.append(E)
                
                if len(energies) == len(valid_pos):
                    # Add trajectory (enhanced visibility - width=6)
                    fig.add_trace(
                        go.Scatter3d(
                            x=valid_pos[:, 0], 
                            y=valid_pos[:, 1], 
                            z=energies,
                            mode='lines',
                            line=dict(color=traj_colors[color_idx], width=6),  #  trajectories
                            name=f'T{traj_idx+1}({traj_type})' if col == 1 else None,
                            legendgroup=f'traj{color_idx}',
                            showlegend=(col == 1)
                        ),
                        row=1, col=col
                    )
                    
                    # Add much smaller start point (size=3 to avoid hiding trajectory)
                    fig.add_trace(
                        go.Scatter3d(
                            x=[valid_pos[0, 0]], 
                            y=[valid_pos[0, 1]], 
                            z=[energies[0]],
                            mode='markers',
                            marker=dict(size=3, color='white',  # Much smaller marker
                                       line=dict(color=traj_colors[color_idx], width=1),
                                       symbol='circle'),
                            showlegend=False
                        ),
                        row=1, col=col
                    )
                
                color_idx += 1
            
            col += 1
        
        # Create comprehensive title with simulation info
        sample_params = list(loaded_data.values())[0][1]  # Get first parameter set
        title_text = (
            f"PES Comparison with Enhanced Trajectory Visibility<br>"
            f"<sub>e={sample_params['e']:.1f}, dt={sample_params['dt']:.3f}, "
            f"Forces: Berry+Geometric, PES Range: [-5,5], Trajectory Width: 6, Start Marker: 3</sub>"
        )
        
        # Update layout
        fig.update_layout(
            height=700,
            width=500 * n_cases,
            title=title_text,
            showlegend=True,
            legend=dict(font=dict(size=8))
        )
        
        # Update scenes for each subplot
        for i in range(1, n_cases + 1):
            scene_name = 'scene' if i == 1 else f'scene{i}'
            fig.update_layout(**{
                scene_name: dict(
                    xaxis=dict(title='X', range=[-5, 5]),
                    yaxis=dict(title='Y', range=[-5, 5]),
                    zaxis=dict(title='Energy'),
                    camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
                )
            })
        
        return fig


def main():
    """Main function to create comparison topology plots v3"""
    
    # Initialize plotter
    data_dir = os.path.expanduser("~/Desktop/hpc/upper/results/a1")
    plotter = ComparisonTopologyPlotterV3(data_directory=data_dir)
    
    # Define topology files
    topology_files = {
        'Z=0 (Conical)': 'quantum_dynamics_e0.0_z0.0_berry_geom_20250422_142720.h5',
        'Z=0.05 (Avoided)': 'quantum_dynamics_e0.0_z0.05_berry_geom_20250422_132932.h5',
        'Z=f(x,y) (Parabolic)': 'quantum_dynamics_e0.0_zfunc_berry_geom_20250422_192616.h5'
    }
    
    print("Comparison Topology Plotter v3 - Clean Annotations")
    print("=" * 55)
    
    print("\n1. Creating adiabatic population comparison (all trajectories)...")
    pop_all_fig = plotter.plot_adiabatic_comparison_all(topology_files)
    
    if pop_all_fig is not None:
        pop_all_fig.write_html("comparison_v3_adiabatic_all_trajectories.html")
        pop_all_fig.show()
        print("Saved: comparison_v3_adiabatic_all_trajectories.html")
    
    print("\n2. Creating adiabatic population comparison (by trajectory type)...")
    pop_type_fig = plotter.plot_adiabatic_comparison_by_type(topology_files)
    
    if pop_type_fig is not None:
        pop_type_fig.write_html("comparison_v3_adiabatic_by_type.html")
        pop_type_fig.show()
        print("Saved: comparison_v3_adiabatic_by_type.html")
    
    print("\n3. Creating forces comparison (by trajectory type)...")
    forces_fig = plotter.plot_forces_comparison(topology_files)
    
    if forces_fig is not None:
        forces_fig.write_html("comparison_v3_forces_by_type.html")
        forces_fig.show()
        print("Saved: comparison_v3_forces_by_type.html")
    
    print("\n4. Creating PES comparison with  trajectories...")
    pes_fig = plotter.plot_pes_comparison(topology_files, n_traj_per_case=4)
    
    if pes_fig is not None:
        pes_fig.write_html("comparison_v3_pes_trajectories.html")
        pes_fig.show()
        print("Saved: comparison_v3_pes_trajectories.html")
    
    print("\nComparison topology plots v3 completed!")
    print("\nLayout structure:")
    print("┌─────────────┬─────────────┬─────────────┐")
    print("│   Z=0       │   Z=0.05    │  Z=f(x,y)   │")
    print("│ (Conical)   │ (Avoided)   │ (Parabolic) │")
    print("└─────────────┴─────────────┴─────────────┘")
    print("\nKey features:")
    print("✓ Clean 3-panel comparison layout")
    print("✓ PES range: [-5,5] × [-5,5] (focused visualization)")
    print("✓  trajectories: width=6 (highly visible)")
    print("✓ Small start markers: size=3 (don't hide trajectories)")
    print("✓ No obstructing annotations - clean visualization")
    print("✓ Four comparison plots created:")
    print("  1. Adiabatic populations (all trajectories) - thick lines")
    print("  2. Adiabatic populations (by trajectory type) - detailed analysis")
    print("  3. Forces comparison (by trajectory type)")
    print("  4. PES comparison with enhanced trajectory visibility")


if __name__ == "__main__":
    main()

Comparison Topology Plotter v3 - Clean Annotations

1. Creating adiabatic population comparison (all trajectories)...
Loading quantum_dynamics_e0.0_z0.0_berry_geom_20250422_142720.h5...
Successfully loaded 100 trajectories

Z=0 (Conical):
  - Total trajectories: 100
  - Pre-looped: 50, Regular: 50
Loading quantum_dynamics_e0.0_z0.05_berry_geom_20250422_132932.h5...
Successfully loaded 100 trajectories

Z=0.05 (Avoided):
  - Total trajectories: 100
  - Pre-looped: 50, Regular: 50
Loading quantum_dynamics_e0.0_zfunc_berry_geom_20250422_192616.h5...
Successfully loaded 100 trajectories

Z=f(x,y) (Parabolic):
  - Total trajectories: 100
  - Pre-looped: 50, Regular: 50


Gtk-Message: 16:28:44.767: Failed to load module "xapp-gtk3-module"
Gtk-Message: 16:28:44.768: Not loading module "atk-bridge": The functionality is provided by GTK natively. Please try to not load it.

GTK+ 2.x symbols detected. Using GTK+ 2.x and GTK+ 3 in the same process is not supported.
Gtk-Message: 16:28:44.834: Failed to load module "canberra-gtk-module"

GTK+ 2.x symbols detected. Using GTK+ 2.x and GTK+ 3 in the same process is not supported.
Gtk-Message: 16:28:44.835: Failed to load module "canberra-gtk-module"


Saved: comparison_v3_adiabatic_all_trajectories.html

2. Creating adiabatic population comparison (by trajectory type)...
Loading quantum_dynamics_e0.0_z0.0_berry_geom_20250422_142720.h5...
Successfully loaded 100 trajectories
Loading quantum_dynamics_e0.0_z0.05_berry_geom_20250422_132932.h5...
Successfully loaded 100 trajectories
Loading quantum_dynamics_e0.0_zfunc_berry_geom_20250422_192616.h5...
Successfully loaded 100 trajectories


Gtk-Message: 16:28:46.243: Failed to load module "xapp-gtk3-module"
Gtk-Message: 16:28:46.244: Not loading module "atk-bridge": The functionality is provided by GTK natively. Please try to not load it.

GTK+ 2.x symbols detected. Using GTK+ 2.x and GTK+ 3 in the same process is not supported.
Gtk-Message: 16:28:46.345: Failed to load module "canberra-gtk-module"

GTK+ 2.x symbols detected. Using GTK+ 2.x and GTK+ 3 in the same process is not supported.
Gtk-Message: 16:28:46.346: Failed to load module "canberra-gtk-module"


Saved: comparison_v3_adiabatic_by_type.html

3. Creating forces comparison (by trajectory type)...
Loading quantum_dynamics_e0.0_z0.0_berry_geom_20250422_142720.h5...
Successfully loaded 100 trajectories
Loading quantum_dynamics_e0.0_z0.05_berry_geom_20250422_132932.h5...
Successfully loaded 100 trajectories
Loading quantum_dynamics_e0.0_zfunc_berry_geom_20250422_192616.h5...
Successfully loaded 100 trajectories


Gtk-Message: 16:28:47.732: Failed to load module "xapp-gtk3-module"
Gtk-Message: 16:28:47.732: Not loading module "atk-bridge": The functionality is provided by GTK natively. Please try to not load it.

GTK+ 2.x symbols detected. Using GTK+ 2.x and GTK+ 3 in the same process is not supported.
Gtk-Message: 16:28:47.804: Failed to load module "canberra-gtk-module"

GTK+ 2.x symbols detected. Using GTK+ 2.x and GTK+ 3 in the same process is not supported.
Gtk-Message: 16:28:47.805: Failed to load module "canberra-gtk-module"


Saved: comparison_v3_forces_by_type.html

4. Creating PES comparison with  trajectories...
Loading quantum_dynamics_e0.0_z0.0_berry_geom_20250422_142720.h5...
Successfully loaded 100 trajectories
Loading quantum_dynamics_e0.0_z0.05_berry_geom_20250422_132932.h5...
Successfully loaded 100 trajectories
Loading quantum_dynamics_e0.0_zfunc_berry_geom_20250422_192616.h5...
Successfully loaded 100 trajectories
Saved: comparison_v3_pes_trajectories.html

Comparison topology plots v3 completed!

Layout structure:
┌─────────────┬─────────────┬─────────────┐
│   Z=0       │   Z=0.05    │  Z=f(x,y)   │
│ (Conical)   │ (Avoided)   │ (Parabolic) │
└─────────────┴─────────────┴─────────────┘

Key features:
✓ Clean 3-panel comparison layout
✓ PES range: [-5,5] × [-5,5] (focused visualization)
✓  trajectories: width=6 (highly visible)
✓ Small start markers: size=3 (don't hide trajectories)
✓ No obstructing annotations - clean visualization
✓ Four comparison plots created:
  1. Adiabatic populations (a

Gtk-Message: 16:28:51.470: Failed to load module "xapp-gtk3-module"
Gtk-Message: 16:28:51.471: Not loading module "atk-bridge": The functionality is provided by GTK natively. Please try to not load it.

GTK+ 2.x symbols detected. Using GTK+ 2.x and GTK+ 3 in the same process is not supported.
Gtk-Message: 16:28:51.566: Failed to load module "canberra-gtk-module"

GTK+ 2.x symbols detected. Using GTK+ 2.x and GTK+ 3 in the same process is not supported.
Gtk-Message: 16:28:51.566: Failed to load module "canberra-gtk-module"
