In [168]:
import qutip as qt
from qutip_qip.operations.gates import hadamard_transform as hadamard

import numpy as np
from matplotlib import pyplot as plt

import jax
import jax.numpy as jnp
from diffrax import Dopri5, Dopri8, Tsit5, PIDController

In [169]:
from optimize import optimize_pulses
from time_interval import TimeInterval
from objective import Objective

# Toy model

In [170]:
initial = qt.qeye(2)
target  = hadamard()

initial = qt.sprepost(initial, initial.dag())
target  = qt.sprepost(target , target.dag() )

In [171]:
σx = qt.sigmax()
σy = qt.sigmay()
σz = qt.sigmaz()

ω, Δ, γ, π = 0.1, 1.0, 0.1, np.pi

Hd = 1/2 * (ω * σz + Δ * σx)

H_d =  qt.liouvillian(H=Hd, c_ops=[np.sqrt(γ) * qt.sigmam()])

In [172]:
interval = TimeInterval(tlist=[0, 2*π])

In [173]:
def sin(t, α):
    return α[0] * np.sin(α[1] * t + α[2])

def grad_sin(t, α, idx):
    if idx==0: return np.sin(α[1] * t + α[2])
    if idx==1: return α[0] * np.cos(α[1] * t + α[2]) * t
    if idx==2: return α[0] * np.cos(α[1] * t + α[2])
    if idx==3: return α[0] * np.cos(α[1] * t + α[2]) * α[1] # w.r.t. time

In [174]:
sin_x = lambda t, p: sin(t, p)
sin_y = lambda t, q: sin(t, q)
sin_z = lambda t, r: sin(t, r)

In [175]:
Hc  = [σx, σy, σz]
H_c = [qt.liouvillian(H) for H in Hc]

H = [ H_d,
     [H_c[0], sin_x, {"grad": grad_sin}],
     [H_c[1], sin_y, {"grad": grad_sin}],
     [H_c[2], sin_z, {"grad": grad_sin}]]

In [176]:
p_init = [1, 1, 0]
q_init = [1, 1, 0]
r_init = [1, 1, 0]

# Default integrator settings - local search only - fix time

GOAT

In [177]:
res_goat = optimize_pulses(
    objectives = [Objective(initial, H, target)],
    pulse_options={
        "p": {
            "guess":  p_init, # p0 * sin(p1 * t + p2)
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        },
        "q": {
            "guess":  q_init, # q0 * sin(q1 * t + q2)
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        },
        "r": {
            "guess":  r_init, # r0 * sin(r1 * t + r2)
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        }
    },
    time_interval = interval,
    algorithm_kwargs = {
        "alg": "GOAT",
        "fid_err_targ": 0.01,
        "method": "basinhopping",
        "disp": False,
        "max_iter": 0, # global optimizer steps
        "seed": 1,
    },
    
)

In [178]:
res_goat

Control Optimization Result
--------------------------
- Started at 2023-11-10 17:41:54
- Number of objectives: 1
- Final fidelity error: 0.04403656960277009
- Final parameters: [[1.0, 0.6462771162438175, 2.1425609700529074], [0.9996894160117288, 0.990898249375864, 0.004499661412899851], [0.8444724950457054, 2.379880939421669e-09, 0.2993300539999815]]
- Number of iterations: 1
- Reason for termination: ['requested number of basinhopping iterations completed successfully']
- Ended at 2023-11-10 17:42:02 (7.8043s)

JAX

In [179]:
def sin_jax(t, α):
    return α[0] * jnp.sin(α[1] * t + α[2])

In [180]:
@jax.jit
def sin_x_jax(t, p, **kwargs):
    return sin_jax(t, p)

@jax.jit
def sin_y_jax(t, q, **kwargs):
    return sin_jax(t, q)

@jax.jit
def sin_z_jax(t, r, **kwargs):
    return sin_jax(t, r)

In [181]:
H_jax = [ H_d,
         [H_c[0], sin_x_jax],
         [H_c[1], sin_y_jax],
         [H_c[2], sin_z_jax]]

In [182]:
res_joat = optimize_pulses(
    objectives = [Objective(initial, H_jax, target)],
    pulse_options={
        "p": {
            "guess":  p_init,
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        },
        "q": {
            "guess":  q_init,
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        },
        "r": {
            "guess":  r_init,
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        }
    },
    time_interval = interval,
    algorithm_kwargs = {
        "alg": "JOAT",
        "fid_err_targ": 0.01,
        "method": "basinhopping",
        "disp": False,
        "max_iter": 0,
        "seed": 1,
    },
)

In [183]:
res_joat

Control Optimization Result
--------------------------
- Started at 2023-11-10 17:42:03
- Number of objectives: 1
- Final fidelity error: 0.04455113743096623
- Final parameters: [[0.999553068793097, 0.9677108939313139, 1.3896276385936615], [0.7368374722530612, 0.7802926694405331, 1.5571836032831659], [0.3879646737896883, 0.0, 1.7601127919800026]]
- Number of iterations: 1
- Reason for termination: ['requested number of basinhopping iterations completed successfully']
- Ended at 2023-11-10 17:42:34 (30.8221s)

# Default integrator settings - local search only - variable time

GOAT

