Skip to content

Commit

Permalink
Add plotting test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
thangleiter committed Jan 24, 2020
1 parent 216e728 commit e6ead88
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
10 changes: 5 additions & 5 deletions filter_functions/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,20 @@
'plot_infidelity_convergence', 'plot_error_transfer_matrix']


def get_bloch_vector(states: Sequence[State], outtype: str = 'np') -> ndarray:
def get_bloch_vector(states: Sequence[State]) -> ndarray:
r"""
Get the Bloch vector from quantum states.
"""
if outtype == 'np':
a = np.array([[(states[i].T.conj() @ util.P_np[j+1] @ states[i])[0, 0]
for i in range(len(states))] for j in range(3)])
else:
if isinstance(states[0], Qobj):
a = np.empty((3, len(states)))
X, Y, Z = util.P_qt[1:]
for i, state in enumerate(states):
a[:, i] = [expect(X, state),
expect(Y, state),
expect(Z, state)]
else:
a = np.einsum('...ij,kil,...lm->k...', np.conj(states), util.P_np[1:],
states)
return a


Expand Down
42 changes: 41 additions & 1 deletion tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
from numpy.random import randint, randn

import filter_functions as ff
from filter_functions.plotting import (init_bloch_sphere,
from filter_functions.plotting import (get_bloch_vector,
get_states_from_prop,
init_bloch_sphere,
plot_bloch_vector_evolution,
plot_filter_function,
plot_pulse_correlation_filter_function,
Expand Down Expand Up @@ -69,6 +71,26 @@

class PlottingTest(testutil.TestCase):

def test_get_bloch_vector(self):
states = [qt.rand_ket(2) for _ in range(10)]
bloch_vectors_qt = get_bloch_vector(states)
bloch_vectors_np = get_bloch_vector([state.full() for state in states])

for bv_qt, bv_np in zip(bloch_vectors_qt, bloch_vectors_np):
self.assertArrayAlmostEqual(bv_qt, bv_np)

def test_get_states_from_prop(self):
P = testutil.rand_unit(2, 10)
Q = np.empty((11, 2, 2), dtype=complex)
Q[0] = np.identity(2)
for i in range(10):
Q[i+1] = P[i] @ Q[i]

psi0 = qt.rand_ket(2)
states_piecewise = get_states_from_prop(P, psi0, 'piecewise')
states_total = get_states_from_prop(Q[1:], psi0, 'total')
self.assertArrayAlmostEqual(states_piecewise, states_total)

def test_plot_bloch_vector_evolution(self):
two_qubit_pulse = ff.PulseSequence(
[[qt.tensor(qt.sigmax(), qt.sigmax()), [np.pi/2]]],
Expand Down Expand Up @@ -96,6 +118,13 @@ def test_plot_pulse_train(self):
# Call with default args
fig, ax, leg = plot_pulse_train(simple_pulse)

# Call with no axes but figure
fig = plt.figure()
fig, ax, leg = plot_pulse_train(simple_pulse, fig=fig)

# Call with axes but no figure
fig, ax, leg = plot_pulse_train(simple_pulse, axes=ax)

# Call with custom args
c_oper_identifiers = sample(
complicated_pulse.c_oper_identifiers.tolist(), randint(2, 4)
Expand Down Expand Up @@ -126,8 +155,16 @@ def test_plot_pulse_train(self):

def test_plot_filter_function(self):
# Call with default args
simple_pulse.cleanup('all')
fig, ax, leg = plot_filter_function(simple_pulse)

# Call with no axes but figure
fig = plt.figure()
fig, ax, leg = plot_filter_function(simple_pulse, fig=fig)

# Call with axes but no figure
fig, ax, leg = plot_filter_function(simple_pulse, axes=ax)

# Non-default args
n_oper_identifiers = sample(
complicated_pulse.n_oper_identifiers.tolist(), randint(2, 4)
Expand Down Expand Up @@ -249,6 +286,9 @@ def test_plot_error_transfer_matrix(self):
U = ff.error_transfer_matrix(simple_pulse, S, omega)
fig, grid = plot_error_transfer_matrix(U=U)

# Log colorscale
fig, grid = plot_error_transfer_matrix(U=U, colorscale='log')

# Non-default args
n_oper_inds = sample(range(len(complicated_pulse.n_opers)),
randint(2, 4))
Expand Down
5 changes: 5 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,11 @@ def test_oper_equiv(self):
self.assertFalse(result[0])

def test_dot_HS(self):
U, V = randint(0, 100, (2, 2, 2))
S = util.dot_HS(U, V)
T = util.dot_HS(U, V, eps=0)
self.assertArrayEqual(S, T)

for d in randint(2, 10, (5,)):
U = qt.rand_herm(d)
V = qt.rand_herm(d)
Expand Down

0 comments on commit e6ead88

Please sign in to comment.