In [1]:
#!/usr/bin/env python3
"""
Individual Topology Plotter v3 - 2x2 Layout
PES range [-5,5] with repositioned annotations that don't hide the PES
"""

import numpy as np
import h5py
import json
import os
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
import random

# Set default renderer
pio.renderers.default = 'browser'

class IndividualTopologyPlotterV3:
    def __init__(self, data_directory=None):
        """Initialize the plotter with data directory"""
        self.data_dir = data_directory or os.path.expanduser("~/Desktop/hpc/upper/results/a1")
        
    def load_h5_file(self, filename):
        """Load trajectory data from HDF5 file"""
        filepath = os.path.join(self.data_dir, filename)
        
        if not os.path.exists(filepath):
            print(f"Error: File {filepath} not found")
            return None, None
            
        results = []
        parameters = {}
        
        print(f"Loading {filename}...")
        
        with h5py.File(filepath, 'r') as f:
            # Load parameters
            param_group = f['parameters']
            for key in param_group.attrs:
                try:
                    value = param_group.attrs[key]
                    if isinstance(value, str) and (value.startswith('{') or value.startswith('[')):
                        parameters[key] = json.loads(value)
                    else:
                        parameters[key] = value
                except:
                    parameters[key] = param_group.attrs[key]
            
            # Load trajectories
            traj_count = 0
            for traj_name in sorted([k for k in f.keys() if k.startswith('trajectory_')]):
                traj_group = f[traj_name]
                traj = {}
                
                # Load metadata
                traj['type'] = traj_group.attrs.get('type', 'regular')
                traj['is_prelooping'] = traj_group.attrs.get('is_prelooping', False)
                traj['initial_position'] = [
                    traj_group.attrs.get('initial_x', 0),
                    traj_group.attrs.get('initial_y', 0)
                ]
                traj['initial_velocity'] = [
                    traj_group.attrs.get('initial_vx', 0),
                    traj_group.attrs.get('initial_vy', 0)
                ]
                
                # Load trajectory data
                traj['positions'] = traj_group['positions'][()]
                traj['velocities'] = traj_group['velocities'][()]
                traj['momenta'] = traj_group['momenta'][()]
                
                # Reconstruct complex psi values
                psi_real = traj_group['psi_t_real'][()]
                psi_imag = traj_group['psi_t_imag'][()]
                traj['psi_t'] = psi_real + 1j * psi_imag
                
                traj['energies'] = traj_group['energies'][()]
                traj['populations'] = traj_group['populations'][()]
                traj['adiabatic_pops'] = traj_group['adiabatic_pops'][()]
                traj['berry_curvatures'] = traj_group['berry_curvatures'][()]
                traj['forces'] = traj_group['forces'][()]
                traj['index'] = traj_count
                
                results.append(traj)
                traj_count += 1
                
        print(f"Successfully loaded {traj_count} trajectories")
        return results, parameters
    
    def hamiltonian(self, x, y, parameters):
        """Compute the Hamiltonian for PES calculation"""
        if parameters['z_choice'] == 'constant':
            z = parameters['z_val']
        else:
            z = np.sqrt(x**2 + (1 - parameters['e']**2) * y**2)
        
        off_diag = x - 1j * np.sqrt(1 - parameters['e']**2) * y
        
        return 0.5 * parameters['s'] * np.array([
            [parameters['a'] * z, off_diag], 
            [np.conjugate(off_diag), -z]
        ], dtype=complex)
    
    def analyze_trajectory_types(self, results):
        """Analyze and separate trajectory types"""
        prelooping_trajs = []
        regular_trajs = []
        
        for i, traj in enumerate(results):
            traj['index'] = i
            if traj.get('is_prelooping', False):
                prelooping_trajs.append(traj)
            else:
                regular_trajs.append(traj)
        
        return prelooping_trajs, regular_trajs
    
    def calculate_average_quantities(self, trajectories, ns):
        """Calculate average quantities for a set of trajectories"""
        if not trajectories:
            return None
            
        n_traj = len(trajectories)
        
        # Initialize averages
        adiabatic_avg = np.zeros((ns, 2))
        forces_avg = np.zeros((ns, 2))
        energies_avg = np.zeros(ns)
        
        # Sum over trajectories
        for traj in trajectories:
            adiabatic_avg += traj['adiabatic_pops']
            forces_avg += traj['forces']
            energies_avg += traj['energies']
        
        # Average
        adiabatic_avg /= n_traj
        forces_avg /= n_traj
        energies_avg /= n_traj
        
        return {
            'adiabatic_pops': adiabatic_avg,
            'forces': forces_avg,
            'energies': energies_avg
        }
    
    def plot_individual_topology(self, filename, topology_label):
        """
        Create 2x2 layout plot for individual topology:
        Top row: Two adiabatic population plots
        Bottom left: PES with  trajectories (range [-5,5])
        Bottom right: Forces plot
        Annotations positioned to not hide PES
        """
        # Load data
        results, parameters = self.load_h5_file(filename)
        if results is None:
            return None
        
        # Analyze trajectory types
        prelooping_trajs, regular_trajs = self.analyze_trajectory_types(results)
        
        # Create 2x2 subplot layout
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=[
                'Adiabatic Population Dynamics (All Trajectories)',
                'Adiabatic Population Dynamics (By Trajectory Type)', 
                'PES with  Trajectories',
                'Forces by Trajectory Type'
            ],
            specs=[
                [{'type': 'scatter'}, {'type': 'scatter'}],
                [{'type': 'scatter3d'}, {'type': 'scatter'}]
            ],
            vertical_spacing=0.12,
            horizontal_spacing=0.10
        )
        
        # Time array and simulation parameters
        ns = parameters['ns']
        dt = parameters['dt']
        time = np.arange(ns) * dt
        total_time = parameters['Ttotal']
        
        # Calculate averages for each trajectory type
        all_avg = self.calculate_average_quantities(results, ns)
        preloop_avg = self.calculate_average_quantities(prelooping_trajs, ns) if prelooping_trajs else None
        regular_avg = self.calculate_average_quantities(regular_trajs, ns) if regular_trajs else None
        
        # 1. Adiabatic Population Dynamics (All Trajectories) - Top Left
        if all_avg:
            fig.add_trace(
                go.Scatter(x=time, y=all_avg['adiabatic_pops'][:, 0], mode='lines',
                          name='Lower Adiabatic (All)', line=dict(color='blue', width=3)),
                row=1, col=1
            )
            fig.add_trace(
                go.Scatter(x=time, y=all_avg['adiabatic_pops'][:, 1], mode='lines',
                          name='Upper Adiabatic (All)', line=dict(color='red', width=3)),
                row=1, col=1
            )
        
        # Add reference line manually for top left
        fig.add_trace(
            go.Scatter(x=[time[0], time[-1]], y=[0.5, 0.5], mode='lines',
                      line=dict(color='gray', width=1, dash='dot'),
                      opacity=0.5, showlegend=False),
            row=1, col=1
        )
        
        # 2. Adiabatic Population Dynamics (By Type) - Top Right
        if all_avg:
            fig.add_trace(
                go.Scatter(x=time, y=all_avg['adiabatic_pops'][:, 0], mode='lines',
                          name='All - Lower', line=dict(color='blue', width=2)),
                row=1, col=2
            )
            fig.add_trace(
                go.Scatter(x=time, y=all_avg['adiabatic_pops'][:, 1], mode='lines',
                          name='All - Upper', line=dict(color='red', width=2)),
                row=1, col=2
            )
        
        if preloop_avg:
            fig.add_trace(
                go.Scatter(x=time, y=preloop_avg['adiabatic_pops'][:, 0], mode='lines',
                          name=f'Pre-looped ({len(prelooping_trajs)}) - Lower', 
                          line=dict(color='blue', width=2, dash='dash')),
                row=1, col=2
            )
            fig.add_trace(
                go.Scatter(x=time, y=preloop_avg['adiabatic_pops'][:, 1], mode='lines',
                          name=f'Pre-looped ({len(prelooping_trajs)}) - Upper', 
                          line=dict(color='red', width=2, dash='dash')),
                row=1, col=2
            )
        
        if regular_avg:
            fig.add_trace(
                go.Scatter(x=time, y=regular_avg['adiabatic_pops'][:, 0], mode='lines',
                          name=f'Regular ({len(regular_trajs)}) - Lower', 
                          line=dict(color='blue', width=2, dash='dot')),
                row=1, col=2
            )
            fig.add_trace(
                go.Scatter(x=time, y=regular_avg['adiabatic_pops'][:, 1], mode='lines',
                          name=f'Regular ({len(regular_trajs)}) - Upper', 
                          line=dict(color='red', width=2, dash='dot')),
                row=1, col=2
            )
        
        # Add reference line manually for top right
        fig.add_trace(
            go.Scatter(x=[time[0], time[-1]], y=[0.5, 0.5], mode='lines',
                      line=dict(color='gray', width=1, dash='dot'),
                      opacity=0.5, showlegend=False),
            row=1, col=2
        )
        
        # 3. PES with  Trajectories (-5 to 5 range) - Bottom Left
        # Create compact PES grid (-5 to 5)
        x_range = np.linspace(-5, 5, 50)
        y_range = np.linspace(-5, 5, 50)
        X, Y = np.meshgrid(x_range, y_range)
        
        # Calculate PES
        Z_lower = np.zeros_like(X)
        Z_upper = np.zeros_like(X)
        
        for i in range(len(x_range)):
            for j in range(len(y_range)):
                x, y = X[i, j], Y[i, j]
                H = self.hamiltonian(x, y, parameters)
                eigenvalues, _ = np.linalg.eigh(H)
                Z_lower[i, j] = eigenvalues[0]
                Z_upper[i, j] = eigenvalues[1]
        
        # Add PES surfaces
        fig.add_trace(
            go.Surface(x=X, y=Y, z=Z_lower, 
                      colorscale='Blues', opacity=0.6,
                      showscale=False, name='Lower PES'),
            row=2, col=1
        )
        
        fig.add_trace(
            go.Surface(x=X, y=Y, z=Z_upper, 
                      colorscale='Reds', opacity=0.6,
                      showscale=False, name='Upper PES'),
            row=2, col=1
        )
        
        # Select and plot trajectories with  lines
        selected_preloop = random.sample(prelooping_trajs, min(3, len(prelooping_trajs))) if prelooping_trajs else []
        selected_regular = random.sample(regular_trajs, min(3, len(regular_trajs))) if regular_trajs else []
        
        trajectory_colors = ['lime', 'cyan', 'yellow', 'orange', 'magenta', 'white']
        color_idx = 0
        
        selected_info = []
        all_selected = selected_preloop + selected_regular
        
        for traj in all_selected:
            positions = traj['positions']
            traj_idx = traj['index']
            traj_type = "Pre-looped" if traj.get('is_prelooping', False) else "Regular"
            
            # Filter positions within compact PES range (-5 to 5)
            mask = (np.abs(positions[:, 0]) <= 5) & (np.abs(positions[:, 1]) <= 5)
            valid_pos = positions[mask]
            
            if len(valid_pos) < 5:
                continue
            
            # Calculate energy along trajectory
            energies = []
            valid_indices = np.where(mask)[0]
            
            for j, (x, y) in enumerate(valid_pos):
                if j < len(valid_indices):
                    psi = traj['psi_t'][valid_indices[j]]
                    H = self.hamiltonian(x, y, parameters)
                    E = np.real(np.vdot(psi, np.dot(H, psi)))
                    energies.append(E)
            
            if len(energies) == len(valid_pos):
                # Add  trajectory line (width=8 for maximum visibility)
                fig.add_trace(
                    go.Scatter3d(
                        x=valid_pos[:, 0], 
                        y=valid_pos[:, 1], 
                        z=energies,
                        mode='lines',
                        line=dict(color=trajectory_colors[color_idx], width=8),  #  trajectories
                        name=f'Traj {traj_idx+1} ({traj_type})',
                        showlegend=False
                    ),
                    row=2, col=1
                )
                
                # Add much smaller start point marker (size=3 to avoid hiding trajectory)
                fig.add_trace(
                    go.Scatter3d(
                        x=[valid_pos[0, 0]], 
                        y=[valid_pos[0, 1]], 
                        z=[energies[0]],
                        mode='markers',
                        marker=dict(size=3, color='white',  # Much smaller marker
                                   line=dict(color=trajectory_colors[color_idx], width=1),
                                   symbol='circle'),
                        showlegend=False
                    ),
                    row=2, col=1
                )
                
                selected_info.append(f"Traj {traj_idx+1}: {traj_type}")
            
            color_idx += 1
        
        # 4. Forces by Trajectory Type - Bottom Right
        if all_avg:
            fig.add_trace(
                go.Scatter(x=time, y=np.clip(all_avg['forces'][:, 0], -1000, 1000), mode='lines',
                          name='All - Force X', line=dict(color='green', width=2)),
                row=2, col=2
            )
            fig.add_trace(
                go.Scatter(x=time, y=np.clip(all_avg['forces'][:, 1], -1000, 1000), mode='lines',
                          name='All - Force Y', line=dict(color='purple', width=2)),
                row=2, col=2
            )
        
        if preloop_avg:
            fig.add_trace(
                go.Scatter(x=time, y=np.clip(preloop_avg['forces'][:, 0], -1000, 1000), mode='lines',
                          name=f'Pre-looped - Fx', 
                          line=dict(color='green', width=2, dash='dash')),
                row=2, col=2
            )
            fig.add_trace(
                go.Scatter(x=time, y=np.clip(preloop_avg['forces'][:, 1], -1000, 1000), mode='lines',
                          name=f'Pre-looped - Fy', 
                          line=dict(color='purple', width=2, dash='dash')),
                row=2, col=2
            )
        
        if regular_avg:
            fig.add_trace(
                go.Scatter(x=time, y=np.clip(regular_avg['forces'][:, 0], -1000, 1000), mode='lines',
                          name=f'Regular - Fx', 
                          line=dict(color='green', width=2, dash='dot')),
                row=2, col=2
            )
            fig.add_trace(
                go.Scatter(x=time, y=np.clip(regular_avg['forces'][:, 1], -1000, 1000), mode='lines',
                          name=f'Regular - Fy', 
                          line=dict(color='purple', width=2, dash='dot')),
                row=2, col=2
            )
        
        # Add zero reference line manually for forces plot
        fig.add_trace(
            go.Scatter(x=[time[0], time[-1]], y=[0, 0], mode='lines',
                      line=dict(color='gray', width=1, dash='dot'),
                      opacity=0.5, showlegend=False),
            row=2, col=2
        )
        
        # Create comprehensive title with all simulation information
        z_info = f"z={parameters['z_val']:.2f}" if parameters['z_choice'] == 'constant' else "z=f(x,y)"
        title_text = (
            f'{topology_label} - Quantum Nonadiabatic Dynamics Analysis<br>'
            f'<sub>System Parameters: e={parameters["e"]:.1f}, a={parameters["a"]:.1f}, s={parameters["s"]:.1f}, {z_info} | '
            f'Simulation: T={total_time:.1f}, dt={dt:.3f}, Steps={ns} | '
            f'Forces: Berry + Geometric | Trajectories: {len(results)} total ({len(prelooping_trajs)} Pre-looped, {len(regular_trajs)} Regular) | '
            f'PES Range: [-5,5] × [-5,5] | Trajectory Width: 8, Start Marker: 3</sub>'
        )
        
        # Update layout with proper sizing and spacing
        fig.update_layout(
            height=1000,
            width=1600,
            title=dict(
                text=title_text,
                x=0.5,
                font=dict(size=16)
            ),
            showlegend=True,
            legend=dict(
                x=1.02, y=1.0, 
                font=dict(size=9),
                bgcolor='rgba(255, 255, 255, 0.9)',
                bordercolor='gray',
                borderwidth=1
            )
        )
        
        # Update axes labels and ranges
        # Top row - Population plots
        fig.update_xaxes(title_text=f"Time (dt={dt:.3f})", row=1, col=1)
        fig.update_yaxes(title_text="Adiabatic Population", range=[0, 1.05], row=1, col=1)
        
        fig.update_xaxes(title_text=f"Time (dt={dt:.3f})", row=1, col=2)
        fig.update_yaxes(title_text="Adiabatic Population", range=[0, 1.05], row=1, col=2)
        
        # Bottom right - Forces plot
        fig.update_xaxes(title_text=f"Time (dt={dt:.3f})", row=2, col=2)
        fig.update_yaxes(title_text="Force Components (±1000)", range=[-1000, 1000], row=2, col=2)
        
        # Bottom left - 3D PES scene (updated range)
        fig.update_scenes(
            xaxis=dict(title='X Position', range=[-5, 5]),
            yaxis=dict(title='Y Position', range=[-5, 5]),
            zaxis=dict(title='Adiabatic Energy'),
            camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)),
            row=2, col=1
        )
        
        # Add compact detailed annotation positioned to NOT hide PES (top right corner)
        detailed_info = (
            f'<b>Simulation Details:</b><br>'
            f'System: {topology_label}<br>'
            f'Params: e={parameters["e"]:.1f}, a={parameters["a"]:.1f}, s={parameters["s"]:.1f}<br>'
            f'Z-coupling: {z_info}<br>'
            f'Time: 0→{total_time:.1f} (dt={dt:.4f})<br>'
            f'Steps: {ns}<br>'
            f'Forces: Berry + Geometric<br>'
            f'Trajectories: {len(results)} total<br>'
            f'Pre-looped: {len(prelooping_trajs)}<br>'
            f'Regular: {len(regular_trajs)}<br>'
            f'PES: [-5,5]×[-5,5]<br>'
            f'Traj Width: 8px, Start: 3px'
        )
        
        # Position annotation at top right to avoid hiding PES
        fig.add_annotation(
            xref='paper', yref='paper',
            x=0.98, y=0.98,  # Top right corner
            text=detailed_info,
            showarrow=False,
            align='right',
            bgcolor='rgba(255, 255, 255, 0.95)',
            bordercolor='gray',
            borderwidth=1,
            borderpad=4,
            font=dict(size=8),
            xanchor='right',
            yanchor='top'
        )
        
        return fig


