# Investigating solve speeds


In [20]:
import pickle
import time
from collections import defaultdict
from typing import Sequence
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
from qutip import sesolve, Options, fidelity

from demonstration_utils import *
import interaction_constants
from qubit_system.geometry.regular_lattice_1d import RegularLattice1D
from qubit_system.qubit_system_classes import EvolvingQubitSystem
from qubit_system.utils.ghz_states import StandardGHZState
from qubit_system.utils.interpolation import get_hamiltonian_coeff_linear_interpolation
from qubit_system.utils.states import get_states, get_label_from_state, get_product_basis_states_index


In [2]:
N_RYD = 50
C6 = interaction_constants.get_C6(N_RYD)

LATTICE_SPACING = 1.5e-6

print(f"C6: {C6:.3e}")
characteristic_V = C6 / (LATTICE_SPACING ** 6)
print(f"Characteristic V: {characteristic_V:.3e} Hz")


C6: 1.555e-26
Characteristic V: 1.365e+09 Hz


In [41]:
def evaluate_solve_options(e_qs: EvolvingQubitSystem, options_args: Sequence[dict], repeats: int = 10):
    hamiltonian = e_qs.get_hamiltonian()
    solve_durations = defaultdict(list)
    fidelities = defaultdict(list)
    errors = defaultdict(list)
    for _ in tqdm(range(repeats)):
        for i, _options_args in enumerate(options_args):
            start_time = time.process_time() 
            try:
                solve_result = sesolve(
                    hamiltonian,
                    e_qs.psi_0,
                    e_qs.t_list,
                    options=Options(
                        store_states=True,
                        **_options_args
                    ),
                )
            except Exception as e:
                errors[i].append(e)
                continue
            end_time = time.process_time()
            
            solve_duration = end_time - start_time 
            # print(f"duration: {solve_duration}")
            
            solve_durations[i].append(solve_duration)
            final_state = solve_result.states[-1]
            fidelity_with_ghz = fidelity(final_state, e_qs.ghz_state.get_state_tensor(True)) ** 2
            fidelity_with_ghz_as = fidelity(final_state, e_qs.ghz_state.get_state_tensor(False)) ** 2
            fidelities[i].append((fidelity_with_ghz, fidelity_with_ghz_as))
            
    for i, _options_args in enumerate(options_args):
        
        print("\n")
        print(f"args: {_options_args}")
        if len(errors[i]) > 0:
            print(f"\terrors: {len(errors)}")
        if len(fidelities[i]) != 0:
            _durations = np.array(solve_durations[i])
            _fidelities = np.array(fidelities[i])
            all_fidelities_equal = np.isclose(_fidelities, _fidelities[0]).all()

            print(f"\tfidelities: {_fidelities[0]}, all equal: {all_fidelities_equal}")
            print(f"\ttimes:\tmin:\t{_durations.min():.3f}"
                  f"\n\t\tmean:\t{_durations.mean():.3f}"
                  f"\n\t\tmax:\t{_durations.max():.3f}"
                  )
    # print(solve_durations)
    # print(fidelities)
    

In [42]:
N = 8
t = 0.5e-6
e_qs_8_manual = EvolvingQubitSystem(
    N=N, V=C6, geometry=RegularLattice1D(spacing=LATTICE_SPACING),
    Omega=get_hamiltonian_coeff_linear_interpolation([0, t / 6, t * 5 / 6, t], [0, 480e6, 490e6, 0]),
    Delta=get_hamiltonian_coeff_linear_interpolation([0, t], [1.5e9, 1.2e9]),
    t_list=np.linspace(0, t, 100),
    ghz_state=StandardGHZState(N)
)

In [43]:
with open("reinforcement_learning/results/20190814_211310.pkl", "rb") as f:
    data = pickle.load(f)
t_list = data['evolving_qubit_system_kwargs']['t_list']
solve_t_list = np.linspace(t_list[0], t_list[-1], 300)

data['evolving_qubit_system_kwargs'].pop('t_list')
e_qs_8_rl = EvolvingQubitSystem(
    **data['evolving_qubit_system_kwargs'],
    Omega=get_hamiltonian_coeff_linear_interpolation(t_list, data['protocol'].Omega),
    Delta=get_hamiltonian_coeff_linear_interpolation(t_list, data['protocol'].Delta),
    t_list=solve_t_list,
)

