# Investigating solve speeds


In [1]:
from ipywidgets import Button

Button()

Button(style=ButtonStyle())

In [1]:
import pickle
import time
from collections import defaultdict
from typing import Sequence
import warnings

import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
from qutip import sesolve, Options, fidelity

from demonstration_utils import *
import interaction_constants
from qubit_system.geometry.regular_lattice_1d import RegularLattice1D
from qubit_system.qubit_system_classes import EvolvingQubitSystem
from qubit_system.utils.ghz_states import StandardGHZState
from qubit_system.utils.interpolation import get_hamiltonian_coeff_linear_interpolation
from qubit_system.utils.states import get_states, get_label_from_state, get_product_basis_states_index


In [2]:
N_RYD = 50
C6 = interaction_constants.get_C6(N_RYD)

LATTICE_SPACING = 1.5e-6

print(f"C6: {C6:.3e}")
characteristic_V = C6 / (LATTICE_SPACING ** 6)
print(f"Characteristic V: {characteristic_V:.3e} Hz")


C6: 1.555e-26
Characteristic V: 1.365e+09 Hz


In [3]:
def evaluate_solve_options(e_qs: EvolvingQubitSystem, options_args: Sequence[dict], repeats: int = 10):
    hamiltonian = e_qs.get_hamiltonian()
    solve_durations = defaultdict(list)
    fidelities = defaultdict(list)
    errors = defaultdict(list)
    warnings_ = defaultdict(list)

    for _ in tqdm(range(repeats)):
        for i, _options_args in enumerate(options_args):
            start_time = time.process_time() 
            with warnings.catch_warnings(record=True) as w:
                try:
                    solve_result = sesolve(
                        hamiltonian,
                        e_qs.psi_0,
                        e_qs.t_list,
                        options=Options(
                            store_states=True,
                            **_options_args
                        ),
                    )
                except Exception as e:
                    errors[i].append(e)
                    warnings_[i] += w
                    continue    
                warnings_[i] += w

            end_time = time.process_time()
            
            solve_duration = end_time - start_time 
            # print(f"duration: {solve_duration}")
            
            solve_durations[i].append(solve_duration)
            final_state = solve_result.states[-1]
            fidelity_with_ghz = fidelity(final_state, e_qs.ghz_state.get_state_tensor(True)) ** 2
            fidelity_with_ghz_as = fidelity(final_state, e_qs.ghz_state.get_state_tensor(False)) ** 2
            fidelities[i].append((fidelity_with_ghz, fidelity_with_ghz_as))
            
    for i, _options_args in enumerate(options_args):
        
        print("\n")
        print(f"args: {_options_args}")
        if len(errors[i]) > 0:
            print(f"\terrors: {len(errors[i])}")
            print(f"\t\t{errors[i][0]}")
        if len(warnings_[i]) > 0:
            print(f"\twarnings: {len(warnings_[i])}")
            print(f"\t\t{warnings_[i][0]}")
        if len(fidelities[i]) != 0:
            _durations = np.array(solve_durations[i])
            _fidelities = np.array(fidelities[i])
            all_fidelities_equal = np.isclose(_fidelities, _fidelities[0]).all()

            print(f"\tfidelities: {_fidelities[0]}, all equal: {all_fidelities_equal}")
            print(f"\ttimes:\tmin:\t{_durations.min():.3f}"
                  f"\n\t\tmean:\t{_durations.mean():.3f}"
                  f"\n\t\tmax:\t{_durations.max():.3f}"
                  )
    # print(solve_durations)
    # print(fidelities)
    

In [4]:
N = 8
t = 0.5e-6
e_qs_8_manual = EvolvingQubitSystem(
    N=N, V=C6, geometry=RegularLattice1D(spacing=LATTICE_SPACING),
    Omega=get_hamiltonian_coeff_linear_interpolation([0, t / 6, t * 5 / 6, t], [0, 480e6, 490e6, 0]),
    Delta=get_hamiltonian_coeff_linear_interpolation([0, t], [1.5e9, 1.2e9]),
    t_list=np.linspace(0, t, 100),
    ghz_state=StandardGHZState(N)
)

In [5]:
with open("reinforcement_learning/results/20190814_211310.pkl", "rb") as f:
    data = pickle.load(f)
t_list = data['evolving_qubit_system_kwargs']['t_list']
solve_t_list = np.linspace(t_list[0], t_list[-1], 300)

data['evolving_qubit_system_kwargs'].pop('t_list')
e_qs_8_rl = EvolvingQubitSystem(
    **data['evolving_qubit_system_kwargs'],
    Omega=get_hamiltonian_coeff_linear_interpolation(t_list, data['protocol'].Omega),
    Delta=get_hamiltonian_coeff_linear_interpolation(t_list, data['protocol'].Delta),
    t_list=solve_t_list,
)

In [18]:
# Expected fidelities, baseline results

options_args = [{}]

