In [None]:
import jax.numpy as jnp
from static_potential import StaticPotential
import matplotlib.pyplot as plt
from utils import visualize_profiles
from utils import adaptive_grid

from moving_potential import MovingPotentials
from solver import Solver

In [None]:
N_t = 10000
T_max = 30
time_grid = jnp.linspace(0 , T_max, N_t, dtype=jnp.complex64)
time_step = (time_grid[1] - time_grid[0]).real
time_step

In [None]:
#initialize problem dimensions and parameters
eta = 2/5

n_total_times = 20
total_times = jnp.linspace(1, T_max, n_total_times)

n_amplitudes = 20
min_amplitude = 1
max_amplitude = 10
amplitudes = jnp.linspace(min_amplitude, max_amplitude, n_amplitudes)

In [None]:
sp = StaticPotential.from_json('params.json')
mps = MovingPotentials(eta, 1, amplitudes, total_times, sp.x_left, sp.x_right)
sp, mps

In [None]:
coord_profiles, amp_profiles = mps.populate_profiles(time_grid)
coord_profiles.shape

# Create a figure for plotting
fig, ax = plt.subplots(mps.n_profiles, 2,
                       figsize=(12, 6), sharex=True)

visualize_profiles(time_grid, coord_profiles, amp_profiles,
                   total_times, mps.profile_kind_to_index, amplitudes, ax)

plt.show()

In [None]:
N_x, coord_grid, coord_step, momentum_grid, momentum_step = adaptive_grid(
    sp.x_left - 5, sp.x_right + 5, required_energy = 2 * (1 + max_amplitude), B=sp.borne_parameter)

N_x, coord_step, momentum_step

In [None]:
sv = Solver(sp, mps)
psi_0 = sv.init_psi_0(coord_grid)

In [8]:
psi, (norm, kinetic, overlap) = sv.solve(coord_grid, time_grid, momentum_grid)
norm.shape, kinetic.shape, overlap.shape

In [None]:
selected_amp_index = 0
fig, ax = plt.subplots(mps.n_profiles, 1, figsize=(15, 6), sharex=True)

for profile, profile_index in mps.profile_kind_to_index.items():
    for total_time_index, total_time in enumerate(total_times):
        ax[profile_index].plot(
            coord_grid,
            jnp.abs(psi[profile_index, selected_amp_index, total_time_index])**2,
            label=f'Total Time: {total_time:.2f}' + r" $T_{st}$"
        )

    ax[profile_index].set_title(f'{profile} Profile. Final wave function, Tweezer_depth = {amplitudes[selected_amp_index]:.2f}' + r" $A_{st}$")
    ax[profile_index].set_ylabel(r'$|\psi|^2$')
    ax[profile_index].axvline(sp.x_left, ls = ':', label = 'start')
    ax[profile_index].axvline(sp.x_right, ls = '--', label = 'finish')
    ax[profile_index].grid(ls = ':')
    
#ax[0].legend()

for ax_row in ax:
    ax_row.set_xlabel(r'Coordinate, units of $\sigma_{st}$')

plt.tight_layout()
plt.show()

In [None]:
for profile, profile_index in mps.profile_kind_to_index.items():
    for amp_index, amp in enumerate(amplitudes):
        for total_times_index, total_time in enumerate(total_times):
            plt.plot(time_grid, norm[:, profile_index, amp_index, total_times_index])

In [None]:
selected_amp_index = -1
fig, ax = plt.subplots(mps.n_profiles, 1, figsize=(8, 6), sharex=True, sharey=True)

for profile, profile_index in mps.profile_kind_to_index.items():
    for total_time_index, total_time in enumerate(total_times):
        ax[profile_index].plot(
            time_grid,
            jnp.log10(1-overlap[:, profile_index,
                      selected_amp_index, total_time_index].real),
            label=f'Total Time: {total_time:.2f}' + r" $T_{st}$"
        )

    ax[profile_index].set_title(f'{profile} Profile. Evolution of infidelity for tweezer depth: {
                                amplitudes[selected_amp_index]:.2f}' + r" $A_{st}$")
    ax[profile_index].set_ylabel(r'$log_{10}(1 - F)$')
    ax[profile_index].grid(ls = ":")

