Skip to content

Commit

Permalink
Improve plot_state_histogram configuration options. (#4015)
Browse files Browse the repository at this point in the history
This PR introduces the following changes
- Make `tick_labels` (and other plot properties) configurable so that users can provide different labels
- Changes default tick labels from binary string ('0111') to integers ('7') to make larger plots more readable. 
- Add support for `collections.Counter` so users can directly pass the result of `result.histogram()` so that processed results can be plotted easily. 

Once this is checked-in, I'll send a PR for state_histogram.ipynb tutorial.
  • Loading branch information
tanujkhattar committed Apr 13, 2021
1 parent 76edeb1 commit 4799246
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 18 deletions.
57 changes: 40 additions & 17 deletions cirq/vis/state_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

"""Tool to visualize the results of a study."""

from typing import Union, Optional
import math
from typing import Union, Optional, Sequence, SupportsFloat
import collections
import numpy as np
import matplotlib.pyplot as plt
import cirq.study.result as result
Expand Down Expand Up @@ -50,32 +50,55 @@ def get_state_histogram(result: 'result.Result') -> np.ndarray:


def plot_state_histogram(
values: Union['result.Result', np.ndarray], ax: Optional['plt.Axis'] = None
data: Union['result.Result', collections.Counter, Sequence[SupportsFloat]],
ax: Optional['plt.Axis'] = None,
*,
tick_label: Optional[Sequence[str]] = None,
xlabel: Optional[str] = 'qubit state',
ylabel: Optional[str] = 'result count',
title: Optional[str] = 'Result State Histogram',
) -> 'plt.Axis':
"""Plot the state histogram from either a single result with repetitions or
a histogram of measurement results computed using `get_state_histogram`.
a histogram computed using `result.histogram()` or a flattened histogram
of measurement results computed using `get_state_histogram`.
Args:
values: The histogram values to plot. If `result.Result` is passed, the
values are computed by calling `get_state_histogram`.
ax: The Axes to plot on. If not given, a new figure is created,
plotted on, and shown.
data: The histogram values to plot. Possible options are:
`result.Result`: Histogram is computed using
`get_state_histogram` and all 2 ** num_qubits values are
plotted, including 0s.
`collections.Counter`: Only (key, value) pairs present in
collection are plotted.
`Sequence[SupportsFloat]`: Values in the input sequence are
plotted. i'th entry corresponds to height of the i'th
bar in histogram.
ax: The Axes to plot on. If not given, a new figure is created,
plotted on, and shown.
tick_label: Tick labels for the histogram plot in case input is not
`collections.Counter`. By default, label for i'th entry
is |i>.
xlabel: Label for the x-axis.
ylabel: Label for the y-axis.
title: Title of the plot.
Returns:
The axis that was plotted on.
"""
show_fig = not ax
if not ax:
fig, ax = plt.subplots(1, 1)
print(values, isinstance(values, result.Result))
if isinstance(values, result.Result):
values = get_state_histogram(values)
states = len(values)
num_qubits = math.ceil(math.log(states, 2))
plot_labels = [bin(x)[2:].zfill(num_qubits) for x in range(states)]
ax.bar(np.arange(states), values, tick_label=plot_labels)
ax.set_xlabel('qubit state')
ax.set_ylabel('result count')
if isinstance(data, result.Result):
values = get_state_histogram(data)
elif isinstance(data, collections.Counter):
tick_label, values = zip(*sorted(data.items()))
else:
values = data
if not tick_label:
tick_label = np.arange(len(values))
ax.bar(np.arange(len(values)), values, tick_label=tick_label)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
if show_fig:
fig.show()
return ax
19 changes: 18 additions & 1 deletion cirq/vis/state_histogram_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_get_state_histogram_multi_2():
np.testing.assert_equal(values_to_plot, expected_values)


def test_plot_state_histogram():
def test_plot_state_histogram_result():
qubits = cirq.LineQubit.range(4)
c = cirq.Circuit(
cirq.X.on_each(*qubits[1:]),
Expand All @@ -80,6 +80,23 @@ def test_plot_state_histogram():
assert str(r1) == str(r2)


def test_plot_state_histogram_collection():
qubits = cirq.LineQubit.range(4)
c = cirq.Circuit(
cirq.X.on_each(*qubits[1:]),
cirq.measure(*qubits), # One multi-qubit measurement
)
r = cirq.sample(c, repetitions=5)
_, (ax1, ax2) = plt.subplots(1, 2)
state_histogram.plot_state_histogram(r.histogram(key='0,1,2,3'), ax1)
expected_values = [5]
tick_label = ['7']
state_histogram.plot_state_histogram(expected_values, ax2, tick_label=tick_label, xlabel=None)
for r1, r2 in zip(ax1.get_children(), ax2.get_children()):
if isinstance(r1, mpl.patches.Rectangle) and isinstance(r2, mpl.patches.Rectangle):
assert str(r1) == str(r2)


def test_plot_state_histogram_deprecation():
with cirq.testing.assert_deprecated(
'cirq.study.visualize.plot_state_histogram was used but is deprecated.\n'
Expand Down

0 comments on commit 4799246

Please sign in to comment.