In [29]:
options_args = [
    {'order': 12},
    {'order': 5},
    {'order': 4},
    {'order': 3},
]
evaluate_solve_options(e_qs_8_rl, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))




args: {'order': 12}
	fidelities: [0.96721226 0.18609636], all equal: True
	times:	min:	12.578
		mean:	13.444
		max:	14.078


args: {'order': 5}
	fidelities: [0.96721226 0.18609636], all equal: True
	times:	min:	12.688
		mean:	14.378
		max:	16.406


args: {'order': 4}
	fidelities: [0.96721158 0.18604303], all equal: True
	times:	min:	13.969
		mean:	14.544
		max:	15.203


args: {'order': 3}
	errors: 4


In [34]:
options_args = [
    {'method': 'bdf', 'order': 5},
    {'method': 'adams', 'order': 5},
]
evaluate_solve_options(e_qs_8_manual, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))



args: {'method': 'bdf', 'order': 5}
	fidelities: [0.02779446 0.99512143], all equal: True
	times:	min:	1.781
		mean:	2.536
		max:	2.984


args: {'method': 'adams', 'order': 5}
	fidelities: [0.02783744 0.99512253], all equal: True
	times:	min:	2.250
		mean:	2.283
		max:	2.406


In [39]:
options_args = [
    {'atol': 1e-8, 'rtol': 1e-6},  # default tolerances
    {'atol': 1e-5, 'rtol': 1e-4},
    {'atol': 1e-4, 'rtol': 1e-3},
]
evaluate_solve_options(e_qs_8_rl, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))



args: {'atol': 1e-08, 'rtol': 1e-06}
	fidelities: [0.96721226 0.18609636], all equal: True
	times:	min:	10.797
		mean:	11.981
		max:	13.391


args: {'atol': 1e-05, 'rtol': 0.0001}
	fidelities: [0.00667188 0.00126767], all equal: True
	times:	min:	10.969
		mean:	11.703
		max:	13.766


args: {'atol': 0.0001, 'rtol': 0.001}
	fidelities: [0. 0.], all equal: True
	times:	min:	10.109
		mean:	10.475
		max:	11.750


In [40]:
options_args = [
    {'atol': 1e-8, 'rtol': 1e-6},  # default tolerances
    {'order': 5, 'atol': 1e-5, 'rtol': 1e-4},
    {'order': 5, 'atol': 1e-4, 'rtol': 1e-3},
]
evaluate_solve_options(e_qs_8_rl, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))



args: {'atol': 1e-08, 'rtol': 1e-06}
	fidelities: [0.96721226 0.18609636], all equal: True
	times:	min:	11.234
		mean:	12.259
		max:	12.875


args: {'order': 5, 'atol': 1e-05, 'rtol': 0.0001}
	fidelities: [0.00667188 0.00126767], all equal: True
	times:	min:	11.250
		mean:	11.952
		max:	13.203


args: {'order': 5, 'atol': 0.0001, 'rtol': 0.001}
	fidelities: [0. 0.], all equal: True
	times:	min:	10.047
		mean:	10.530
		max:	11.078


In [36]:
options_args = [
    {'atol': 1e-8, 'rtol': 1e-6},  # default
    {'order': 8, 'atol': 1e-5, 'rtol': 1e-4},
    {'order': 5, 'atol': 1e-5, 'rtol': 1e-4},
]
evaluate_solve_options(e_qs_8_rl, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))



args: {'atol': 1e-08, 'rtol': 1e-06}
	fidelities: [0.96721226 0.18609636], all equal: True
	times:	min:	11.484
		mean:	11.859
		max:	12.547


args: {'order': 8, 'atol': 1e-05, 'rtol': 0.0001}
	fidelities: [0.00667188 0.00126767], all equal: True
	times:	min:	10.797
		mean:	11.316
		max:	11.938


args: {'order': 5, 'atol': 1e-05, 'rtol': 0.0001}
	fidelities: [0.00667188 0.00126767], all equal: True
	times:	min:	10.781
		mean:	11.408
		max:	13.016