def main():
    """Main function to create individual topology plots v3"""
    
    # Initialize plotter
    data_dir = os.path.expanduser("~/Desktop/hpc/upper/results/a1")
    plotter = IndividualTopologyPlotterV3(data_directory=data_dir)
    
    # Define topology files
    topology_files = {
        'Z=0 (Conical)': 'quantum_dynamics_e0.0_z0.0_berry_geom_20250422_142720.h5',
        'Z=0.05 (Avoided)': 'quantum_dynamics_e0.0_z0.05_berry_geom_20250422_132932.h5',
        'Z=f(x,y) (Parabolic)': 'quantum_dynamics_e0.0_zfunc_berry_geom_20250422_192616.h5'
    }
    
    print("Individual Topology Plotter v3 - Repositioned Annotations")
    print("=" * 60)
    
    # Create individual topology plots with 2x2 layout
    for label, filename in topology_files.items():
        print(f"\nCreating clean 2x2 layout plot for {label}...")
        
        fig = plotter.plot_individual_topology(filename, label)
        
        if fig is not None:
            # Save and show
            safe_label = label.replace('=', '').replace('(', '').replace(')', '').replace(' ', '_').replace(',', '')
            html_filename = f"individual_v3_{safe_label.lower()}.html"
            fig.write_html(html_filename)
            print(f"Saved: {html_filename}")
            fig.show()
        else:
            print(f"Failed to create plot for {label}")
    
    print("\nIndividual topology plots v3 completed!")
    print("\nLayout structure:")
    print("┌─────────────────────────┬─────────────────────────┐")
    print("│ Adiabatic Pop (All)     │ Adiabatic Pop (By Type) │")
    print("├─────────────────────────┼─────────────────────────┤")
    print("│ PES +  Trajectories  │ Forces by Type          │")
    print("└─────────────────────────┴─────────────────────────┘")
    print("\nKey improvements:")
    print("✓ PES range: [-5,5] × [-5,5] (focused visualization)")
    print("✓  trajectories: width=8 (highly visible)")
    print("✓ Small start markers: size=3 (don't hide trajectories)")
    print("✓ Annotation repositioned: Top right corner (doesn't hide PES)")
    print("✓ Compact annotation text for better readability")
    print("✓ Clear PES visualization without obstruction")


