Skip to content

Commit

Permalink
Refactor state histograms (#3953)
Browse files Browse the repository at this point in the history
Solves #3041
  • Loading branch information
tanujkhattar committed Mar 24, 2021
1 parent 40bc856 commit 46efaae
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 94 deletions.
7 changes: 6 additions & 1 deletion cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,12 @@
NeutralAtomDevice,
)

from cirq.vis import Heatmap, TwoQubitInteractionHeatmap, integrated_histogram
from cirq.vis import (
Heatmap,
TwoQubitInteractionHeatmap,
get_state_histogram,
integrated_histogram,
)

from cirq.work import (
CircuitSampleJob,
Expand Down
37 changes: 10 additions & 27 deletions cirq/study/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@

from typing import TYPE_CHECKING
import numpy as np
from cirq._compat import deprecated

if TYPE_CHECKING:
from cirq.study import result


@deprecated(
deadline="v0.12",
fix="use cirq.vis.plot_state_histogram or cirq.vis.get_state_histogram instead",
name="cirq.study.visualize.plot_state_histogram",
)
def plot_state_histogram(result: 'result.Result') -> np.ndarray:
"""Plot the state histogram from a single result with repetitions.
Expand All @@ -33,32 +39,9 @@ def plot_state_histogram(result: 'result.Result') -> np.ndarray:
Returns:
The histogram. A list of values plotted on the y-axis.
"""
# Needed to avoid circular imports.
import cirq.vis as vis

# pyplot import is deferred because it requires a system dependency
# (python3-tk) that `python -m pip install cirq` can't handle for the user.
# This allows cirq to be usable without python3-tk.
import matplotlib.pyplot as plt

num_qubits = sum([value.shape[1] for value in result.measurements.values()])
states = 2 ** num_qubits
values = np.zeros(states)
# measurements is a dict of {measurement gate key:
# array(repetitions, boolean result)}
# Convert this to an array of repetitions, each with an array of booleans.
# e.g. {q1: array([[True, True]]), q2: array([[False, False]])}
# --> array([[True, False], [True, False]])
measurement_by_result = np.hstack(list(result.measurements.values()))

for meas in measurement_by_result:
# Convert each array of booleans to a string representation.
# e.g. [True, False] -> [1, 0] -> '10' -> 2
state_ind = int(''.join([str(x) for x in [int(x) for x in meas]]), 2)
values[state_ind] += 1

plot_labels = [bin(x)[2:].zfill(num_qubits) for x in range(states)]
plt.bar(np.arange(states), values, tick_label=plot_labels)
plt.xlabel('qubit state')
plt.ylabel('result count')
plt.show()

values = vis.get_state_histogram(result)
vis.plot_state_histogram(values)
return values
66 changes: 0 additions & 66 deletions cirq/study/visualize_test.py

This file was deleted.

2 changes: 2 additions & 0 deletions cirq/vis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@

from cirq.vis.histogram import integrated_histogram

from cirq.vis.state_histogram import get_state_histogram, plot_state_histogram

from cirq.vis.vis_utils import relative_luminance
81 changes: 81 additions & 0 deletions cirq/vis/state_histogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tool to visualize the results of a study."""

from typing import Union, Optional
import math
import numpy as np
import matplotlib.pyplot as plt
import cirq.study.result as result


def get_state_histogram(result: 'result.Result') -> np.ndarray:
"""Computes a state histogram from a single result with repetitions.
Args:
result: The trial result containing measurement results from which the
state histogram should be computed.
Returns:
The state histogram (a numpy array) corresponding to the trial result.
"""
num_qubits = sum([value.shape[1] for value in result.measurements.values()])
states = 2 ** num_qubits
values = np.zeros(states)
# measurements is a dict of {measurement gate key:
# array(repetitions, boolean result)}
# Convert this to an array of repetitions, each with an array of booleans.
# e.g. {q1: array([[True, True]]), q2: array([[False, False]])}
# --> array([[True, False], [True, False]])
measurement_by_result = np.hstack(list(result.measurements.values()))

for meas in measurement_by_result:
# Convert each array of booleans to a string representation.
# e.g. [True, False] -> [1, 0] -> '10' -> 2
state_ind = int(''.join([str(x) for x in [int(x) for x in meas]]), 2)
values[state_ind] += 1
return values


def plot_state_histogram(
values: Union['result.Result', np.ndarray], ax: Optional['plt.Axis'] = None
) -> 'plt.Axis':
"""Plot the state histogram from either a single result with repetitions or
a histogram of measurement results computed using `get_state_histogram`.
Args:
values: The histogram values to plot. If `result.Result` is passed, the
values are computed by calling `get_state_histogram`.
ax: The Axes to plot on. If not given, a new figure is created,
plotted on, and shown.
Returns:
The axis that was plotted on.
"""
show_fig = not ax
if not ax:
fig, ax = plt.subplots(1, 1)
print(values, isinstance(values, result.Result))
if isinstance(values, result.Result):
values = get_state_histogram(values)
states = len(values)
num_qubits = math.ceil(math.log(states, 2))
plot_labels = [bin(x)[2:].zfill(num_qubits) for x in range(states)]
ax.bar(np.arange(states), values, tick_label=plot_labels)
ax.set_xlabel('qubit state')
ax.set_ylabel('result count')
if show_fig:
fig.show()
return ax
95 changes: 95 additions & 0 deletions cirq/vis/state_histogram_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for state_histogram."""

import numpy as np
from matplotlib import pyplot as plt
import matplotlib as mpl

import cirq
from cirq.devices import GridQubit
from cirq.vis import state_histogram


def test_get_state_histogram():
simulator = cirq.Simulator()

q0 = GridQubit(0, 0)
q1 = GridQubit(1, 0)
circuit = cirq.Circuit()
circuit.append([cirq.X(q0), cirq.X(q1)])
circuit.append([cirq.measure(q0, key='q0'), cirq.measure(q1, key='q1')])
result = simulator.run(program=circuit, repetitions=5)

values_to_plot = state_histogram.get_state_histogram(result)
expected_values = [0.0, 0.0, 0.0, 5.0]

np.testing.assert_equal(values_to_plot, expected_values)


def test_get_state_histogram_multi_1():
qubits = cirq.LineQubit.range(4)
c = cirq.Circuit(
cirq.X.on_each(*qubits[1:]),
cirq.measure(*qubits), # One multi-qubit measurement
)
r = cirq.sample(c, repetitions=5)
values_to_plot = state_histogram.get_state_histogram(r)
expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0]
np.testing.assert_equal(values_to_plot, expected_values)