In [184]:
res_joat = optimize_pulses(
    objectives = [Objective(initial, H_jax, target)],
    pulse_options={
        "p": {
            "guess":  p_init,
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        },
        "q": {
            "guess":  q_init,
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        },
        "r": {
            "guess":  r_init,
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        }
    },
    time_interval = interval,
    time_options = {
        "guess": interval.evo_time,
        "bounds": (0, 2*interval.evo_time),
    },
    algorithm_kwargs = {
        "alg": "JOAT",
        "fid_err_targ": 0.01,
        "method": "basinhopping",
        "disp": False,
        "max_iter": 0,
        "seed": 1,
    },
)

In [185]:
res_goat

Control Optimization Result
--------------------------
- Started at 2023-11-10 17:41:54
- Number of objectives: 1
- Final fidelity error: 0.04403656960277009
- Final parameters: [[1.0, 0.6462771162438175, 2.1425609700529074], [0.9996894160117288, 0.990898249375864, 0.004499661412899851], [0.8444724950457054, 2.379880939421669e-09, 0.2993300539999815]]
- Number of iterations: 1
- Reason for termination: ['requested number of basinhopping iterations completed successfully']
- Ended at 2023-11-10 17:42:02 (7.8043s)

JAX

In [186]:
res_joat = optimize_pulses(
    objectives = [Objective(initial, H_jax, target)],
    pulse_options={
        "p": {
            "guess":  p_init,
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        },
        "q": {
            "guess":  q_init,
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        },
        "r": {
            "guess":  r_init,
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        }
    },
    time_interval = interval,
    time_options = {
        "guess": interval.evo_time,
        "bounds": (0, 2*interval.evo_time),
    },
    algorithm_kwargs = {
        "alg": "JOAT",
        "fid_err_targ": 0.01,
        "method": "basinhopping",
        "disp": False,
        "max_iter": 0,
        "seed": 1,
    },
)

In [187]:
res_joat

Control Optimization Result
--------------------------
- Started at 2023-11-10 17:42:59
- Number of objectives: 1
- Final fidelity error: 0.02492020332425142
- Final parameters: [[1.0, 0.9485234732252087, 2.0175236634621174], [0.9999985248371083, 0.9174976046839698, 0.9463379704948467], [0.9605607074573544, 1.0, 0.9800389764496323], [3.846593105352696]]
- Number of iterations: 1
- Reason for termination: ['requested number of basinhopping iterations completed successfully']
- Ended at 2023-11-10 17:43:27 (27.9107s)

# Same integrator settings - global search

GOAT

In [200]:
res_goat = optimize_pulses(
    objectives = [Objective(initial, H, target)],
    pulse_options={
        "p": {
            "guess":  p_init, # p0 * sin(p1 * t + p2)
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        },
        "q": {
            "guess":  q_init, # q0 * sin(q1 * t + q2)
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        },
        "r": {
            "guess":  r_init, # r0 * sin(r1 * t + r2)
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        }
    },
    time_interval = interval,
    time_options = {
        "guess": interval.evo_time,
        "bounds": (0, 2*interval.evo_time),
    },
    algorithm_kwargs = {
        "alg": "GOAT",
        "method": "dual_annealing",
        "fid_err_targ": 0.01,
        "disp": False,
        "max_iter": 100,
        "seed": 1,
    },
    integrator_kwargs = {
        "atol": 1e-6,
        "rtol": 1e-6,
        "method": "dop853",
    },
)

In [201]:
res_goat

Control Optimization Result
--------------------------
- Started at 2023-11-10 17:47:06
- Number of objectives: 1
- Final fidelity error: 0.011313517577368495
- Final parameters: [[-0.5869472475953812, 0.604422073378828, 5.622922553448685], [0.7824922242136919, 0.973781729282308, 2.836355179096559], [-0.9156613739350358, 0.6829865555995811, 2.3392860719886186], [2.4695917434767205]]
- Number of iterations: 100
- Reason for termination: ['Maximum number of iteration reached']
- Ended at 2023-11-10 17:47:39 (33.052s)

JAX

In [202]:
res_joat = optimize_pulses(
    objectives = [Objective(initial, H_jax, target)],
    pulse_options={
        "p": {
            "guess":  p_init,
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        },
        "q": {
            "guess":  q_init,
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        },
        "r": {
            "guess":  r_init,
            "bounds": [(-1, 1), (0, 1), (0, 2*np.pi)],
        }
    },
    time_interval = interval,
    time_options = {
        "guess": interval.evo_time,
        "bounds": (0, 2*interval.evo_time),
    },
    algorithm_kwargs = {
        "alg": "JOAT",
        "method": "dual_annealing",
        "fid_err_targ": 0.01,
        "disp": False,
        "max_iter": 100,
        "seed": 1,
    },
    integrator_kwargs = {
        "stepsize_controller": PIDController(
            atol = 1e-6,
            rtol = 1e-6,
        ),
        "solver": Dopri8(),
    },
)

In [203]:
res_joat

Control Optimization Result
--------------------------
- Started at 2023-11-10 17:47:39
- Number of objectives: 1
- Final fidelity error: 0.009547201430487956
- Final parameters: [[-0.03375980138181052, 0.05141741975191592, 4.5515703640921386], [-0.915119546944092, 0.04875635041843301, 3.0809082065847653], [0.5851213371793615, 0.15078751496709192, 0.9344299824482413], [1.9927253517080148]]
- Number of iterations: 45
- Reason for termination: fid_err_targ reached
- Ended at 2023-11-10 17:48:13 (33.713s)