if __name__ == "__main__":
    main()

Individual Topology Plotter v3 - Repositioned Annotations

Creating clean 2x2 layout plot for Z=0 (Conical)...
Loading quantum_dynamics_e0.0_z0.0_berry_geom_20250422_142720.h5...
Successfully loaded 100 trajectories
Saved: individual_v3_z0_conical.html


Gtk-Message: 10:35:58.994: Failed to load module "xapp-gtk3-module"
Gtk-Message: 10:35:58.994: Not loading module "atk-bridge": The functionality is provided by GTK natively. Please try to not load it.

GTK+ 2.x symbols detected. Using GTK+ 2.x and GTK+ 3 in the same process is not supported.
Gtk-Message: 10:35:59.073: Failed to load module "canberra-gtk-module"

GTK+ 2.x symbols detected. Using GTK+ 2.x and GTK+ 3 in the same process is not supported.
Gtk-Message: 10:35:59.074: Failed to load module "canberra-gtk-module"



Creating clean 2x2 layout plot for Z=0.05 (Avoided)...
Loading quantum_dynamics_e0.0_z0.05_berry_geom_20250422_132932.h5...
Successfully loaded 100 trajectories
Saved: individual_v3_z0.05_avoided.html


Gtk-Message: 10:36:00.263: Failed to load module "xapp-gtk3-module"
Gtk-Message: 10:36:00.263: Not loading module "atk-bridge": The functionality is provided by GTK natively. Please try to not load it.

