diff --git a/cirq/__init__.py b/cirq/__init__.py index 00b95a34a15..6d1a981fd90 100644 --- a/cirq/__init__.py +++ b/cirq/__init__.py @@ -419,8 +419,7 @@ NeutralAtomDevice, ) -from cirq.vis import ( - Heatmap,) +from cirq.vis import (Heatmap, plot) from cirq.work import ( CircuitSampleJob, diff --git a/cirq/experiments/cross_entropy_benchmarking.py b/cirq/experiments/cross_entropy_benchmarking.py index 3d2e16ec805..a858788552c 100644 --- a/cirq/experiments/cross_entropy_benchmarking.py +++ b/cirq/experiments/cross_entropy_benchmarking.py @@ -31,21 +31,31 @@ def data(self) -> Sequence[CrossEntropyPair]: return self._data def plot(self, **plot_kwargs: Any) -> None: - """Plots the average XEB fidelity vs the number of cycles. + """Plots mean XEB fidelity vs number of cycles on a new figure. Args: **plot_kwargs: Arguments to be passed to 'matplotlib.pyplot.plot'. """ - num_cycles = [d.num_cycle for d in self._data] - fidelities = [d.xeb_fidelity for d in self._data] + # TODO(pingyeh): deprecate this in favor of the SupportsPlot protocol. fig = plt.figure() ax = plt.gca() - ax.set_ylim([0, 1.1]) - plt.plot(num_cycles, fidelities, 'ro-', figure=fig, **plot_kwargs) - plt.xlabel('Number of Cycles', figure=fig) - plt.ylabel('XEB Fidelity', figure=fig) + self._plot_(ax, **plot_kwargs) fig.show(warn=False) + def _plot_(self, ax: plt.Axes, **plot_kwargs: Any) -> None: + """Plots mean XEB fidelity vs number of cycles onto ax. + + Args: + ax: the axes to plot onto. + **plot_kwargs: Arguments to be passed to 'ax.plot'. + """ + num_cycles = [d.num_cycle for d in self._data] + fidelities = [d.xeb_fidelity for d in self._data] + ax.set_ylim([0, 1.1]) + ax.plot(num_cycles, fidelities, 'ro-', **plot_kwargs) + ax.set_xlabel('Number of Cycles') + ax.set_ylabel('XEB Fidelity') + def cross_entropy_benchmarking( sampler: work.Sampler, diff --git a/cirq/experiments/qubit_characterizations.py b/cirq/experiments/qubit_characterizations.py index a4016b2fa11..b85239679a8 100644 --- a/cirq/experiments/qubit_characterizations.py +++ b/cirq/experiments/qubit_characterizations.py @@ -47,15 +47,27 @@ def plot(self, **plot_kwargs: Any) -> None: Args: **plot_kwargs: Arguments to be passed to matplotlib.pyplot.plot. """ + # TODO(pingyeh): deprecate this in favor of the SupportsPlot protocol. fig = plt.figure() ax = plt.gca() - ax.set_ylim([0, 1]) - plt.plot(self._rabi_angles, self._excited_state_probs, 'ro-', - figure=fig, **plot_kwargs) - plt.xlabel(r"Rabi Angle (Radian)", figure=fig) - plt.ylabel('Excited State Probability', figure=fig) + self._plot_(ax, **plot_kwargs) fig.show(warn=False) + def _plot_(self, ax: plt.Axes, **plot_kwargs: Any) -> None: + """Plots excited state probability vs the Rabi angle onto ax. + + Here the Rabi angle is the angle of rotation around the x-axis. + + Args: + ax: the axes to plot onto. + **plot_kwargs: Arguments to be passed to ax.plot. + """ + ax.set_ylim([0, 1]) + ax.plot(self._rabi_angles, self._excited_state_probs, 'ro-', + **plot_kwargs) + ax.set_xlabel(r"Rabi Angle (Radian)") + ax.set_ylabel('Excited State Probability') + class RandomizedBenchMarkResult: """Results from a randomized benchmarking experiment.""" @@ -90,13 +102,21 @@ def plot(self, **plot_kwargs: Any) -> None: """ fig = plt.figure() ax = plt.gca() + self._plot_(ax, **plot_kwargs) + fig.show(warn=False) + + def _plot_(self, ax: plt.Axes, **plot_kwargs: Any) -> None: + """Plots probability(|0>) vs number of Cliffords onto ax. + + Args: + ax: the axes to plot onto. + **plot_kwargs: Arguments to be passed to ax.plot. + """ ax.set_ylim([0, 1]) - plt.plot(self._num_cfds_seq, self._gnd_state_probs, 'ro-', - figure=fig, **plot_kwargs) - plt.xlabel(r"Number of Cliffords", figure=fig) - plt.ylabel('Ground State Probability', figure=fig) - fig.show(warn=False) + ax.plot(self._num_cfds_seq, self._gnd_state_probs, 'ro-', **plot_kwargs) + ax.set_xlabel(r"Number of Cliffords") + ax.set_ylabel('Ground State Probability') class TomographyResult: @@ -120,6 +140,7 @@ def plot(self) -> None: """Plots the real and imaginary parts of the density matrix as two 3D bar plots. """ + # TODO(pingyeh): convert this into _plot_(). fig = _plot_density_matrix(self._density_matrix) fig.show(warn=False) diff --git a/cirq/vis/__init__.py b/cirq/vis/__init__.py index 448785026fd..399628a4468 100644 --- a/cirq/vis/__init__.py +++ b/cirq/vis/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from cirq.vis.heatmap import Heatmap +from cirq.vis.plot import plot diff --git a/cirq/vis/examples/bristlecone_heatmap_example.py b/cirq/vis/examples/bristlecone_heatmap_example.py index 2ed21a4ffdf..7e6e8ef1140 100644 --- a/cirq/vis/examples/bristlecone_heatmap_example.py +++ b/cirq/vis/examples/bristlecone_heatmap_example.py @@ -10,9 +10,7 @@ def main(): for qubit in cirq.google.known_devices.Bristlecone.qubits} heatmap = cirq.Heatmap(value_map) - fig, ax = plt.subplots(figsize=(9, 9)) - heatmap.plot(ax) - fig.show(warn=False) + cirq.plot(heatmap) if __name__ == '__main__': diff --git a/cirq/vis/heatmap.py b/cirq/vis/heatmap.py index 1553233c1e0..e8f2d56e32c 100644 --- a/cirq/vis/heatmap.py +++ b/cirq/vis/heatmap.py @@ -253,6 +253,9 @@ def plot(self, ax: plt.Axes, **pcolor_options: Any return mesh, value_table + def _plot_(self, ax: plt.Axes, **pcolor_options: Any) -> None: + self.plot(ax, **pcolor_options) + def _plot_colorbar(self, mappable: mpl.cm.ScalarMappable, ax: plt.Axes) -> mpl.colorbar.Colorbar: """Plots the colorbar. Internal.""" diff --git a/cirq/vis/plot.py b/cirq/vis/plot.py new file mode 100644 index 00000000000..8178adeceac --- /dev/null +++ b/cirq/vis/plot.py @@ -0,0 +1,31 @@ +from typing import Any, Optional + +import matplotlib.pyplot as plt +from typing_extensions import Protocol + + +class SupportsPlot(Protocol): + """A class of objects that knows how to plot itself to an axes.""" + + def _plot_(self, ax: plt.Axes, **kwargs) -> Any: + raise NotImplementedError + + +def plot(obj: SupportsPlot, ax: Optional[plt.Axes] = None, **kwargs) -> Any: + """Plots an object to a given Axes or a new Axes and show it. + + Args: + obj: an object with a _plot_() method that knows how to plot itself + to an axes. + ax: if given, plot onto it. Otherwise, create a new Axes. + kwargs: additional arguments passed to obj._plot_(). + Returns: + A 2-tuple: + - The Axes that's plotted on. + - The return value of obj._plot_(). + """ + if ax is None: + _, ax = plt.subplots(1, 1, figsize=(10, 10)) + result = obj._plot_(ax, **kwargs) + ax.get_figure().show(warn=False) + return ax, result