In [None]:
import numpy as np
from scipy.linalg import expm
import matplotlib.pyplot as plt

# --- Fixed Physical Parameters for Rb87 ---
# Use more conservative energy scales to avoid overflow
ZEEMAN_HZ_PER_GAUSS = 700e3  # 700 kHz per Gauss (realistic for Rb87 F=1)
COHERENCE_TIME_T2_STAR = 0.5  # seconds

def get_standard_gell_mann_matrices():
    """Returns the 8 standard Gell-Mann matrices."""
    lambda_matrices = {}
    lambda_matrices['lambda1'] = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 0]], dtype=complex)
    lambda_matrices['lambda2'] = np.array([[0, -1j, 0], [1j, 0, 0], [0, 0, 0]], dtype=complex)
    lambda_matrices['lambda3'] = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 0]], dtype=complex)
    lambda_matrices['lambda4'] = np.array([[0, 0, 1], [0, 0, 0], [1, 0, 0]], dtype=complex)
    lambda_matrices['lambda5'] = np.array([[0, 0, -1j], [0, 0, 0], [1j, 0, 0]], dtype=complex)
    lambda_matrices['lambda6'] = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0]], dtype=complex)
    lambda_matrices['lambda7'] = np.array([[0, 0, 0], [0, 0, -1j], [0, 1j, 0]], dtype=complex)
    lambda_matrices['lambda8'] = (1/np.sqrt(3)) * np.array([[1, 0, 0], [0, 1, 0], [0, 0, -2]], dtype=complex)
    return lambda_matrices

def get_Lambda_3_k(k, lambda_3_std, lambda_8_std):
    """Gets the k-parameterized Lambda_3 matrix."""
    theta_k = k * np.pi / 6.0
    return np.cos(theta_k) * lambda_3_std + np.sin(theta_k) * lambda_8_std

def get_Lambda_8_k(k, lambda_3_std, lambda_8_std):
    """Gets the k-parameterized Lambda_8 matrix."""
    theta_k = k * np.pi / 6.0
    return -np.sin(theta_k) * lambda_3_std + np.cos(theta_k) * lambda_8_std

def construct_H_k(k, b_field_gauss, std_matrices, alpha3_coeff, alpha8_coeff, omega1_hz):
    """
    Constructs the Hamiltonian H(k) in Hz units (avoiding overflow).
    
    H(k) = Zeeman_scale * (alpha3*Lambda_3(k) + alpha8*Lambda_8(k)) + omega1 * lambda_1
    
    Parameters:
    - k: parameter value (0 to 1)
    - b_field_gauss: magnetic field in Gauss
    - alpha3_coeff, alpha8_coeff: dimensionless scaling coefficients
    - omega1_hz: coupling strength in Hz (not MHz to avoid large numbers)
    """
    L3k = get_Lambda_3_k(k, std_matrices['lambda3'], std_matrices['lambda8'])
    L8k = get_Lambda_8_k(k, std_matrices['lambda3'], std_matrices['lambda8'])
    L1_std = std_matrices['lambda1']
    
    # Energy scale from B-field (in Hz)
    zeeman_scale_hz = ZEEMAN_HZ_PER_GAUSS * b_field_gauss
    
    # Construct Hamiltonian terms
    H_zeeman_k = zeeman_scale_hz * (alpha3_coeff * L3k + alpha8_coeff * L8k)
    H_coupling = omega1_hz * L1_std
    
    H = H_zeeman_k + H_coupling
    
    # Check for reasonable matrix norms to prevent overflow
    matrix_norm = np.linalg.norm(H)
    if matrix_norm > 1e8:  # If norm > 100 MHz, warn
        print(f"Warning: Large Hamiltonian norm {matrix_norm/1e6:.1f} MHz may cause numerical issues")
    
    return H

def get_initial_state():
    """Returns normalized equal superposition state."""
    psi_initial = (1/np.sqrt(3)) * np.array([1, 1, 1], dtype=complex).reshape(3, 1)
    return psi_initial

def safe_time_evolution(H, psi_initial, t_evolution):
    """
    Safely compute time evolution with overflow protection.
    """
    # Check if the evolution will be stable
    max_eigenvalue = np.max(np.real(np.linalg.eigvals(H)))
    evolution_phase = max_eigenvalue * t_evolution
    
    if abs(evolution_phase) > 50:  # Prevent exp(50) overflow
        print(f"Warning: Large evolution phase {evolution_phase:.1f}, scaling down time")
        t_evolution = 50 / abs(max_eigenvalue)
    
    # Use 2π * Hz units for angular frequency
    U_t = expm(-1j * 2 * np.pi * H * t_evolution)
    psi_final = U_t @ psi_initial
    
    # Check normalization
    norm_sq = np.sum(np.abs(psi_final)**2)
    if not np.isclose(norm_sq, 1.0, atol=1e-10):
        print(f"Warning: Final state norm = {norm_sq:.10f}")
    
    return psi_final, t_evolution

def get_populations(psi_final):
    """Calculate state populations."""
    return (np.abs(psi_final)**2).flatten()