evaluate_solve_options(e_qs_8_rl, options_args)
evaluate_solve_options(e_qs_8_manual, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




args: {}
	fidelities: [0.93543184 0.03470263], all equal: True
	times:	min:	29.725
		mean:	31.673
		max:	32.906


HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




args: {}
	fidelities: [7.74922829e-04 9.90268868e-01], all equal: True
	times:	min:	3.245
		mean:	3.346
		max:	3.446


In [6]:
options_args = [
    {'order': 12},
    {'order': 5},
    {'order': 4},
    {'order': 3},
]
evaluate_solve_options(e_qs_8_rl, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




args: {'order': 12}
	fidelities: [0.93543184 0.03470263], all equal: True
	times:	min:	30.457
		mean:	35.152
		max:	45.790


args: {'order': 5}
	fidelities: [0.93543184 0.03470263], all equal: True
	times:	min:	30.695
		mean:	35.839
		max:	45.965


args: {'order': 4}
	fidelities: [0.93549824 0.03461201], all equal: True
	times:	min:	31.498
		mean:	35.284
		max:	40.769


args: {'order': 3}
	errors: 10
		ODE integration error: Try to increase the allowed number of substeps by increasing the nsteps parameter in the Options class.


In [9]:
options_args = [
    {'method': 'bdf', 'order': 5},
    {'method': 'adams', 'order': 5},
]
evaluate_solve_options(e_qs_8_manual, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




args: {'method': 'bdf', 'order': 5}
	fidelities: [7.72532112e-04 9.90266678e-01], all equal: True
	times:	min:	4.164
		mean:	4.441
		max:	4.872


args: {'method': 'adams', 'order': 5}
	fidelities: [7.74922829e-04 9.90268868e-01], all equal: True
	times:	min:	3.305
		mean:	3.551
		max:	3.857


In [8]:
options_args = [
    {'atol': 1e-8, 'rtol': 1e-6},  # default tolerances
    {'atol': 1e-7, 'rtol': 1e-6},
    {'atol': 1e-6, 'rtol': 1e-6},
    {'atol': 1e-5, 'rtol': 1e-6},
]
evaluate_solve_options(e_qs_8_rl, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




args: {'atol': 1e-08, 'rtol': 1e-06}
	fidelities: [0.93543184 0.03470263], all equal: True
	times:	min:	30.245
		mean:	30.714
		max:	31.045


args: {'atol': 1e-07, 'rtol': 1e-06}
	fidelities: [0.93537871 0.0345893 ], all equal: True
	times:	min:	30.655
		mean:	31.235
		max:	31.720


args: {'atol': 1e-06, 'rtol': 1e-06}
	fidelities: [0.92554539 0.0339033 ], all equal: True
	times:	min:	30.432
		mean:	31.023
		max:	31.660


args: {'atol': 1e-05, 'rtol': 1e-06}
	fidelities: [0.07638473 0.00277465], all equal: True
	times:	min:	30.091
		mean:	30.547
		max:	31.206


In [9]:
options_args = [
    {'atol': 1e-8, 'rtol': 1e-6},  # default tolerances
    {'atol': 1e-8, 'rtol': 3e-5},
    {'atol': 1e-8, 'rtol': 1e-5},
    {'atol': 1e-8, 'rtol': 1e-4},
]
evaluate_solve_options(e_qs_8_rl, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




args: {'atol': 1e-08, 'rtol': 1e-06}
	fidelities: [0.93543184 0.03470263], all equal: True
	times:	min:	29.130
		mean:	30.501
		max:	31.729


args: {'atol': 1e-08, 'rtol': 3e-05}
	fidelities: [0.93586656 0.03388395], all equal: True
	times:	min:	28.722
		mean:	30.129
		max:	31.009


args: {'atol': 1e-08, 'rtol': 1e-05}
	fidelities: [0.9356986  0.03443109], all equal: True
	times:	min:	28.443
		mean:	29.745
		max:	30.318


args: {'atol': 1e-08, 'rtol': 0.0001}
	fidelities: [0.92883073 0.03350507], all equal: True
	times:	min:	29.646
		mean:	30.918
		max:	31.717


In [11]:
options_args = [
    {'atol': 1e-8, 'rtol': 1e-6},  # default tolerances
    {'order': 5, 'atol': 1e-5, 'rtol': 1e-4},
    {'order': 5, 'atol': 1e-4, 'rtol': 1e-3},
]
evaluate_solve_options(e_qs_8_rl, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




args: {'atol': 1e-08, 'rtol': 1e-06}
	fidelities: [0.93543184 0.03470263], all equal: True
	times:	min:	27.982
		mean:	28.221
		max:	28.364


args: {'order': 5, 'atol': 1e-05, 'rtol': 0.0001}
	fidelities: [4.45141690e-05 1.60698363e-06], all equal: True
	times:	min:	26.644
		mean:	26.720
		max:	26.816


args: {'order': 5, 'atol': 0.0001, 'rtol': 0.001}
	fidelities: [0. 0.], all equal: True
	times:	min:	24.627
		mean:	24.758
		max:	24.859


In [12]:
options_args = [
    {'atol': 1e-8, 'rtol': 1e-6},  # default
    {'order': 8, 'atol': 1e-5, 'rtol': 1e-4},
    {'order': 5, 'atol': 1e-5, 'rtol': 1e-4},
]
evaluate_solve_options(e_qs_8_rl, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




args: {'atol': 1e-08, 'rtol': 1e-06}
	fidelities: [0.93543184 0.03470263], all equal: True
	times:	min:	27.943
		mean:	28.838
		max:	30.961


args: {'order': 8, 'atol': 1e-05, 'rtol': 0.0001}
	fidelities: [4.45141690e-05 1.60698363e-06], all equal: True
	times:	min:	26.614
		mean:	27.034
		max:	27.717


args: {'order': 5, 'atol': 1e-05, 'rtol': 0.0001}
	fidelities: [4.45141690e-05 1.60698363e-06], all equal: True
	times:	min:	26.669
		mean:	27.336
		max:	29.518


In [16]:
options_args = [
    {'min_step': 0},  # default
    {'min_step': 1e-25},
    {'min_step': 1e-20},
    {'min_step': 1e-15},
    {'min_step': 1e-10},
]
evaluate_solve_options(e_qs_8_manual, options_args)


HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




args: {'min_step': 0}
	fidelities: [7.74922829e-04 9.90268868e-01], all equal: True
	times:	min:	3.195
		mean:	3.298
		max:	3.438


args: {'min_step': 1e-25}
	fidelities: [7.74922829e-04 9.90268868e-01], all equal: True
	times:	min:	3.240
		mean:	3.324
		max:	3.516


args: {'min_step': 1e-20}
	fidelities: [7.74922829e-04 9.90268868e-01], all equal: True
	times:	min:	3.184
		mean:	3.289
		max:	3.415


args: {'min_step': 1e-15}
	fidelities: [7.74922829e-04 9.90268868e-01], all equal: True
	times:	min:	3.166
		mean:	3.285
		max:	3.489


args: {'min_step': 1e-10}
	errors: 10
		ODE integration error: Try to increase the allowed number of substeps by increasing the nsteps parameter in the Options class.


In [14]:
options_args = [
    {'min_step': 1e-12},
    {'min_step': 1e-10},
    {'min_step': 1e-8},
]
evaluate_solve_options(e_qs_8_manual, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




args: {'min_step': 1e-12}
	errors: 3


args: {'min_step': 1e-10}
	errors: 3


args: {'min_step': 1e-08}
	errors: 3


In [15]:
options_args = [
    {'min_step': 1e-25},
    {'min_step': 1e-20},
    {'min_step': 1e-15},
]
evaluate_solve_options(e_qs_8_manual, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




args: {'min_step': 1e-25}
	fidelities: [7.74922829e-04 9.90268868e-01], all equal: True
	times:	min:	2.920
		mean:	3.308
		max:	3.501


args: {'min_step': 1e-20}
	fidelities: [7.74922829e-04 9.90268868e-01], all equal: True
	times:	min:	3.318
		mean:	3.338
		max:	3.382


args: {'min_step': 1e-15}
	fidelities: [7.74922829e-04 9.90268868e-01], all equal: True
	times:	min:	3.316
		mean:	3.345
		max:	3.408


In [16]:
options_args = [
    {'min_step': 1e-15},
    {'min_step': 1e-14},
    {'min_step': 1e-13},
]
evaluate_solve_options(e_qs_8_rl, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




args: {'min_step': 1e-15}
	fidelities: [0.93543184 0.03470263], all equal: True
	times:	min:	28.625
		mean:	29.184
		max:	31.669


args: {'min_step': 1e-14}
	fidelities: [0.93543184 0.03470263], all equal: True
	times:	min:	28.230
		mean:	28.915
		max:	29.261


args: {'min_step': 1e-13}
	fidelities: [0.93543184 0.03470263], all equal: True
	times:	min:	28.212
		mean:	28.720
		max:	29.248


In [17]:
options_args = [
    {'min_step': 1e-13},
    {'min_step': 1e-13, 'order': 5},
    {'min_step': 1e-12, 'order': 5},
]
evaluate_solve_options(e_qs_8_rl, options_args)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))
  self.messages.get(istate, unexpected_istate_msg)))





args: {'min_step': 1e-13}
	fidelities: [0.93543184 0.03470263], all equal: True
	times:	min:	28.068
		mean:	28.861
		max:	29.309


args: {'min_step': 1e-13, 'order': 5}
	fidelities: [0.93543184 0.03470263], all equal: True
	times:	min:	28.847
		mean:	29.093
		max:	29.412


args: {'min_step': 1e-12, 'order': 5}
	errors: 3


  self.messages.get(istate, unexpected_istate_msg)))
