Skip to content

Commit

Permalink
Merge pull request #139 from keflavich/wcs_update
Browse files Browse the repository at this point in the history
Update WCS slicing to fix bug noted in astropy #2909
  • Loading branch information
keflavich committed Sep 8, 2014
2 parents 05e3135 + c135914 commit 38c86ef
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 17 deletions.
12 changes: 11 additions & 1 deletion spectral_cube/spectral_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,15 @@ def with_mask(self, mask, inherit_mask=True):
return self._new_cube_with(mask=self._mask & mask if inherit_mask else mask)

def __getitem__(self, view):

# Need to allow self[:], self[:,:]
if isinstance(view, slice):
view = (view, slice(None), slice(None))
elif len(view) == 2:
view = view + (slice(None),)
elif len(view) > 3:
raise IndexError("Too many indices")

meta = {}
meta.update(self._meta)
meta['slice'] = [(s.start, s.stop, s.step)
Expand All @@ -695,7 +704,8 @@ def __getitem__(self, view):
if len(intslices) > 1:
# TODO: return a Specutils Spectrum object
raise NotImplementedError("1D slices are not implemented yet.")
newwcs = wcs_utils.slice_wcs(self._wcs, view)
# only one element, so drop an axis
newwcs = wcs_utils.drop_axis(self._wcs, intslices[0])
return Slice(value=self.filled_data[view],
wcs=newwcs,
copy=False,
Expand Down
35 changes: 35 additions & 0 deletions spectral_cube/tests/test_spectral_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,38 @@ def test_endians():

assert xbig.dtype.byteorder == '>'
assert xlil.dtype.byteorder == '='

def test_slicing():

cube, data = cube_and_raw('advs.fits')

# just to check that we're starting in the right place
assert cube.shape == (2,3,4)

sl = cube[:,1,:]
assert sl.shape == (2,4)

v = cube[1:2,:,:]
assert v.shape == (1,3,4)

assert cube[:,:,:].shape == (2,3,4)
assert cube[:,:].shape == (2,3,4)
assert cube[:].shape == (2,3,4)
assert cube[:1,:1,:1].shape == (1,1,1)


@pytest.mark.parametrize(('view','naxis'),
[( (slice(None), 1, slice(None)), 2 ),
( (1, slice(None), slice(None)), 2 ),
( (slice(None), slice(None), 1), 2 ),
( (slice(None), slice(None), slice(1)), 3 ),
( (slice(1), slice(1), slice(1)), 3 ),
])
def test_slice_wcs(view, naxis):

cube, data = cube_and_raw('advs.fits')

sl = cube[view]
assert sl.wcs.naxis == naxis


59 changes: 43 additions & 16 deletions spectral_cube/wcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,38 +160,65 @@ def axis_names(wcs):
return names


def slice_wcs(wcs, view):
def slice_wcs(mywcs, view, numpy_order=True):
"""
Slice a WCS instance using a Numpy slice. The order of the slice should
be reversed (as for the data) compared to the natural WCS order.
Parameters
----------
view : tuple
A tuple containing the same number of slices as the WCS system
A tuple containing the same number of slices as the WCS system.
The ``step`` method, the third argument to a slice, is not
presently supported.
numpy_order : bool
Use numpy order, i.e. slice the WCS so that an identical slice
applied to a numpy array will slice the array and WCS in the same
way. If set to `False`, the WCS will be sliced in FITS order,
meaning the first slice will be applied to the *last* numpy index
but the *first* WCS axis.
Returns
-------
A new `~astropy.wcs.WCS` instance
wcs_new : `~astropy.wcs.WCS`
A new resampled WCS axis
"""
if len(view) != wcs.wcs.naxis:
raise ValueError("Must have same number of slices as number of WCS axes")
if hasattr(view, '__len__') and len(view) > mywcs.wcs.naxis:
raise ValueError("Must have # of slices <= # of WCS axes")
elif not hasattr(view, '__len__'): # view MUST be an iterable
view = [view]

wcs_new = wcs.deepcopy()

# Indexing the WCS: x,y,z order (not numpy z,y,x order)
intslices = [wcs.wcs.naxis-ii-1 for ii,s in enumerate(view) if not hasattr(s,'start')]

for ii in intslices:
wcs_new = wcs_new.dropaxis(ii)
view = [s for s in view if hasattr(s,'start')]
if not all([isinstance(x, slice) for x in view]):
raise ValueError("Cannot downsample a WCS with indexing. Use "
"wcs.sub or wcs.dropaxis if you want to remove "
"axes.")

wcs_new = mywcs.deepcopy()
for i, iview in enumerate(view):
if iview.step is not None and iview.start is None:
# Slice from "None" is equivalent to slice from 0 (but one
# might want to downsample, so allow slices with
# None,None,step or None,stop,step)
iview = slice(0, iview.stop, iview.step)

if iview.start is not None:
if numpy_order:
wcs_index = mywcs.wcs.naxis - 1 - i
else:
wcs_index = i

if iview.step not in (None, 1):
raise NotImplementedError("Cannot yet slice WCS with strides different from None or 1")
wcs_index = wcs.wcs.naxis - 1 - i
wcs_new.wcs.crpix[wcs_index] -= iview.start
crpix = mywcs.wcs.crpix[wcs_index]
cdelt = mywcs.wcs.cdelt[wcs_index]
# equivalently (keep this comment so you can compare eqns):
# wcs_new.wcs.crpix[wcs_index] =
# (crpix - iview.start)*iview.step + 0.5 - iview.step/2.
crp = ((crpix - iview.start - 1.)/iview.step
+ 0.5 + 1./iview.step/2.)
wcs_new.wcs.crpix[wcs_index] = crp
wcs_new.wcs.cdelt[wcs_index] = cdelt * iview.step
else:
wcs_new.wcs.crpix[wcs_index] -= iview.start
return wcs_new

def check_equality(wcs1, wcs2, warn_missing=False, ignore_keywords=['MJD-OBS',
Expand Down

0 comments on commit 38c86ef

Please sign in to comment.