Skip to content

Commit

Permalink
Extend test coverage (#8)
Browse files Browse the repository at this point in the history
* Add tests for liouville_representation, calculate_error_vector_correlation_functions, and the analytic module

* Add tests for basis

* Add some PulseSequence attribute tests

* Hotfix test_error_vector_... to be compatible with numpy < 1.18

* Add plotting test coverage
  • Loading branch information
thangleiter committed Jan 24, 2020
1 parent 98f9064 commit d43832e
Show file tree
Hide file tree
Showing 10 changed files with 395 additions and 29 deletions.
15 changes: 7 additions & 8 deletions filter_functions/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,16 @@ def isorthonorm(self) -> bool:
# unitary.
if self.ndim == 2:
# Only one basis element
dim = 1
self._isorthonorm = True
else:
# Size of the result after multiplication
dim = self.shape[0]

U = self.reshape((dim, -1))
actual = U.conj() @ U.T
target = np.identity(dim)
atol = self._eps*(self.d**2)**3
self._isorthonorm = np.allclose(actual.view(ndarray), target,
atol=atol, rtol=self._rtol)
U = self.reshape((dim, -1))
actual = U.conj() @ U.T
target = np.identity(dim)
atol = self._eps*(self.d**2)**3
self._isorthonorm = np.allclose(actual.view(ndarray), target,
atol=atol, rtol=self._rtol)

return self._isorthonorm

Expand Down
1 change: 1 addition & 0 deletions filter_functions/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ def liouville_representation(U: ndarray, basis: Basis) -> ndarray:
:math:`\mathbb{C}^{d\times d}` with :math:`d` the dimension of the Hilbert
space.
"""
U = np.asanyarray(U)
if basis.btype == 'GGM' and basis.d > 12:
# Can do closed form expansion and overhead compensated
path = ['einsum_path', (0, 1), (0, 1)]
Expand Down
10 changes: 5 additions & 5 deletions filter_functions/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,20 @@
'plot_infidelity_convergence', 'plot_error_transfer_matrix']


def get_bloch_vector(states: Sequence[State], outtype: str = 'np') -> ndarray:
def get_bloch_vector(states: Sequence[State]) -> ndarray:
r"""
Get the Bloch vector from quantum states.
"""
if outtype == 'np':
a = np.array([[(states[i].T.conj() @ util.P_np[j+1] @ states[i])[0, 0]
for i in range(len(states))] for j in range(3)])
else:
if isinstance(states[0], Qobj):
a = np.empty((3, len(states)))
X, Y, Z = util.P_qt[1:]
for i, state in enumerate(states):
a[:, i] = [expect(X, state),
expect(Y, state),
expect(Z, state)]
else:
a = np.einsum('...ij,kil,...lm->k...', np.conj(states), util.P_np[1:],
states)
return a


Expand Down
5 changes: 5 additions & 0 deletions filter_functions/pulse_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,11 @@ def _parse_args(H_c: Hamiltonian, H_n: Hamiltonian, dt: Coefficients,
Function to parse the arguments given at instantiation of the PulseSequence
object.
"""

if not hasattr(dt, '__getitem__'):
raise TypeError('Expected a sequence of time steps, not {}'.format(
type(dt)))

dt = np.asarray(dt)
# Check the time argument for data type and monotonicity (should be
# increasing)
Expand Down
44 changes: 43 additions & 1 deletion tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,42 @@
import numpy as np

import filter_functions as ff
from sparse import COO
from tests import testutil


class BasisTest(testutil.TestCase):

def test_basis_class(self):
def test_basis_constructor(self):
"""Test the constructor for several failure modes"""

# Constructing from given elements should check for __getitem__
with self.assertRaises(TypeError):
_ = ff.Basis(1)

# All elements should be either COO, Qobj, or ndarray
elems = [ff.util.P_np[1], ff.util.P_qt[2],
COO.from_numpy(ff.util.P_np[3]), ff.util.P_qt[0].data]
with self.assertRaises(TypeError):
_ = ff.Basis(elems)

# Too many elements
with self.assertRaises(ValueError):
_ = ff.Basis(np.random.randn(5, 2, 2))

# Properly normalized
self.assertTrue(ff.Basis.pauli(1) == ff.Basis(ff.util.P_np))

# Warns if orthonormal basis couldn't be generated
basis = ff.Basis(
[np.pad(b, (0, 1), 'constant') for b in ff.Basis.pauli(1)],
skip_check=True,
btype='Pauli'
)
with self.assertWarns(UserWarning):
_ = ff.Basis(basis)

def test_basis_properties(self):
"""Basis orthonormal and of correct dimensions"""
d = np.random.randint(2, 17)
ggm_basis = ff.Basis.ggm(d)
Expand All @@ -41,6 +71,9 @@ def test_basis_class(self):
btypes = ('Pauli', 'GGM')
bases = (pauli_basis, ggm_basis)
for btype, base in zip(btypes, bases):
base.tidyup(eps_scale=0)
self.assertTrue(base == base)
self.assertFalse(base == ff.Basis.ggm(d+1))
self.assertEqual(btype, base.btype)
if not btype == 'Pauli':
self.assertEqual(d, base.d)
Expand All @@ -51,13 +84,18 @@ def test_basis_class(self):
# Check if __contains__ works as expected
self.assertTrue(base[np.random.randint(0, (2**n)**2)] in base)
# Check if all elements of each basis are orthonormal and hermitian
self.assertArrayEqual(base.T,
base.view(np.ndarray).swapaxes(-1, -2))
self.assertTrue(base.isorthonorm)
self.assertTrue(base.isherm)
# Check if basis spans the whole space and all elems are traceless
self.assertTrue(base.istraceless)
self.assertTrue(base.iscomplete)
# Check sparse representation
self.assertArrayEqual(base.sparse.todense(), base)
# Test sparse cache
self.assertArrayEqual(base.sparse.todense(), base)

if base.d < 8:
# Test very resource intense
self.assertArrayAlmostEqual(base.four_element_traces.todense(),
Expand All @@ -67,6 +105,10 @@ def test_basis_class(self):

base._print_checks()

basis = ff.util.P_np[1].view(ff.Basis)
self.assertTrue(basis.isorthonorm)
self.assertArrayEqual(basis.T, basis.view(np.ndarray).T)

def test_basis_expansion_and_normalization(self):
"""Correct expansion of operators and normalization of bases"""
for _ in range(10):
Expand Down
174 changes: 169 additions & 5 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@
from numpy.random import choice, randint, randn

import filter_functions as ff
from filter_functions.numeric import (calculate_control_matrix_from_atomic,
calculate_control_matrix_from_scratch,
diagonalize, liouville_representation)
from filter_functions.numeric import (
calculate_control_matrix_from_atomic,
calculate_control_matrix_from_scratch,
calculate_error_vector_correlation_functions,
calculate_pulse_correlation_filter_function,
diagonalize,
liouville_representation
)
from tests import testutil


Expand All @@ -50,6 +55,10 @@ def test_pulse_sequence_constructor(self):
# Not enough positional arguments
ff.PulseSequence(H_c, H_n)

with self.assertRaises(TypeError):
# dt not a sequence
ff.PulseSequence(H_c, H_n, dt[0])

idx = randint(0, 5)
with self.assertRaises(ValueError):
# negative dt
Expand Down Expand Up @@ -204,7 +213,7 @@ def test_pulse_sequence_constructor(self):
print(pulse)

# Hit __copy__ method
pulse_copy = copy(pulse)
_ = copy(pulse)

# Fewer identifiers than opers
pulse_2 = ff.PulseSequence(
Expand All @@ -218,6 +227,140 @@ def test_pulse_sequence_constructor(self):
self.assertArrayEqual(pulse_2.n_oper_identifiers, ('B_0', 'Y'))

def test_pulse_sequence_attributes(self):
"""Test attributes of single instance"""
X, Y, Z = ff.util.P_np[1:]
n_dt = randint(1, 10)

# trivial case
A = ff.PulseSequence([[X, randn(n_dt), 'X']],
[[Z, randn(n_dt), 'Z']],
np.abs(randn(n_dt)))
self.assertFalse(A == 1)
self.assertTrue(A != 1)

# different number of time steps
B = ff.PulseSequence([[X, randn(n_dt+1), 'X']],
[[Z, randn(n_dt+1), 'Z']],
np.abs(randn(n_dt+1)))
self.assertFalse(A == B)
self.assertTrue(A != B)

# different time steps
B = ff.PulseSequence(
list(zip(A.c_opers, A.c_coeffs, A.c_oper_identifiers)),
list(zip(A.n_opers, A.n_coeffs, A.n_oper_identifiers)),
np.abs(randn(n_dt))
)
self.assertFalse(A == B)
self.assertTrue(A != B)

# different control opers
B = ff.PulseSequence(
list(zip([Y], A.c_coeffs, A.c_oper_identifiers)),
list(zip(A.n_opers, A.n_coeffs, A.n_oper_identifiers)),
A.dt
)
self.assertFalse(A == B)
self.assertTrue(A != B)

# different control coeffs
B = ff.PulseSequence(
list(zip(A.c_opers, [randn(n_dt)], A.c_oper_identifiers)),
list(zip(A.n_opers, A.n_coeffs, A.n_oper_identifiers)),
A.dt
)
self.assertFalse(A == B)
self.assertTrue(A != B)

# different noise opers
B = ff.PulseSequence(
list(zip(A.c_opers, A.c_coeffs, A.c_oper_identifiers)),
list(zip([Y], A.n_coeffs, A.n_oper_identifiers)),
A.dt
)
self.assertFalse(A == B)
self.assertTrue(A != B)

# different noise coeffs
B = ff.PulseSequence(
list(zip(A.c_opers, A.c_coeffs, A.c_oper_identifiers)),
list(zip(A.n_opers, [randn(n_dt)], A.n_oper_identifiers)),
A.dt
)
self.assertFalse(A == B)
self.assertTrue(A != B)

# different control oper identifiers
B = ff.PulseSequence(
list(zip(A.c_opers, A.c_coeffs, ['foobar'])),
list(zip(A.n_opers, A.n_coeffs, A.n_oper_identifiers)),
A.dt
)
self.assertFalse(A == B)
self.assertTrue(A != B)

# different noise oper identifiers
B = ff.PulseSequence(
list(zip(A.c_opers, A.c_coeffs, A.c_oper_identifiers)),
list(zip(A.n_opers, A.n_coeffs, ['foobar'])),
A.dt
)
self.assertFalse(A == B)
self.assertTrue(A != B)

# different bases
elem = testutil.rand_herm(2)
elem -= np.eye(2)*np.trace(elem)/2
B = ff.PulseSequence(
list(zip(A.c_opers, A.c_coeffs, A.c_oper_identifiers)),
list(zip(A.n_opers, A.n_coeffs, A.n_oper_identifiers)),
A.dt,
ff.Basis([elem])
)
self.assertFalse(A == B)
self.assertTrue(A != B)

# Test for attributes
for attr in A.__dict__.keys():
if not (attr.startswith('_') or '_' + attr in A.__dict__.keys()):
# not a cached attribute
with self.assertRaises(AttributeError):
_ = A.is_cached(attr)
else:
self.assertFalse(A.is_cached(attr))

# Test cleanup
C = ff.concatenate((A, A), calc_pulse_correlation_ff=True,
omega=ff.util.get_sample_frequencies(A))
C.diagonalize()
C.cache_filter_function(ff.util.get_sample_frequencies(A))
attrs = ['_HD', '_HV', '_Q']
for attr in attrs:
self.assertIsNotNone(getattr(C, attr))

C.cleanup()
for attr in attrs:
self.assertIsNone(getattr(C, attr))

C.diagonalize()
attrs.extend(['_R', '_total_phases', '_total_Q', '_total_Q_liouville'])
for attr in attrs:
self.assertIsNotNone(getattr(C, attr))

C.cleanup('greedy')
for attr in attrs:
self.assertIsNone(getattr(C, attr))

C.cache_filter_function(ff.util.get_sample_frequencies(A))
attrs.extend(['omega', '_F', '_F_pc'])
for attr in attrs:
self.assertIsNotNone(getattr(C, attr))

C.cleanup('all')
for attr in attrs:
self.assertIsNone(getattr(C, attr))

def test_pulse_sequence_attributes_concat(self):
"""Test attributes of concatenated sequence."""
X, Y, Z = ff.util.P_np[1:]
n_dt_1 = randint(5, 11)
Expand All @@ -242,12 +385,15 @@ def test_pulse_sequence_attributes(self):
pulse_12 = pulse_1 @ pulse_2
pulse_21 = pulse_2 @ pulse_1

with self.assertRaises(TypeError):
_ = pulse_1 @ randn(2, 2)

# Concatenate pulses with same operators but different labels
with self.assertRaises(ValueError):
pulse_1 @ pulse_3

# Test nbytes property
nbytes = pulse_1.nbytes
_ = pulse_1.nbytes

self.assertArrayEqual(pulse_12.dt, [*dt_1, *dt_2])
self.assertArrayEqual(pulse_21.dt, [*dt_2, *dt_1])
Expand Down Expand Up @@ -400,6 +546,9 @@ def test_pulse_correlation_filter_function(self):
pulse_2 = ff.concatenate([pulses['X'], pulses['Y']],
calc_pulse_correlation_ff=True)

with self.assertRaises(ValueError):
calculate_pulse_correlation_filter_function(pulse_1._R)

# Check if the filter functions on the diagonals are real
F = pulse_2.get_pulse_correlation_filter_function()
diag_1 = np.eye(2, dtype=bool)
Expand Down Expand Up @@ -451,6 +600,21 @@ def test_pulse_correlation_filter_function(self):
self.assertAlmostEqual(infid_1.sum(), infid_2.sum())
self.assertArrayAlmostEqual(infid_1, infid_2.sum(axis=(0, 1)))

def test_calculate_error_vector_correlation_functions(self):
"""Test raises of numeric.error_transfer_matrix"""
pulse = ff.PulseSequence([[ff.util.P_np[1], [np.pi/2]]],
[[ff.util.P_np[1], [1]]],
[1])

omega = randn(43)
# single spectrum
S = randn(78)
for i in range(4):
with self.assertRaises(ValueError):
calculate_error_vector_correlation_functions(
pulse, np.tile(S, [1]*i), omega
)

def test_infidelity_convergence(self):
import matplotlib
matplotlib.use('Agg')
Expand Down
Loading

0 comments on commit d43832e

Please sign in to comment.