From 01d3d3dcd3267ba6e742465331d31dbda0221e02 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Thu, 30 Aug 2018 10:04:28 +0200 Subject: [PATCH] More flexible plotting --- qctoolkit/pulses/plotting.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/qctoolkit/pulses/plotting.py b/qctoolkit/pulses/plotting.py index 2c6748188..11ca201cc 100644 --- a/qctoolkit/pulses/plotting.py +++ b/qctoolkit/pulses/plotting.py @@ -187,7 +187,9 @@ def plot(pulse: PulseTemplate, show: bool=True, plot_channels: Optional[Set[ChannelID]]=None, plot_measurements: Optional[Set[str]]=None, - maximum_points: int=10**6) -> Any: # pragma: no cover + stepped: bool=True, + maximum_points: int=10**6, + **kwargs) -> Any: # pragma: no cover """Plot a pulse using matplotlib. The given pulse will first be sequenced using the Sequencer class. The resulting @@ -203,8 +205,10 @@ def plot(pulse: PulseTemplate, axes: matplotlib Axes object the pulse will be drawn into if provided show: If true, the figure will be shown plot_channels: If specified only channels from this set will be plotted. If omitted all channels will be. + stepped: If true pyplot.step is used for plotting plot_measurements: If specified measurements in this set will be plotted. If omitted no measurements will be. maximum_points: If the sampled waveform is bigger, it is not plotted + kwargs: Forwarded to pyplot. Overwrites other settings. Returns: matplotlib.pyplot.Figure instance in which the pulse is rendered Raises: @@ -248,10 +252,19 @@ def plot(pulse: PulseTemplate, # plot to figure figure = plt.figure() axes = figure.add_subplot(111) + + if plot_channels is not None: + voltages = {ch: voltage + for ch, voltage in voltages.items() + if ch in plot_channels} + for ch_name, voltage in voltages.items(): - if plot_channels is None or ch_name in plot_channels: - line, = axes.step(times, voltage, where='post', label='channel {}'.format(ch_name)) - legend_handles.append(line) + label = 'channel {}'.format(ch_name) + if stepped: + line, = axes.step(times, voltage, **{**dict(where='post', label=label), **kwargs}) + else: + line, = axes.plot(times, voltage, **{**dict(label=label), **kwargs}) + legend_handles.append(line) if plot_measurements: measurement_dict = dict() @@ -273,14 +286,13 @@ def plot(pulse: PulseTemplate, min_voltage = min((min(channel, default=0) for channel in voltages.values()), default=0) # add some margins in the presentation - plt.plot() - plt.xlim(-0.5, duration + 0.5) - plt.ylim(min_voltage - 0.5, max_voltage + 0.5) - plt.xlabel('Time in ns') - plt.ylabel('Voltage') + axes.set_xlim(-0.5, duration + 0.5) + axes.set_ylim(min_voltage - 0.1*(max_voltage-min_voltage), max_voltage + 0.1*(max_voltage-min_voltage)) + axes.set_xlabel('Time (ns)') + axes.set_ylabel('Voltage (a.u.)') if pulse.identifier: - plt.title(pulse.identifier) + axes.set_title(pulse.identifier) if show: axes.get_figure().show()