Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,7 @@
NeutralAtomDevice,
)

from cirq.vis import (
Heatmap,)
from cirq.vis import (Heatmap, plot)

from cirq.work import (
CircuitSampleJob,
Expand Down
24 changes: 17 additions & 7 deletions cirq/experiments/cross_entropy_benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,31 @@ def data(self) -> Sequence[CrossEntropyPair]:
return self._data

def plot(self, **plot_kwargs: Any) -> None:
"""Plots the average XEB fidelity vs the number of cycles.
"""Plots mean XEB fidelity vs number of cycles on a new figure.

Args:
**plot_kwargs: Arguments to be passed to 'matplotlib.pyplot.plot'.
"""
num_cycles = [d.num_cycle for d in self._data]
fidelities = [d.xeb_fidelity for d in self._data]
# TODO(pingyeh): deprecate this in favor of the SupportsPlot protocol.
fig = plt.figure()
ax = plt.gca()
ax.set_ylim([0, 1.1])
plt.plot(num_cycles, fidelities, 'ro-', figure=fig, **plot_kwargs)
plt.xlabel('Number of Cycles', figure=fig)
plt.ylabel('XEB Fidelity', figure=fig)
self._plot_(ax, **plot_kwargs)
fig.show(warn=False)

def _plot_(self, ax: plt.Axes, **plot_kwargs: Any) -> None:
"""Plots mean XEB fidelity vs number of cycles onto ax.

Args:
ax: the axes to plot onto.
**plot_kwargs: Arguments to be passed to 'ax.plot'.
"""
num_cycles = [d.num_cycle for d in self._data]
fidelities = [d.xeb_fidelity for d in self._data]
ax.set_ylim([0, 1.1])
ax.plot(num_cycles, fidelities, 'ro-', **plot_kwargs)
ax.set_xlabel('Number of Cycles')
ax.set_ylabel('XEB Fidelity')


def cross_entropy_benchmarking(
sampler: work.Sampler,
Expand Down
41 changes: 31 additions & 10 deletions cirq/experiments/qubit_characterizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,27 @@ def plot(self, **plot_kwargs: Any) -> None:
Args:
**plot_kwargs: Arguments to be passed to matplotlib.pyplot.plot.
"""
# TODO(pingyeh): deprecate this in favor of the SupportsPlot protocol.
fig = plt.figure()
ax = plt.gca()
ax.set_ylim([0, 1])
plt.plot(self._rabi_angles, self._excited_state_probs, 'ro-',
figure=fig, **plot_kwargs)
plt.xlabel(r"Rabi Angle (Radian)", figure=fig)
plt.ylabel('Excited State Probability', figure=fig)
self._plot_(ax, **plot_kwargs)
fig.show(warn=False)

def _plot_(self, ax: plt.Axes, **plot_kwargs: Any) -> None:
"""Plots excited state probability vs the Rabi angle onto ax.

Here the Rabi angle is the angle of rotation around the x-axis.

Args:
ax: the axes to plot onto.
**plot_kwargs: Arguments to be passed to ax.plot.
"""
ax.set_ylim([0, 1])
ax.plot(self._rabi_angles, self._excited_state_probs, 'ro-',
**plot_kwargs)
ax.set_xlabel(r"Rabi Angle (Radian)")
ax.set_ylabel('Excited State Probability')


class RandomizedBenchMarkResult:
"""Results from a randomized benchmarking experiment."""
Expand Down Expand Up @@ -90,13 +102,21 @@ def plot(self, **plot_kwargs: Any) -> None:
"""
fig = plt.figure()
ax = plt.gca()
self._plot_(ax, **plot_kwargs)
fig.show(warn=False)

def _plot_(self, ax: plt.Axes, **plot_kwargs: Any) -> None:
"""Plots probability(|0>) vs number of Cliffords onto ax.

Args:
ax: the axes to plot onto.
**plot_kwargs: Arguments to be passed to ax.plot.
"""
ax.set_ylim([0, 1])

plt.plot(self._num_cfds_seq, self._gnd_state_probs, 'ro-',
figure=fig, **plot_kwargs)
plt.xlabel(r"Number of Cliffords", figure=fig)
plt.ylabel('Ground State Probability', figure=fig)
fig.show(warn=False)
ax.plot(self._num_cfds_seq, self._gnd_state_probs, 'ro-', **plot_kwargs)
ax.set_xlabel(r"Number of Cliffords")
ax.set_ylabel('Ground State Probability')


class TomographyResult:
Expand All @@ -120,6 +140,7 @@ def plot(self) -> None:
"""Plots the real and imaginary parts of the density matrix as two
3D bar plots.
"""
# TODO(pingyeh): convert this into _plot_().
fig = _plot_density_matrix(self._density_matrix)
fig.show(warn=False)

Expand Down
1 change: 1 addition & 0 deletions cirq/vis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

from cirq.vis.heatmap import Heatmap
from cirq.vis.plot import plot
4 changes: 1 addition & 3 deletions cirq/vis/examples/bristlecone_heatmap_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ def main():
for qubit in cirq.google.known_devices.Bristlecone.qubits}

heatmap = cirq.Heatmap(value_map)
fig, ax = plt.subplots(figsize=(9, 9))
heatmap.plot(ax)
fig.show(warn=False)
cirq.plot(heatmap)


if __name__ == '__main__':
Expand Down
3 changes: 3 additions & 0 deletions cirq/vis/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def plot(self, ax: plt.Axes, **pcolor_options: Any

return mesh, value_table

def _plot_(self, ax: plt.Axes, **pcolor_options: Any) -> None:
self.plot(ax, **pcolor_options)

def _plot_colorbar(self, mappable: mpl.cm.ScalarMappable,
ax: plt.Axes) -> mpl.colorbar.Colorbar:
"""Plots the colorbar. Internal."""
Expand Down
31 changes: 31 additions & 0 deletions cirq/vis/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Any, Optional

import matplotlib.pyplot as plt
from typing_extensions import Protocol


class SupportsPlot(Protocol):
"""A class of objects that knows how to plot itself to an axes."""

def _plot_(self, ax: plt.Axes, **kwargs) -> Any:
raise NotImplementedError


def plot(obj: SupportsPlot, ax: Optional[plt.Axes] = None, **kwargs) -> Any:
"""Plots an object to a given Axes or a new Axes and show it.

Args:
obj: an object with a _plot_() method that knows how to plot itself
to an axes.
ax: if given, plot onto it. Otherwise, create a new Axes.
kwargs: additional arguments passed to obj._plot_().
Returns:
A 2-tuple:
- The Axes that's plotted on.
- The return value of obj._plot_().
"""
if ax is None:
_, ax = plt.subplots(1, 1, figsize=(10, 10))
result = obj._plot_(ax, **kwargs)
ax.get_figure().show(warn=False)
return ax, result