#ax[0].legend()
for ax_row in ax:
    ax_row.set_xlabel(r'Time, units of $T_{st}$')

plt.tight_layout()
plt.show()

In [None]:
selected_amp_index = 1
fig, ax = plt.subplots(mps.n_profiles, 1, figsize=(8, 6), sharex=True, sharey=True)

for profile, profile_index in mps.profile_kind_to_index.items():
    for total_time_index, total_time in enumerate(total_times):
        ax[profile_index].plot(
            time_grid,
            kinetic[:, profile_index, selected_amp_index, total_time_index],
            label=f'Total Time: {total_time:.2f}' + r" $T_{st}$"
        )

    # Set titles, labels, and legends for each subplot\
    ax[profile_index].set_xticks(jnp.arange(0,time_grid.max().real, 1))
    ax[profile_index].set_title(f'{profile} Profile: Evolution of <p^2 / 2m> \n Amplitude: {
                                amplitudes[selected_amp_index]:.2f}' + r" $A_{st}$")
    ax[profile_index].set_ylabel(r'$<p^2 / 2m>$, units of $A_{st}$')
    ax[profile_index].grid(ls = ":")
    #ax[profile_index].set_xlim(0,10)

#ax[0].legend()

# Set common xlabel for all subplots
for ax_row in ax:
    ax_row.set_xlabel(r'Time, units of $T_{st}$')

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

overlap_last_time = overlap[-1, ...] 
log_infid = jnp.log10(1 - overlap_last_time)

vmin, vmax = log_infid.min(), log_infid.max()

# Create a figure for plotting
fig, ax = plt.subplots(mps.n_profiles, 1, figsize=(10, 6), sharex=True, sharey=True)

# Loop through each profile type (Linear, Minjerk, STA, etc.)
for profile, profile_index in mps.profile_kind_to_index.items():
    
    # Plot heatmap with amplitude on the y-axis, time on the x-axis, and color as overlap values
    c = ax[profile_index].imshow(
        log_infid[profile_index], 
        aspect='auto',
        cmap='inferno_r',  # Use 'hot' colormap for the heatmap
        extent=[total_times[0], total_times[-1], amplitudes[0], amplitudes[-1]],  # Set axis limits
        origin='lower',  # So the amplitude is on the y-axis from bottom to top
        vmin=vmin, vmax=vmax  # Common color scale across all profiles
    )
    
    # Set titles, labels, and colorbar for each subplot
    ax[profile_index].set_title(f'{profile} Overlap Heatmap')
    ax[profile_index].set_xlabel('Time')
    ax[profile_index].set_ylabel('Amplitude')
    fig.colorbar(c, ax=ax[profile_index], label=r'$log_{10}(1 - F)$')

# Adjust layout and display the plot
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

selected_amp_index = 2
overlap_last_time = overlap[-1, ...] 
log_infid = jnp.log10(1 - overlap_last_time)

vmin, vmax = log_infid.min(), log_infid.max()

# Create a figure for plotting
fig, ax = plt.subplots(npn_profiles, 1, figsize=(10, 6), sharex=True, sharey=True)

# Loop through each profile type (Linear, Minjerk, STA, etc.)
for profile_index, profile in profile_kind_to_index.items():
    
    ax[profile_index].plot(log_infid[profile_index, selected_amp_index])
    
    # Set titles, labels, and colorbar for each subplot
    ax[profile_index].set_title(f'{profile} Overlap')
    ax[profile_index].set_xlabel('Total Time')
    ax[profile_index].set_ylabel('log I')

# Adjust layout and display the plot
plt.tight_layout()
plt.show()