Skip to content

Commit

Permalink
Add integrated histogram visualizations (#3942)
Browse files Browse the repository at this point in the history
This PR adds integrated histogram visualizations support to Cirq. Closes #3941

See https://tinyurl.com/cirq-visualizations for the larger roadmap item. 

**Usage:**
- `calibration.plot('single_qubit_errors')` now produces the following image: 

![image](https://user-images.githubusercontent.com/7863287/111851641-06cae200-893a-11eb-8aa6-c8f7d37acdf6.png)


**Next steps**
- The `visualizing_calibration_metrics` tutorial will be updated in a follow up PR once this is checked in.
  • Loading branch information
tanujkhattar committed Mar 20, 2021
1 parent cfb2559 commit bee1a44
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 5 deletions.
5 changes: 1 addition & 4 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,10 +546,7 @@
NeutralAtomDevice,
)

from cirq.vis import (
Heatmap,
TwoQubitInteractionHeatmap,
)
from cirq.vis import Heatmap, TwoQubitInteractionHeatmap, integrated_histogram

from cirq.work import (
CircuitSampleJob,
Expand Down
79 changes: 78 additions & 1 deletion cirq/google/engine/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@

from collections import abc, defaultdict
import datetime
from itertools import cycle

from typing import Any, Dict, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union, Sequence

import matplotlib as mpl
import matplotlib.pyplot as plt
import google.protobuf.json_format as json_format
from cirq import devices, vis
from cirq.google.api import v2
Expand Down Expand Up @@ -252,3 +255,77 @@ def heatmap(self, key: str) -> vis.Heatmap:
'Heatmaps are only supported if all the targets in a metric are one or two qubits.'
+ f'{key} has target qubits {value_map.keys()}'
)

def plot_histograms(
self,
keys: Sequence[str],
ax: Optional[plt.Axes] = None,
*,
labels: Optional[Sequence[str]] = None,
) -> plt.Axes:
"""Plots integrated histograms of metric values corresponding to keys
Args:
keys: List of metric keys for which an integrated histogram should be plot
ax: The axis to plot on. If None, we generate one.
Returns:
The axis that was plotted on.
Raises:
ValueError if the metric values are not single floats.
"""
show_plot = not ax
if not ax:
fig, ax = plt.subplots(1, 1)

if isinstance(keys, str):
keys = [keys]
if not labels:
labels = keys
colors = ['b', 'r', 'k', 'g', 'c', 'm']
for key, label, color in zip(keys, labels, cycle(colors)):
metrics = self[key]
if not all(len(k) == 1 for k in metrics.values()):
raise ValueError(
'Histograms are only supported if all values in a metric '
+ 'are single metric values.'
+ f'{key} has metric values {metrics.values()}'
)
vis.integrated_histogram(
[self.value_to_float(v) for v in metrics.values()],
ax,
label=label,
color=color,
title=key.replace('_', ' ').title(),
)
if show_plot:
fig.show()

return ax

def plot(
self, key: str, fig: Optional[mpl.figure.Figure] = None
) -> Tuple[mpl.figure.Figure, List[plt.Axes]]:
"""Plots a heatmap and an integrated histogram for the given key.
Args:
key: The metric key to plot a heatmap and integrated histogram for.
fig: The figure to plot on. If none, we generate one.
Returns:
The figure and list of axis that was plotted on.
Raises:
ValueError if the key is not for one/two qubits metric or the metric
values are not single floats.
"""
show_plot = not fig
if not fig:
fig = plt.figure()
axs = fig.subplots(1, 2)
self.heatmap(key).plot(axs[0])
self.plot_histograms(key, axs[1])
if show_plot:
fig.show()
return fig, axs
24 changes: 24 additions & 0 deletions cirq/google/engine/calibration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,27 @@ def test_calibration_heatmap():
v2.metrics_pb2.MetricsSnapshot(),
)
cg.Calibration(multi_qubit_data).heatmap('multi_value')


def test_calibration_plot_histograms():
calibration = cg.Calibration(_CALIBRATION_DATA)
_, ax = mpl.pyplot.subplots(1, 1)
calibration.plot_histograms(['t1', 'two_qubit_xeb'], ax, labels=['T1', 'XEB'])
assert len(ax.get_lines()) == 4

with pytest.raises(ValueError, match="single metric values.*multi_value"):
multi_qubit_data = Merge(
"""metrics: [{
name: 'multi_value',
targets: ['0_0'],
values: [{double_val: 0.999}, {double_val: 0.001}]}]""",
v2.metrics_pb2.MetricsSnapshot(),
)
cg.Calibration(multi_qubit_data).plot_histograms('multi_value')


def test_calibration_plot():
calibration = cg.Calibration(_CALIBRATION_DATA)
_, axs = calibration.plot('two_qubit_xeb')
assert axs[0].get_title() == 'Two Qubit Xeb'
assert len(axs[1].get_lines()) == 2
2 changes: 2 additions & 0 deletions cirq/vis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@
from cirq.vis.heatmap import Heatmap
from cirq.vis.heatmap import TwoQubitInteractionHeatmap

from cirq.vis.histogram import integrated_histogram

from cirq.vis.vis_utils import relative_luminance
148 changes: 148 additions & 0 deletions cirq/vis/histogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Mapping, Optional, Sequence, Union, SupportsFloat

