Skip to content

Commit

Permalink
Tests and small fix to concatenate()
Browse files Browse the repository at this point in the history
  • Loading branch information
thangleiter committed Jun 1, 2020
1 parent 4e6897f commit 56e0cc9
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 19 deletions.
36 changes: 21 additions & 15 deletions filter_functions/pulse_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
#
# Contact email: tobias.hangleiter@rwth-aachen.de
# =============================================================================
"""
This module defines the PulseSequence class, the central object of the
formalism.

"""This module defines the PulseSequence class, the package's central object.
Classes
-------
Expand Down Expand Up @@ -1033,10 +1032,6 @@ def _parse_Hamiltonian(H: Hamiltonian, n_dt: int,
raise ValueError('Expected all operators in {} '.format(H_str) +
'to be two-dimensional!')

if len(set(oper.shape for oper in opers)) != 1:
raise ValueError('Expected all operators in {} '.format(H_str) +
'to have the same dimensions!')

if len(set(opers[0].shape)) != 1:
raise ValueError('Expected operators in {} '.format(H_str) +
'to be square!')
Expand Down Expand Up @@ -1523,18 +1518,25 @@ def concatenate(pulses: Iterable[PulseSequence],
equal_n_opers = (n_opers_present.sum(axis=0) > 1).any()
if omega is None:
cached_ctrl_mat = [pls.is_cached('R') for pls in pulses]
equal_omega = util.all_array_equal(
(pls.omega for pls in compress(pulses, cached_ctrl_mat))
)
if any(cached_ctrl_mat):
equal_omega = util.all_array_equal(
(pls.omega for pls in compress(pulses, cached_ctrl_mat))
)
else:
cached_omega = [pls.is_cached('omega') for pls in pulses]
equal_omega = util.all_array_equal(
(pls.omega for pls in compress(pulses, cached_omega))
)

if not equal_omega:
if calc_filter_function:
raise ValueError("Calculation of filter function forced " +
"but not all pulses have the same " +
"frequencies cached and none were supplied!")
"but not all pulses have the same " +
"frequencies cached and none were supplied!")
if calc_pulse_correlation_ff:
raise ValueError("Cannot compute the pulse correlation " +
"filter functions; do not have the " +
"frequencies at which to evaluate.")
"filter functions; do not have the " +
"frequencies at which to evaluate.")

return newpulse

