-
Notifications
You must be signed in to change notification settings - Fork 989
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
6 changed files
with
194 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright 2021 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tool to visualize the results of a study.""" | ||
|
||
from typing import Union, Optional | ||
import math | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import cirq.study.result as result | ||
|
||
|
||
def get_state_histogram(result: 'result.Result') -> np.ndarray: | ||
"""Computes a state histogram from a single result with repetitions. | ||
Args: | ||
result: The trial result containing measurement results from which the | ||
state histogram should be computed. | ||
Returns: | ||
The state histogram (a numpy array) corresponding to the trial result. | ||
""" | ||
num_qubits = sum([value.shape[1] for value in result.measurements.values()]) | ||
states = 2 ** num_qubits | ||
values = np.zeros(states) | ||
# measurements is a dict of {measurement gate key: | ||
# array(repetitions, boolean result)} | ||
# Convert this to an array of repetitions, each with an array of booleans. | ||
# e.g. {q1: array([[True, True]]), q2: array([[False, False]])} | ||
# --> array([[True, False], [True, False]]) | ||
measurement_by_result = np.hstack(list(result.measurements.values())) | ||
|
||
for meas in measurement_by_result: | ||
# Convert each array of booleans to a string representation. | ||
# e.g. [True, False] -> [1, 0] -> '10' -> 2 | ||
state_ind = int(''.join([str(x) for x in [int(x) for x in meas]]), 2) | ||
values[state_ind] += 1 | ||
return values | ||
|
||
|
||
def plot_state_histogram( | ||
values: Union['result.Result', np.ndarray], ax: Optional['plt.Axis'] = None | ||
) -> 'plt.Axis': | ||
"""Plot the state histogram from either a single result with repetitions or | ||
a histogram of measurement results computed using `get_state_histogram`. | ||
Args: | ||
values: The histogram values to plot. If `result.Result` is passed, the | ||
values are computed by calling `get_state_histogram`. | ||
ax: The Axes to plot on. If not given, a new figure is created, | ||
plotted on, and shown. | ||
Returns: | ||
The axis that was plotted on. | ||
""" | ||
show_fig = not ax | ||
if not ax: | ||
fig, ax = plt.subplots(1, 1) | ||
print(values, isinstance(values, result.Result)) | ||
if isinstance(values, result.Result): | ||
values = get_state_histogram(values) | ||
states = len(values) | ||
num_qubits = math.ceil(math.log(states, 2)) | ||
plot_labels = [bin(x)[2:].zfill(num_qubits) for x in range(states)] | ||
ax.bar(np.arange(states), values, tick_label=plot_labels) | ||
ax.set_xlabel('qubit state') | ||
ax.set_ylabel('result count') | ||
if show_fig: | ||
fig.show() | ||
return ax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright 2021 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests for state_histogram.""" | ||
|
||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
import matplotlib as mpl | ||
|
||
import cirq | ||
from cirq.devices import GridQubit | ||
from cirq.vis import state_histogram | ||
|
||
|
||
def test_get_state_histogram(): | ||
simulator = cirq.Simulator() | ||
|
||
q0 = GridQubit(0, 0) | ||
q1 = GridQubit(1, 0) | ||
circuit = cirq.Circuit() | ||
circuit.append([cirq.X(q0), cirq.X(q1)]) | ||
circuit.append([cirq.measure(q0, key='q0'), cirq.measure(q1, key='q1')]) | ||
result = simulator.run(program=circuit, repetitions=5) | ||
|
||
values_to_plot = state_histogram.get_state_histogram(result) | ||
expected_values = [0.0, 0.0, 0.0, 5.0] | ||
|
||
np.testing.assert_equal(values_to_plot, expected_values) | ||
|
||
|
||
def test_get_state_histogram_multi_1(): | ||
qubits = cirq.LineQubit.range(4) | ||
c = cirq.Circuit( | ||
cirq.X.on_each(*qubits[1:]), | ||
cirq.measure(*qubits), # One multi-qubit measurement | ||
) | ||
r = cirq.sample(c, repetitions=5) | ||
values_to_plot = state_histogram.get_state_histogram(r) | ||
expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0] | ||
np.testing.assert_equal(values_to_plot, expected_values) | ||
|
||
|
||
def test_get_state_histogram_multi_2(): | ||
qubits = cirq.LineQubit.range(4) | ||
c = cirq.Circuit( | ||
cirq.X.on_each(*qubits[1:]), | ||
cirq.measure(*qubits[:2]), # One multi-qubit measurement | ||
cirq.measure_each(*qubits[2:]), # Multiple single-qubit measurement | ||
) | ||
r = cirq.sample(c, repetitions=5) | ||
values_to_plot = state_histogram.get_state_histogram(r) | ||
expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0] | ||
np.testing.assert_equal(values_to_plot, expected_values) | ||
|
||
|
||
def test_plot_state_histogram(): | ||
qubits = cirq.LineQubit.range(4) | ||
c = cirq.Circuit( | ||
cirq.X.on_each(*qubits[1:]), | ||
cirq.measure(*qubits), # One multi-qubit measurement | ||
) | ||
r = cirq.sample(c, repetitions=5) | ||
expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0] | ||
_, (ax1, ax2) = plt.subplots(1, 2) | ||
state_histogram.plot_state_histogram(r, ax1) | ||
state_histogram.plot_state_histogram(expected_values, ax2) | ||
for r1, r2 in zip(ax1.get_children(), ax2.get_children()): | ||
if isinstance(r1, mpl.patches.Rectangle) and isinstance(r2, mpl.patches.Rectangle): | ||
assert str(r1) == str(r2) | ||
|
||
|
||
def test_plot_state_histogram_deprecation(): | ||
with cirq.testing.assert_deprecated( | ||
'cirq.study.visualize.plot_state_histogram was used but is deprecated.\n' | ||
'It will be removed in cirq v0.12.\n' | ||
'use cirq.vis.plot_state_histogram or cirq.vis.get_state_histogram instead', | ||
deadline="v0.12", | ||
count=None, # Another warning is due to matplotlib. | ||
): | ||
simulator = cirq.Simulator() | ||
q = cirq.NamedQubit("a") | ||
circuit = cirq.Circuit([cirq.X(q), cirq.measure(q)]) | ||
result = simulator.run(program=circuit, repetitions=5) | ||
cirq.study.visualize.plot_state_histogram(result) |