import numpy as np
from matplotlib import pyplot as plt


def integrated_histogram(
data: Union[Sequence[SupportsFloat], Mapping[Any, SupportsFloat]],
ax: Optional[plt.Axes] = None,
*,
cdf_on_x: bool = False,
axis_label: str = '',
semilog: bool = True,
median_line: bool = True,
median_label: Optional[str] = 'median',
mean_line: bool = False,
mean_label: Optional[str] = 'mean',
show_zero: bool = False,
title: Optional[str] = None,
**kwargs,
) -> plt.Axes:
"""Plot the integrated histogram for an array of data.
Suppose the input is a list of gate fidelities. The x-axis of the plot will
be gate fidelity, and the y-axis will be the probability that a random gate
fidelity from the list is less than the x-value. It will look something like
this
1.0
| |
| ___|
| |
| ____|
| |
| |
|_____|_______________
0.0
Another way of saying this is that we assume the probability distribution
function (pdf) of gate fidelities is a set of equally weighted delta
functions at each value in the list. Then, the "integrated histogram"
is the cumulative distribution function (cdf) for this pdf.
Args:
data: Data to histogram. If the data is a `Mapping`, we histogram the
values. All nans will be removed.
ax: The axis to plot on. If None, we generate one.
cdf_on_x: If True, flip the axes compared the above example.
axis_label: Label for x axis (y-axis if cdf_on_x is True).
semilog: If True, force the x-axis to be logarithmic.
median_line: If True, draw a vertical line on the median value.
median_label: If drawing median line, optional label for it.
mean_line: If True, draw a vertical line on the mean value.
mean_label: If drawing mean line, optional label for it.
title: Title of the plot. If None, we assign "N={len(data)}".
show_zero: If True, moves the step plot up by one unit by prepending 0
to the data.
**kwargs: Kwargs to forward to `ax.step()`. Some examples are
color: Color of the line.
linestyle: Linestyle to use for the plot.
lw: linewidth for integrated histogram.
ms: marker size for a histogram trace.
label: An optional label which can be used in a legend.
Returns:
The axis that was plotted on.
"""
show_plot = not ax
if ax is None:
fig, ax = plt.subplots(1, 1)

if isinstance(data, Mapping):
data = list(data.values())

data = [d for d in data if not np.isnan(d)]
n = len(data)

if not show_zero:
bin_values = np.linspace(0, 1, n + 1)
parameter_values = sorted(np.concatenate(([0], data)))
else:
bin_values = np.linspace(0, 1, n)
parameter_values = sorted(data)
plot_options = {
"where": 'post',
"color": 'b',
"linestyle": '-',
"lw": 1.0,
"ms": 0.0,
}
plot_options.update(kwargs)

if cdf_on_x:
ax.step(bin_values, parameter_values, **plot_options)
else:
ax.step(parameter_values, bin_values, **plot_options)

set_semilog = ax.semilogy if cdf_on_x else ax.semilogx
set_lim = ax.set_xlim if cdf_on_x else ax.set_ylim
set_ticks = ax.set_xticks if cdf_on_x else ax.set_yticks
set_line = ax.axhline if cdf_on_x else ax.axvline
cdf_label = ax.set_xlabel if cdf_on_x else ax.set_ylabel
ax_label = ax.set_ylabel if cdf_on_x else ax.set_xlabel

if not title:
title = f'N={n}'
ax.set_title(title)

if semilog:
set_semilog()
set_lim(0, 1)
set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
ax.grid(True)
cdf_label('Integrated histogram')
if axis_label:
ax_label(axis_label)
if 'label' in plot_options:
ax.legend()

if median_line:
set_line(
np.median(data),
linestyle='--',
color=plot_options['color'],
alpha=0.5,
label=median_label,
)
if mean_line:
set_line(
np.mean(data), linestyle='-.', color=plot_options['color'], alpha=0.5, label=mean_label
)
if show_plot:
fig.show()
return ax
59 changes: 59 additions & 0 deletions cirq/vis/histogram_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Histogram."""

import numpy as np
import pytest

import matplotlib.pyplot as plt

from cirq.vis import integrated_histogram


@pytest.mark.parametrize('data', [range(10), {f'key_{i}': i for i in range(10)}])
def test_integrated_histogram(data):
ax = integrated_histogram(
data,
title='Test Plot',
axis_label='Y Axis Label',
color='r',
label='line label',
cdf_on_x=True,
show_zero=True,
)
assert ax.get_title() == 'Test Plot'
assert ax.get_ylabel() == 'Y Axis Label'
assert len(ax.get_lines()) == 2
for line in ax.get_lines():
assert line.get_color() == 'r'


def test_multiple_plots():
_, ax = plt.subplots(1, 1)
n = 53
data = np.random.random_sample((2, n))
integrated_histogram(
data[0],
ax,
color='r',
label='data_1',
median_line=False,
mean_line=True,
mean_label='mean_1',
)
integrated_histogram(data[1], ax, color='k', label='data_2', median_label='median_2')
assert ax.get_title() == 'N=53'
for line in ax.get_lines():
assert line.get_color() in ['r', 'k']
assert line.get_label() in ['data_1', 'data_2', 'mean_1', 'median_2']

0 comments on commit bee1a44

Please sign in to comment.