Skip to content

Commit

Permalink
Merge pull request #1423 from Cadair/map_cube_func_plot
Browse files Browse the repository at this point in the history
Add plot_function argument to MapCubeAnimator
  • Loading branch information
ayshih committed Jun 15, 2015
2 parents 76f45af + ac6be8d commit c880de3
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 25 deletions.
29 changes: 17 additions & 12 deletions sunpy/map/mapbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,8 @@ def draw_grid(self, axes=None, grid_spacing=15*u.deg, **kwargs):
Returns
-------
matplotlib.axes object
lines: list
A list of `matplotlib.lines.Line2D` objects that have been plotted.
Notes
-----
Expand All @@ -1052,6 +1053,8 @@ def draw_grid(self, axes=None, grid_spacing=15*u.deg, **kwargs):
if not axes:
axes = wcsaxes_compat.gca_wcs(self.wcs)

lines = []

# Do not automatically rescale axes when plotting the overlay
axes.set_autoscale_on(False)

Expand Down Expand Up @@ -1087,7 +1090,7 @@ def draw_grid(self, axes=None, grid_spacing=15*u.deg, **kwargs):
if wcsaxes_compat.is_wcsaxes(axes):
x = (x*u.arcsec).to(u.deg).value
y = (y*u.arcsec).to(u.deg).value
axes.plot(x, y, **plot_kw)
lines += axes.plot(x, y, **plot_kw)

hg_longitude_deg = np.arange(-180, 180, grid_spacing.to(u.deg).value) + l0
hg_latitude_deg = np.linspace(-90, 90, num=181)
Expand All @@ -1103,11 +1106,11 @@ def draw_grid(self, axes=None, grid_spacing=15*u.deg, **kwargs):
if wcsaxes_compat.is_wcsaxes(axes):
x = (x*u.arcsec).to(u.deg).value
y = (y*u.arcsec).to(u.deg).value
axes.plot(x, y, **plot_kw)
lines += axes.plot(x, y, **plot_kw)

# Turn autoscaling back on.
axes.set_autoscale_on(True)
return axes
return lines

def draw_limb(self, axes=None, **kwargs):
"""Draws a circle representing the solar limb
Expand All @@ -1119,7 +1122,9 @@ def draw_limb(self, axes=None, **kwargs):
Returns
-------
matplotlib.axes object
circ: list
A list containing the `matplotlib.patches.Circle` object that
has been added to the axes.
Notes
-----
Expand Down Expand Up @@ -1147,7 +1152,7 @@ def draw_limb(self, axes=None, **kwargs):
circ = patches.Circle([0, 0], **c_kw)
axes.add_artist(circ)

return axes
return [circ]

@toggle_pylab
def peek(self, draw_limb=False, draw_grid=False, gamma=None,
Expand Down Expand Up @@ -1234,14 +1239,14 @@ def plot(self, gamma=None, annotate=True, axes=None, **imshow_args):
Examples
--------
#Simple Plot with color bar
plt.figure()
aiamap.plot()
plt.colorbar()
>>> aiamap.plot()
>>> plt.colorbar()
#Add a limb line and grid
aia.plot()
aia.draw_limb()
aia.draw_grid()
>>> aia.plot()
>>> aia.draw_limb()
>>> aia.draw_grid()
"""

#Get current axes
Expand Down
59 changes: 47 additions & 12 deletions sunpy/map/mapcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _derotate(self):
pass

def plot(self, gamma=None, axes=None, resample=None, annotate=True,
interval=200, **kwargs):
interval=200, plot_function=None, **kwargs):
"""
A animation plotting routine that animates each element in the
MapCube
Expand All @@ -117,6 +117,11 @@ def plot(self, gamma=None, axes=None, resample=None, annotate=True,
interval: int
Animation interval in ms
plot_function : function
A function to be called as each map is plotted. Any variables
returned from the function will have their ``remove()`` method called
at the start of the next frame so that they are removed from the plot.
Examples
--------
>>> cube = sunpy.Map(files, cube=True)
Expand All @@ -139,11 +144,25 @@ def plot(self, gamma=None, axes=None, resample=None, annotate=True,
>>> writer = Writer(fps=10, metadata=dict(artist='SunPy'), bitrate=1800)
>>> ani.save('mapcube_animation.mp4', writer=writer)
Save an animation with the limb at each time step
>>> def myplot(fig, ax, sunpy_map):
... p = sunpy_map.draw_limb()
... return p
>>> cube = sunpy.Map(files, cube=True)
>>> ani = cube.peek(plot_function=myplot)
>>> plt.show()
"""
if not axes:
axes = plt.gca()
fig = axes.get_figure()

if not plot_function:
plot_function = lambda fig, ax, smap: []
removes = []

# Normal plot
def annotate_frame(i):
axes.set_title("{s.name} {s.date!s}".format(s=self[i]))
Expand All @@ -167,27 +186,29 @@ def annotate_frame(i):
self[0].cmap.set_gamma(gamma)

