From afaad6fc693b6270ac51e4ac7ba22c8e09f2429b Mon Sep 17 00:00:00 2001 From: "Adam Ginsburg (keflavich)" Date: Fri, 14 Nov 2014 15:52:05 +0100 Subject: [PATCH] make it possible for the GUI to use a spectral_cube directly --- pvextractor/gui.py | 54 +++++++++++++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/pvextractor/gui.py b/pvextractor/gui.py index 16505328f0..c1afe89da4 100644 --- a/pvextractor/gui.py +++ b/pvextractor/gui.py @@ -208,21 +208,32 @@ def _update_segments(self): self.set_linestyles(('solid', 'dashed')) self.set_linewidths((2, 1)) +def unitless(x): + if hasattr(x, 'unit'): + return x.value + else: + return x class PVSlicer(object): - def __init__(self, filename, backend="Qt4Agg", clim=None): + def __init__(self, filename_or_cube, backend="Qt4Agg", clim=None): - self.filename = filename try: from spectral_cube import SpectralCube - cube = SpectralCube.read(filename, format='fits') - self.array = cube._data - except: + if isinstance(filename_or_cube, SpectralCube): + cube = filename_or_cube + else: + cube = SpectralCube.read(filename_or_cube, format='fits') + self.cube = cube + self.array = self.cube + self.shape = cube.shape + except ImportError: warnings.warn("spectral_cube package is not available - using astropy.io.fits directly") from astropy.io import fits - self.array = fits.getdata(filename) + self.array = fits.getdata(filename_or_cube) + self.shape =array.shape + self.ok_mask = np.isfinite(self.array) if self.array.ndim != 3: raise ValueError("dataset does not have 3 dimensions (install the spectral_cube package to avoid this error)") @@ -240,29 +251,38 @@ def __init__(self, filename, backend="Qt4Agg", clim=None): warnings.warn("clim not defined and will be determined from the data") # To work with large arrays, sub-sample the data # (but don't do it for small arrays) - n1 = max(self.array.shape[0] / 10, 1) - n2 = max(self.array.shape[1] / 10, 1) - n3 = max(self.array.shape[2] / 10, 1) - sub_array = self.array[::n1,::n2,::n3] - cmin = np.min(sub_array[~np.isnan(sub_array) & ~np.isinf(sub_array)]) - cmax = np.max(sub_array[~np.isnan(sub_array) & ~np.isinf(sub_array)]) + n1 = max(self.shape[0] / 10, 1) + n2 = max(self.shape[1] / 10, 1) + n3 = max(self.shape[2] / 10, 1) + if hasattr(self,'cube'): + sub_cube = self.cube[::n1,::n2,::n3] + cmin = sub_cube.min().value + cmax = sub_cube.max().value + else: + sub_array = self.array[::n1,::n2,::n3] + sub_mask = self.ok_mask[::n1,::n2,::n3] + cmin = sub_array[sub_mask].min() + cmax = sub_array[sub_mask].max() crange = cmax - cmin self._clim = (cmin - crange, cmax + crange) else: self._clim = clim - self.slice = int(round(self.array.shape[0] / 2.)) + self.slice = int(round(self.shape[0] / 2.)) from matplotlib.widgets import Slider self.slice_slider_ax = self.fig.add_axes([0.1, 0.95, 0.4, 0.03]) self.slice_slider_ax.set_xticklabels("") self.slice_slider_ax.set_yticklabels("") - self.slice_slider = Slider(self.slice_slider_ax, "3-d slice", 0, self.array.shape[0], valinit=self.slice, valfmt="%i") + self.slice_slider = Slider(self.slice_slider_ax, "3-d slice", 0, self.shape[0]-1, valinit=self.slice, valfmt="%i") self.slice_slider.on_changed(self.update_slice) self.slice_slider.drawon = False - self.image = self.ax1.imshow(self.array[self.slice, :,:], origin='lower', interpolation='nearest', vmin=self._clim[0], vmax=self._clim[1], cmap=plt.cm.gray) + self.image = self.ax1.imshow(unitless(self.array[self.slice, :,:]), + origin='lower', interpolation='nearest', + vmin=self._clim[0], vmax=self._clim[1], + cmap=plt.cm.gray) self.vmin_slider_ax = self.fig.add_axes([0.1, 0.90, 0.4, 0.03]) self.vmin_slider_ax.set_xticklabels("") @@ -360,10 +380,10 @@ def show(self, block=True): def update_slice(self, pos=None): if self.array.ndim == 2: - self.image.set_array(self.array) + self.image.set_array(unitless(self.array)) else: self.slice = int(round(pos)) - self.image.set_array(self.array[self.slice, :, :]) + self.image.set_array(unitless(self.array[self.slice, :, :])) self.fig.canvas.draw()