Expand All @@ -1547,7 +1549,11 @@ def concatenate(pulses: Iterable[PulseSequence],
# Can reuse cached filter functions or calculation explicitly asked
# for; run calculation. Get the index of the first pulse with cached FF
# to steal some attributes from.
ind = np.nonzero(cached_ctrl_mat)[0][0]
if any(cached_ctrl_mat):
ind = np.nonzero(cached_ctrl_mat)[0][0]
else:
ind = np.nonzero(cached_omega)[0][0]

omega = pulses[ind].omega

if not equal_n_opers:
Expand Down
20 changes: 18 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_pulse_sequence_constructor(self):

dt[idx] *= -1
with self.assertRaises(ValueError):
# imagniary dt
# imaginary dt
dt = dt.astype(complex)
dt[idx] += 1j
ff.PulseSequence(H_c, H_n, dt)
Expand All @@ -86,6 +86,14 @@ def test_pulse_sequence_constructor(self):
# Noise Hamiltonian not list or tuple
ff.PulseSequence(H_c, np.array(H_n), dt)

with self.assertRaises(TypeError):
# Element of control Hamiltonian not list or tuple
ff.PulseSequence([np.array(H_c[0])], H_n, dt)

with self.assertRaises(TypeError):
# Element of noise Hamiltonian not list or tuple
ff.PulseSequence(H_c, [np.array(H_n[0])], dt)

idx = testutil.rng.randint(0, 3)
with self.assertRaises(TypeError):
# Control Hamiltonian element not list or tuple
Expand Down Expand Up @@ -350,7 +358,9 @@ def test_pulse_sequence_attributes(self):

A.cleanup('conservative')
self.assertIsNotNone(A.HD)
A.cleanup('conservative')
self.assertIsNotNone(A.HV)
A.cleanup('conservative')
self.assertIsNotNone(A.Q)

aliases = {'eigenvalues': '_HD',
Expand Down Expand Up @@ -455,6 +465,13 @@ def test_pulse_sequence_attributes_concat(self):
[[Z, np.abs(testutil.rng.randn(2))]],
[1, 1])

# Concatenate with different noise opers
pulses = [testutil.rand_pulse_sequence(2, 1) for _ in range(2)]
pulses[0].omega = np.arange(10)
pulses[1].omega = np.arange(10)
newpulse = ff.concatenate(pulses, calc_filter_function=True)
self.assertTrue(newpulse.is_cached('filter function'))

pulse_12 = pulse_1 @ pulse_2
pulse_21 = pulse_2 @ pulse_1

Expand Down Expand Up @@ -617,7 +634,6 @@ def test_filter_function(self):
self.assertArrayAlmostEqual(F_fidelity,
F_generalized.trace(axis1=2, axis2=3))


def test_pulse_correlation_filter_function(self):
"""
Test calculation of pulse correlation filter function and control
Expand Down
59 changes: 57 additions & 2 deletions tests/test_sequencing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,45 @@
This module tests the concatenation functionality for PulseSequence's.
"""

from copy import copy
import string
from itertools import product
from random import sample

import numpy as np

import filter_functions as ff
from filter_functions import util
from filter_functions import pulse_sequence, util
from tests import testutil


class ConcatenationTest(testutil.TestCase):

def test_concatenate_base(self):
"""Basic functionality."""
pulse_1, pulse_2 = [testutil.rand_pulse_sequence(2, 1, 2, 3)
for _ in range(2)]

# Trivial case, copy
c_pulse = ff.concatenate([pulse_1])
self.assertEqual(pulse_1, c_pulse)
self.assertFalse(pulse_1 is c_pulse)

# Don't cache filter function, expect same result as with
# concatenate_without_filter_function
c_pulse_1 = ff.concatenate([pulse_1, pulse_2],
calc_filter_function=False)
c_pulse_2 = pulse_sequence.concatenate_without_filter_function(
[pulse_1, pulse_2], return_identifier_mappings=False
)
self.assertEqual(c_pulse_1, c_pulse_2)

# Try concatenation with different frequencies but FF calc. forced
with self.assertRaises(ValueError):
pulse_1.omega = [1, 2]
pulse_2.omega = [3, 4]
ff.concatenate([pulse_1, pulse_2], calc_filter_function=True)

def test_concatenate_without_filter_function(self):
"""Concatenate two Spin Echos without filter functions."""
tau = 10
Expand Down Expand Up @@ -67,6 +93,35 @@ def test_concatenate_without_filter_function(self):
CPMG_concat = ff.concatenate((SE_1, SE_2), omega=omega)
self.assertIsNotNone(CPMG_concat._F)

pulse = testutil.rand_pulse_sequence(2, 1, 2, 3)
# Concatenate pulses without filter functions
with self.assertRaises(TypeError):
# Not all pulse sequence
pulse_sequence.concatenate_without_filter_function([pulse, 2])

with self.assertRaises(TypeError):
# Not iterable
pulse_sequence.concatenate_without_filter_function(pulse)

with self.assertRaises(ValueError):
# Incompatible Hamiltonian shapes
pulse_sequence.concatenate_without_filter_function(
[testutil.rand_pulse_sequence(2, 1),
testutil.rand_pulse_sequence(3, 1)]
)

with self.assertRaises(ValueError):
# Incompatible bases
pulse = testutil.rand_pulse_sequence(4, 1, btype='GGM')
cpulse = copy(pulse)
cpulse.basis = ff.Basis.pauli(2)
pulse_sequence.concatenate_without_filter_function([pulse, cpulse])

pulse = pulse_sequence.concatenate_without_filter_function(
[pulse, pulse], return_identifier_mappings=False
)
self.assertFalse(pulse.is_cached('filter function'))

def test_concatenate_with_filter_function_SE1(self):
"""
Concatenate two Spin Echos with the first having a filter function.
Expand Down Expand Up @@ -312,7 +367,7 @@ def test_concatenate_split_cnot(self):
rtol, atol)

def test_different_n_opers(self):
"""Test behavior when concatenating with different n_opers"""
"""Test behavior when concatenating with different n_opers."""
for d, n_dt in zip(testutil.rng.randint(2, 5, 20),
testutil.rng.randint(1, 11, 20)):
opers = testutil.rand_herm_traceless(d, 10)
Expand Down

0 comments on commit 56e0cc9

Please sign in to comment.