if resample:
#This assumes that the maps are homogenous!
#TODO: Update this!
# This assumes that the maps are homogeneous!
# TODO: Update this!
resample = np.array(len(self.maps)-1) * np.array(resample)
ani_data = [x.resample(resample) for x in self.maps]
else:
ani_data = self.maps

im = ani_data[0].plot(axes=axes, **kwargs)

def updatefig(i, im, annotate, ani_data):

def updatefig(i, im, annotate, ani_data, removes):
while removes:
removes.pop(0).remove()
im.set_array(ani_data[i].data)
im.set_cmap(self.maps[i].cmap)
im.set_norm(self.maps[i].mpl_color_normalizer)
im.set_extent(self.maps[i].xrange + self.maps[i].yrange)
im.set_extent(np.concatenate((self.maps[i].xrange.value, self.maps[i].yrange.value)))
if annotate:
annotate_frame(i)
removes += list(plot_function(fig, axes, self.maps[i]))

ani = matplotlib.animation.FuncAnimation(fig, updatefig,
frames=range(0,len(self.maps)),
fargs=[im,annotate,ani_data],
frames=range(0, len(self.maps)),
fargs=[im, annotate, ani_data, removes],
interval=interval,
blit=False)

Expand Down Expand Up @@ -221,9 +242,14 @@ def peek(self, gamma=None, resample=None, **kwargs):
colorbar: bool
Plot colorbar
plot_function : function
A function to call to overplot extra items on the map plot.
For more information see `sunpy.visualization.MapCubeAnimator`.
Returns
-------
Returns a MapCubeAnimator object
mapcubeanim : `sunpy.visualization.MapCubeAnimator`
A mapcube animator instance.
See Also
--------
Expand All @@ -232,19 +258,28 @@ def peek(self, gamma=None, resample=None, **kwargs):
Examples
--------
>>> cube = sunpy.Map(files, cube=True)
>>> ani = cube.plot(colorbar=True)
>>> ani = cube.peek(colorbar=True)
>>> plt.show()
Plot the map at 1/2 original resolution
>>> cube = sunpy.Map(files, cube=True)
>>> ani = cube.plot(resample=[0.5, 0.5], colorbar=True)
>>> ani = cube.peek(resample=[0.5, 0.5], colorbar=True)
>>> plt.show()
Plot the map with the limb at each time step
>>> def myplot(fig, ax, sunpy_map):
... p = sunpy_map.draw_limb()
... return p
>>> cube = sunpy.Map(files, cube=True)
>>> ani = cube.peek(plot_function=myplot)
>>> plt.show()
Decide you want an animation:
>>> cube = sunpy.Map(files, cube=True)
>>> ani = cube.plot(resample=[0.5, 0.5], colorbar=True)
>>> ani = cube.peek(resample=[0.5, 0.5], colorbar=True)
>>> mplani = ani.get_animation()
"""

Expand Down
21 changes: 20 additions & 1 deletion sunpy/visualization/mapcubeanimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,26 @@ class MapCubeAnimator(imageanimator.BaseFuncAnimator):
colorbar: bool
Plot colorbar
plot_function: function
A function to call when each map is plotted, the function must have
the signature `(fig, axes, smap)` where fig and axes are the figure and
axes objects of the plot and smap is the current frames Map object.
Any objects returned from this function will have their `remove()` method
called at the start of the next frame to clear them from the plot.
Notes
-----
Extra keywords are passed to `mapcube[0].plot()` i.e. the `plot()` routine of
the first map in the cube.
the maps in the cube.
"""
def __init__(self, mapcube, annotate=True, **kwargs):

self.mapcube = mapcube
self.annotate = annotate
self.user_plot_function = kwargs.pop('plot_function',
lambda fig, ax, smap: [])
# List of object to remove at the start of each plot step
self.remove_obj = []
slider_functions = [self.updatefig]
slider_ranges = [[0,len(mapcube.maps)]]

Expand All @@ -52,6 +63,11 @@ def __init__(self, mapcube, annotate=True, **kwargs):
self._annotate_plot(0)

def updatefig(self, val, im, slider):
# Remove all the objects that need to be removed from the
# plot
while self.remove_obj:
self.remove_obj.pop(0).remove()

i = int(val)
im.set_array(self.data[i].data)
im.set_cmap(self.mapcube[i].cmap)
Expand All @@ -63,6 +79,8 @@ def updatefig(self, val, im, slider):
if self.annotate:
self._annotate_plot(i)

self.remove_obj += list(self.user_plot_function(self.fig, self.axes, self.mapcube[i]))

def _annotate_plot(self, ind):
"""
Annotate the image.
Expand Down Expand Up @@ -90,4 +108,5 @@ def _annotate_plot(self, ind):
def plot_start_image(self, ax):
im = self.mapcube[0].plot(annotate=self.annotate, axes=ax,
**self.imshow_kwargs)
self.remove_obj += list(self.user_plot_function(self.fig, self.axes, self.mapcube[0]))
return im

0 comments on commit c880de3

Please sign in to comment.