def run_safe_k_scan(std_matrices, psi0, b_field_gauss, time_seconds,
                    alpha3, alpha8, omega1_hz, num_k_points=101):
    """
    Run k-parameter scan with overflow protection.
    """
    k_values = np.linspace(0, 1, num_k_points)
    populations_list = []
    infidelity_signal_list = []
    actual_times = []
    
    # Reference evolution at k=0
    H_ref = construct_H_k(0, b_field_gauss, std_matrices, alpha3, alpha8, omega1_hz)
    psi_ref, t_ref = safe_time_evolution(H_ref, psi0, time_seconds)
    
    print(f"\nSafe k-scan: B={b_field_gauss}G, t_target={time_seconds}s")
    print(f"Alpha3={alpha3}, Alpha8={alpha8}, Omega1={omega1_hz/1e3:.1f}kHz")
    print(f"{'k':<8} | {'Infidelity':<12} | {'Time Used':<12} | {'Pop Sum':<12}")
    print("-" * 50)
    
    for i, k_val in enumerate(k_values):
        H_k = construct_H_k(k_val, b_field_gauss, std_matrices, alpha3, alpha8, omega1_hz)
        psi_k, t_actual = safe_time_evolution(H_k, psi0, time_seconds)
        
        # Calculate populations
        pops = get_populations(psi_k)
        populations_list.append(pops)
        actual_times.append(t_actual)
        
        # Calculate infidelity
        fidelity_sq = np.abs(np.vdot(psi_ref, psi_k))**2
        infidelity = 1.0 - fidelity_sq
        
        # Apply decoherence
        decoherence = np.exp(-t_actual / COHERENCE_TIME_T2_STAR)
        signal = infidelity * decoherence
        infidelity_signal_list.append(signal)
        
        # Print progress for key points
        if i < 3 or abs(k_val - 0.5) < 0.02 or abs(k_val - 1.0) < 0.02:
            print(f"{k_val:<8.2f} | {signal:<12.2e} | {t_actual:<12.3f} | {np.sum(pops):<12.6f}")
    
    print("-" * 50)
    return k_values, np.array(populations_list), np.array(infidelity_signal_list), actual_times

def plot_safe_results(k_values, populations, signals, params):
    """Plot results with better scaling."""
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Population dynamics
    axes[0,0].plot(k_values, populations[:, 0], 'r-', label='|0⟩')
    axes[0,0].plot(k_values, populations[:, 1], 'g-', label='|1⟩') 
    axes[0,0].plot(k_values, populations[:, 2], 'b-', label='|2⟩')
    axes[0,0].set_ylabel('Population')
    axes[0,0].set_title('State Populations vs k')
    axes[0,0].legend()
    axes[0,0].grid(True)
    axes[0,0].set_ylim(0, 1)
    
    # Infidelity signal
    axes[0,1].plot(k_values, signals, 'purple', linewidth=2)
    axes[0,1].set_ylabel('Signal (1 - Fidelity²)')
    axes[0,1].set_title('k-Dependent Signal')
    axes[0,1].grid(True)
    axes[0,1].set_ylim(bottom=0)
    
    # Total population check
    total_pop = np.sum(populations, axis=1)
    axes[1,0].plot(k_values, total_pop, 'k-', linewidth=2)
    axes[1,0].axhline(1.0, color='gray', linestyle='--', alpha=0.7)
    axes[1,0].set_ylabel('Total Population')
    axes[1,0].set_xlabel('k parameter')
    axes[1,0].set_title('Unitarity Check')
    axes[1,0].grid(True)
    
    # Signal strength analysis
    max_signal = np.max(signals)
    signal_range = np.max(signals) - np.min(signals)
    axes[1,1].bar(['Max Signal', 'Signal Range'], [max_signal, signal_range], 
                  color=['red', 'blue'], alpha=0.7)
    axes[1,1].set_ylabel('Signal Strength')
    axes[1,1].set_title('Signal Analysis')
    axes[1,1].grid(True, axis='y')
    
    # Add parameter info
    param_text = (f"B = {params['b_field']}G, t = {params['time']}s\n"
                 f"α₃ = {params['alpha3']}, α₈ = {params['alpha8']}\n"
                 f"Ω₁ = {params['omega1']/1e3:.1f}kHz")
    fig.suptitle(f'k-Parameter Simulation Results\n{param_text}', fontsize=12)
    
    plt.tight_layout()
    plt.show()
    
    # Print analysis
    print(f"\nSignal Analysis:")
    print(f"Maximum signal: {max_signal:.2e}")
    print(f"Signal range: {signal_range:.2e}")
    print(f"Signal-to-max ratio: {signal_range/max_signal:.2f}" if max_signal > 0 else "N/A")

# --- Example Usage with Safe Parameters ---
if __name__ == '__main__':
    # Get matrices and initial state
    std_matrices = get_standard_gell_mann_matrices()
    initial_psi = get_initial_state()
    
    # Conservative parameters to avoid overflow
    simulation_params = {
        'b_field': 1.0,        # 1 Gauss (gives ~700 kHz Zeeman shift)
        'time': 0.01,          # 10 ms evolution (shorter to avoid decoherence)
        'alpha3': 0.5,         # Moderate Zeeman coupling
        'alpha8': 0.2,         # Weaker λ₈ mixing  
        'omega1': 50e3         # 50 kHz Rabi frequency (reasonable for atoms)
    }
    
    print("Running safe k-parameter simulation...")
    print(f"Expected Zeeman scale: {ZEEMAN_HZ_PER_GAUSS * simulation_params['b_field']/1e6:.2f} MHz")
    
    k_vals, pops, signals, times = run_safe_k_scan(
        std_matrices, initial_psi,
        b_field_gauss=simulation_params['b_field'],
        time_seconds=simulation_params['time'],
        alpha3=simulation_params['alpha3'],
        alpha8=simulation_params['alpha8'], 
        omega1_hz=simulation_params['omega1']
    )
    
    plot_safe_results(k_vals, pops, signals, simulation_params)
    
    print("\nSimulation completed without overflow warnings.")