Skip to content

Commit

Permalink
Merge pull request #84 from qutech/hotfix/cache_intermediates_hilbert
Browse files Browse the repository at this point in the history
Cache more intermediates in calculation with unitaries
  • Loading branch information
thangleiter committed May 22, 2022
2 parents aba0d05 + affb039 commit 2ebe601
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
11 changes: 10 additions & 1 deletion filter_functions/numeric.py
Expand Up @@ -591,8 +591,12 @@ def calculate_noise_operators_from_scratch(
noise_operators = np.zeros((len(omega), len(n_opers), d, d), dtype=complex)

if cache_intermediates:
phase_factors_cache = np.empty((len(dt), len(omega)), dtype=complex)
int_cache = np.empty((len(dt), len(omega), d, d), dtype=complex)
sum_cache = np.empty((len(dt), len(omega), len(n_opers), d, d), dtype=complex)
else:
phase_factors = np.empty((len(omega),), dtype=complex)
int_buf = np.empty((len(omega), d, d), dtype=complex)
sum_buf = np.empty((len(omega), len(n_opers), d, d), dtype=complex)

# Set up reusable expressions
Expand All @@ -607,17 +611,22 @@ def calculate_noise_operators_from_scratch(
if cache_intermediates:
# Assign references to the locations in the cache for the quantities
# that should be stored
phase_factors = phase_factors_cache[g]
int_buf = int_cache[g]
sum_buf = sum_cache[g]

phase_factors = util.cexp(omega*t[g], out=phase_factors)
int_buf = _first_order_integral(omega, eigvals[g], dt[g], exp_buf, int_buf)
sum_buf = expr_1(n_opers_transformed[:, g], util.cexp(omega*t[g])[:, None, None]*int_buf,
sum_buf = expr_1(n_opers_transformed[:, g], phase_factors[:, None, None]*int_buf,
out=sum_buf)

noise_operators += expr_2(eigvecs_propagated[g].conj(), sum_buf, eigvecs_propagated[g],
out=sum_buf)

if cache_intermediates:
intermediates = dict(n_opers_transformed=n_opers_transformed,
first_order_integral=int_cache,
phase_factors=phase_factors_cache,
noise_operators_step=sum_cache)
return noise_operators, intermediates

Expand Down
23 changes: 22 additions & 1 deletion tests/test_core.py
Expand Up @@ -608,7 +608,7 @@ def test_pulse_sequence_attributes_concat(self):
self.assertEqual(periodic_pulse._tau, pulse.tau * 7)
self.assertArrayAlmostEqual(periodic_pulse.t, [0, *periodic_pulse.dt.cumsum()])

def test_cache_intermediates(self):
def test_cache_intermediates_liouville(self):
"""Test caching of intermediate elements"""
pulse = testutil.rand_pulse_sequence(3, 4, 2, 3)
omega = util.get_sample_frequencies(pulse, 33, spacing='linear')
Expand All @@ -627,6 +627,27 @@ def test_cache_intermediates(self):
eigvecs_prop.conj(), pulse.basis, eigvecs_prop)
self.assertArrayAlmostEqual(pulse._intermediates['basis_transformed'], basis_transformed,
atol=1e-14)
self.assertArrayAlmostEqual(pulse._intermediates['phase_factors'],
util.cexp(omega*pulse.t[:-1, None]))

def test_cache_intermediates_hilbert(self):
pulse = testutil.rand_pulse_sequence(3, 4, 2, 3)
omega = util.get_sample_frequencies(pulse, 33, spacing='linear')
unitary, intermediates = numeric.calculate_noise_operators_from_scratch(
pulse.eigvals, pulse.eigvecs, pulse.propagators, omega, pulse.n_opers, pulse.n_coeffs,
pulse.dt, pulse.t, cache_intermediates=True
)

pulse._intermediates.update(**intermediates)

self.assertArrayAlmostEqual(pulse._intermediates['noise_operators_step'].sum(0), unitary)
self.assertArrayAlmostEqual(pulse._intermediates['n_opers_transformed'],
numeric._transform_hamiltonian(pulse.eigvecs,
pulse.n_opers,
pulse.n_coeffs))
self.assertArrayAlmostEqual(pulse._intermediates['phase_factors'],
util.cexp(omega*pulse.t[:-1, None]))


def test_cache_filter_function(self):
omega = rng.random(32)
Expand Down

0 comments on commit 2ebe601

Please sign in to comment.