GTK+ 2.x symbols detected. Using GTK+ 2.x and GTK+ 3 in the same process is not supported.
Gtk-Message: 10:36:00.343: Failed to load module "canberra-gtk-module"

GTK+ 2.x symbols detected. Using GTK+ 2.x and GTK+ 3 in the same process is not supported.
Gtk-Message: 10:36:00.344: Failed to load module "canberra-gtk-module"



Creating clean 2x2 layout plot for Z=f(x,y) (Parabolic)...
Loading quantum_dynamics_e0.0_zfunc_berry_geom_20250422_192616.h5...
Successfully loaded 100 trajectories
Saved: individual_v3_zfxy_parabolic.html


Gtk-Message: 10:36:01.927: Failed to load module "xapp-gtk3-module"
Gtk-Message: 10:36:01.928: Not loading module "atk-bridge": The functionality is provided by GTK natively. Please try to not load it.

GTK+ 2.x symbols detected. Using GTK+ 2.x and GTK+ 3 in the same process is not supported.
Gtk-Message: 10:36:02.011: Failed to load module "canberra-gtk-module"

GTK+ 2.x symbols detected. Using GTK+ 2.x and GTK+ 3 in the same process is not supported.
Gtk-Message: 10:36:02.012: Failed to load module "canberra-gtk-module"



Individual topology plots v3 completed!

Layout structure:
┌─────────────────────────┬─────────────────────────┐
│ Adiabatic Pop (All)     │ Adiabatic Pop (By Type) │
├─────────────────────────┼─────────────────────────┤
│ PES +  Trajectories  │ Forces by Type          │
└─────────────────────────┴─────────────────────────┘

Key improvements:
✓ PES range: [-5,5] × [-5,5] (focused visualization)
✓  trajectories: width=8 (highly visible)
✓ Small start markers: size=3 (don't hide trajectories)
✓ Annotation repositioned: Top right corner (doesn't hide PES)
✓ Compact annotation text for better readability
✓ Clear PES visualization without obstruction