def test_get_state_histogram_multi_2():
qubits = cirq.LineQubit.range(4)
c = cirq.Circuit(
cirq.X.on_each(*qubits[1:]),
cirq.measure(*qubits[:2]), # One multi-qubit measurement
cirq.measure_each(*qubits[2:]), # Multiple single-qubit measurement
)
r = cirq.sample(c, repetitions=5)
values_to_plot = state_histogram.get_state_histogram(r)
expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0]
np.testing.assert_equal(values_to_plot, expected_values)


def test_plot_state_histogram():
qubits = cirq.LineQubit.range(4)
c = cirq.Circuit(
cirq.X.on_each(*qubits[1:]),
cirq.measure(*qubits), # One multi-qubit measurement
)
r = cirq.sample(c, repetitions=5)
expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0]
_, (ax1, ax2) = plt.subplots(1, 2)
state_histogram.plot_state_histogram(r, ax1)
state_histogram.plot_state_histogram(expected_values, ax2)
for r1, r2 in zip(ax1.get_children(), ax2.get_children()):
if isinstance(r1, mpl.patches.Rectangle) and isinstance(r2, mpl.patches.Rectangle):
assert str(r1) == str(r2)


def test_plot_state_histogram_deprecation():
with cirq.testing.assert_deprecated(
'cirq.study.visualize.plot_state_histogram was used but is deprecated.\n'
'It will be removed in cirq v0.12.\n'
'use cirq.vis.plot_state_histogram or cirq.vis.get_state_histogram instead',
deadline="v0.12",
count=None, # Another warning is due to matplotlib.
):
simulator = cirq.Simulator()
q = cirq.NamedQubit("a")
circuit = cirq.Circuit([cirq.X(q), cirq.measure(q)])
result = simulator.run(program=circuit, repetitions=5)
cirq.study.visualize.plot_state_histogram(result)

0 comments on commit 46efaae

Please sign in to comment.