diff --git a/changelog/306.bugfix.rst b/changelog/306.bugfix.rst new file mode 100644 index 000000000..98b06ddbe --- /dev/null +++ b/changelog/306.bugfix.rst @@ -0,0 +1 @@ +Move ImageAnimatorWCS class into ndcube from sunpy as it is no longer supported from sunpy 2.1 onwards. diff --git a/ndcube/mixins/plotting.py b/ndcube/mixins/plotting.py index c0e2e7e6c..db563d55c 100644 --- a/ndcube/mixins/plotting.py +++ b/ndcube/mixins/plotting.py @@ -8,13 +8,14 @@ from astropy.visualization.wcsaxes import WCSAxes import sunpy.visualization.wcsaxes_compat as wcsaxes_compat try: - from sunpy.visualization.animator import ImageAnimator, ImageAnimatorWCS, LineAnimator + from sunpy.visualization.animator import ImageAnimator, LineAnimator except ImportError: - from sunpy.visualization.imageanimator import ImageAnimator, ImageAnimatorWCS, LineAnimator + from sunpy.visualization.imageanimator import ImageAnimator, LineAnimator from ndcube import utils from ndcube.utils.cube import _get_extra_coord_edges from ndcube.mixins import sequence_plotting +from ndcube.visualization.animator import ImageAnimatorWCS __all__ = ['NDCubePlotMixin'] diff --git a/ndcube/mixins/sequence_plotting.py b/ndcube/mixins/sequence_plotting.py index 6943324a9..98c5d1b40 100644 --- a/ndcube/mixins/sequence_plotting.py +++ b/ndcube/mixins/sequence_plotting.py @@ -6,12 +6,13 @@ import matplotlib.pyplot as plt import astropy.units as u try: - from sunpy.visualization.animator import ImageAnimatorWCS, LineAnimator + from sunpy.visualization.animator import LineAnimator except ImportError: - from sunpy.visualization.imageanimator import ImageAnimatorWCS, LineAnimator + from sunpy.visualization.imageanimator import LineAnimator from ndcube import utils from ndcube.utils.cube import _get_extra_coord_edges +from ndcube.visualization.animator import ImageAnimatorWCS __all__ = ['NDCubeSequencePlotMixin'] diff --git a/ndcube/tests/test_plotting.py b/ndcube/tests/test_plotting.py index b92be27af..004c4e4a2 100644 --- a/ndcube/tests/test_plotting.py +++ b/ndcube/tests/test_plotting.py @@ -8,13 +8,14 @@ import matplotlib import matplotlib.pyplot as plt try: - from sunpy.visualization.animator import ImageAnimatorWCS, LineAnimator + from sunpy.visualization.animator import LineAnimator except ImportError: - from sunpy.visualization.imageanimator import ImageAnimatorWCS, LineAnimator + from sunpy.visualization.imageanimator import LineAnimator from ndcube import NDCube from ndcube.utils.wcs import WCS from ndcube.mixins import plotting +from ndcube.visualization.animator import ImageAnimatorWCS # sample data for tests diff --git a/ndcube/visualization/animator/__init__.py b/ndcube/visualization/animator/__init__.py new file mode 100644 index 000000000..ab898b5a2 --- /dev/null +++ b/ndcube/visualization/animator/__init__.py @@ -0,0 +1 @@ +from ndcube.visualization.animator.image import * diff --git a/ndcube/visualization/animator/image.py b/ndcube/visualization/animator/image.py new file mode 100644 index 000000000..da58a51c1 --- /dev/null +++ b/ndcube/visualization/animator/image.py @@ -0,0 +1,104 @@ +from astropy.wcs.wcsapi import BaseLowLevelWCS +from sunpy.visualization.animator import ImageAnimator + +__all__ = ["ImageAnimatorWCS"] + + +class ImageAnimatorWCS(ImageAnimator): + """ + Animates N-dimensional data with an associated World Coordinate System. + The following keyboard shortcuts are defined in the viewer: + * 'left': previous step on active slider. + * 'right': next step on active slider. + * 'top': change the active slider up one. + * 'bottom': change the active slider down one. + * 'p': play/pause active slider. + This viewer can have user defined buttons added by specifying the labels + and functions called when those buttons are clicked as keyword arguments. + Parameters + ---------- + data: `numpy.ndarray` + The data to be visualized. + wcs : `~astropy.wcs.wcsapi.BaseLowLevelWCS` + The WCS object describing the physical coordinates of the data. + image_axes: `list`, optional + A list of the axes order that make up the image. + unit_x_axis: `astropy.units.Unit`, optional + The unit of X axis. + unit_y_axis: `astropy.units.Unit`, optional + The unit of Y axis. + axis_ranges: `list`, optional + Defaults to `None` and array indices will be used for all axes. + The `list` should contain one element for each axis of the input data array. + For the image axes a ``[min, max]`` pair should be specified which will be + passed to `matplotlib.pyplot.imshow` as an extent. + For the slider axes a ``[min, max]`` pair can be specified or an array the + same length as the axis which will provide all values for that slider. + Notes + ----- + Extra keywords are passed to `~sunpy.visualization.animator.ArrayAnimator`. + """ + + def __init__(self, data, wcs, image_axes=[-1, -2], unit_x_axis=None, unit_y_axis=None, + axis_ranges=None, **kwargs): + if not isinstance(wcs, BaseLowLevelWCS): + raise ValueError("A WCS object should be provided that implements the astropy WCS API.") + if wcs.pixel_n_dim is not data.ndim: + raise ValueError("Dimensionality of the data and WCS object do not match.") + self.wcs = wcs + list_slices_wcsaxes = [0 for i in range(self.wcs.pixel_n_dim)] + list_slices_wcsaxes[image_axes[0]] = 'x' + list_slices_wcsaxes[image_axes[1]] = 'y' + self.slices_wcsaxes = list_slices_wcsaxes[::-1] + self.unit_x_axis = unit_x_axis + self.unit_y_axis = unit_y_axis + + # Using `super()` here causes an error with the @deprecated decorator. + ImageAnimator.__init__(self, data, image_axes=image_axes, axis_ranges=axis_ranges, **kwargs) + + def _get_main_axes(self): + axes = self.fig.add_axes([0.1, 0.1, 0.8, 0.8], projection=self.wcs, + slices=self.slices_wcsaxes) + self._set_unit_in_axis(axes) + return axes + + def _set_unit_in_axis(self, axes): + x_index = self.slices_wcsaxes.index("x") + y_index = self.slices_wcsaxes.index("y") + if self.unit_x_axis is not None: + axes.coords[x_index].set_format_unit(self.unit_x_axis) + axes.coords[x_index].set_ticks(exclude_overlapping=True) + if self.unit_y_axis is not None: + axes.coords[y_index].set_format_unit(self.unit_y_axis) + axes.coords[y_index].set_ticks(exclude_overlapping=True) + + def plot_start_image(self, ax): + """ + Sets up a plot of initial image. + """ + imshow_args = {'interpolation': 'nearest', + 'origin': 'lower', + } + imshow_args.update(self.imshow_kwargs) + im = ax.imshow(self.data[self.frame_index], **imshow_args) + if self.if_colorbar: + self._add_colorbar(im) + return im + + def update_plot(self, val, im, slider): + """ + Updates plot based on slider/array dimension being iterated. + """ + ind = int(val) + ax_ind = self.slider_axes[slider.slider_ind] + self.frame_slice[ax_ind] = ind + list_slices_wcsaxes = list(self.slices_wcsaxes) + list_slices_wcsaxes[self.wcs.pixel_n_dim-ax_ind-1] = val + self.slices_wcsaxes = list_slices_wcsaxes + if val != slider.cval: + self.axes.reset_wcs(wcs=self.wcs, slices=self.slices_wcsaxes) + self._set_unit_in_axis(self.axes) + im.set_array(self.data[self.frame_index]) + slider.cval = val + # Update slider label to reflect real world values in axis_ranges. + super().update_plot(val, im, slider)