Skip to content

Commit

Permalink
Merge pull request #4939 from dstansby/animator-test
Browse files Browse the repository at this point in the history
Fix animator axes test
  • Loading branch information
nabobalis committed Jan 31, 2021
2 parents de97d05 + 75fe6e2 commit df3a0c7
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 20 deletions.
11 changes: 6 additions & 5 deletions sunpy/visualization/animator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(self, data, slider_functions, slider_ranges, fig=None,
self.timer = None

# Set up axes
self.axes = None
self._make_axes_grid()
self._add_widgets()
self._set_active_slider(0)
Expand Down Expand Up @@ -248,16 +249,16 @@ def _dehighlight_slider(self, ind):
# =============================================================================
# Build the figure and place the widgets
# =============================================================================
def _get_main_axes(self):
def _setup_main_axes(self):
"""
Allow replacement of main axes by subclassing.
This method must set the ``axes`` attribute.
"""
if not len(self.fig.axes):
self.fig.add_subplot(111)
return self.fig.axes[0]
if self.axes is None:
self.axes = self.fig.add_subplot(111)

def _make_axes_grid(self):
self.axes = self._get_main_axes()
self._setup_main_axes()

# Split up the current axes so there is space for start & stop buttons
self.divider = make_axes_locatable(self.axes)
Expand Down
9 changes: 3 additions & 6 deletions sunpy/visualization/animator/mapsequenceanimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,12 @@ def _annotate_plot(self, ind):
self.axes.set_ylabel(axis_labels_from_ctype(self.data[ind].coordinate_system[1],
self.data[ind].spatial_units[1]))

def _get_main_axes(self):
def _setup_main_axes(self):
"""
Create an axes which is a `~astropy.visualization.wcsaxes.WCSAxes`.
"""
# If axes already exist, just return them
if len(self.fig.axes):
return self.fig.axes[0]
else:
return self.fig.add_subplot(111, projection=self.mapsequence[0].wcs)
if self.axes is None:
self.axes = self.fig.add_subplot(111, projection=self.mapsequence[0].wcs)

def plot_start_image(self, ax):
im = self.mapsequence[0].plot(
Expand Down
3 changes: 1 addition & 2 deletions sunpy/visualization/animator/tests/test_basefuncanimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ def test_to_anim(funcanimator):


def test_to_axes(funcanimator):
ax = funcanimator._get_main_axes()
assert isinstance(ax, maxes._subplots.SubplotBase)
assert isinstance(funcanimator.axes, maxes.SubplotBase)


def test_edges_to_centers_nd():
Expand Down
7 changes: 7 additions & 0 deletions sunpy/visualization/animator/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import astropy.units as u
from astropy.io import fits
from astropy.visualization.wcsaxes import WCSAxes
from astropy.wcs import WCS

from sunpy.tests.helpers import figure_test
Expand Down Expand Up @@ -140,6 +141,12 @@ def test_array_animator_wcs_2d_celestial_sliders(wcs_4d):
return a.fig


def test_to_axes(wcs_4d):
data = np.arange(120).reshape((5, 4, 3, 2))
a = ArrayAnimatorWCS(data, wcs_4d, ['x', 'y', 0, 0])
assert isinstance(a.axes, WCSAxes)


@figure_test
def test_array_animator_wcs_2d_update_plot(wcs_4d):
data = np.arange(120).reshape((5, 4, 3, 2))
Expand Down
11 changes: 4 additions & 7 deletions sunpy/visualization/animator/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,10 @@ def _apply_coord_params(self, axes):
"The 'ticks' value in the coord_params dictionary must be a dict or a boolean."
)

def _get_main_axes(self):
axes = self.fig.add_axes([0.1, 0.1, 0.8, 0.8], projection=self.wcs,
slices=self.slices_wcsaxes)

self._apply_coord_params(axes)

return axes
def _setup_main_axes(self):
self.axes = self.fig.add_axes([0.1, 0.1, 0.8, 0.8], projection=self.wcs,
slices=self.slices_wcsaxes)
self._apply_coord_params(self.axes)

def plot_start_image(self, ax):
if self.plot_dimensionality == 1:
Expand Down

0 comments on commit df3a0c7

Please sign in to comment.