Skip to content

Commit

Permalink
Merge pull request #319 from nden/steps-pipeline
Browse files Browse the repository at this point in the history
Represent the pipeline as a list of Step instances
  • Loading branch information
nden committed Aug 17, 2020
2 parents bb50c3b + a41cce2 commit de6cb8d
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 32 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ New Features
- Added a ``WCS.steps`` attribute and a ``wcs.Step`` class to allow serialization
to ASDF to use references. [#317]

- ``wcs.pipeline`` now is a list of ``Step`` instances instead of
a (frame, transform) tuple. Use ``WCS.pipeline.transform`` and
``WCS.pipeline.frame`` to access them. [#319]

Bug Fixes
^^^^^^^^^

Expand Down
8 changes: 4 additions & 4 deletions gwcs/tags/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ def test_references(tmpdir):
af = asdf.AsdfFile(tree)
output_path = os.path.join(str(tmpdir), "test.asdf")
af.write_to(output_path)

with asdf.open(output_path) as af:
gw1 = af.tree['wcs1']
gw2 = af.tree['wcs2']
assert gw1.steps[0].transform is gw1.steps[1].transform
assert gw2.steps[0].transform is gw2.steps[1].transform
assert gw2.steps[0].frame is gw2.steps[1].frame
assert gw1.pipeline[0].transform is gw1.pipeline[1].transform
assert gw2.pipeline[0].transform is gw2.pipeline[1].transform
assert gw2.pipeline[0].frame is gw2.pipeline[1].frame
8 changes: 4 additions & 4 deletions gwcs/tags/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ def from_tree(cls, node, ctx):
@classmethod
def to_tree(cls, gwcsobj, ctx):
return {'name': gwcsobj.name,
'steps': gwcsobj.steps
'steps': gwcsobj.pipeline
}

@classmethod
def assert_equal(cls, old, new):
from asdf.tests import helpers
assert old.name == new.name # nosec
assert len(old.available_frames) == len(new.available_frames) # nosec
for (old_frame, old_transform), (new_frame, new_transform) in zip(
for old_step, new_step in zip(
old.pipeline, new.pipeline):
helpers.assert_tree_match(old_frame, new_frame)
helpers.assert_tree_match(old_transform, new_transform)
helpers.assert_tree_match(old_step.frame, new_step.frame)
helpers.assert_tree_match(old_step.transform, new_step.transform)


class StepType(dict, GWCSType):
Expand Down
2 changes: 1 addition & 1 deletion gwcs/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def test_high_level_wrapper(wcsobj, request):

# Remove the bounding box because the type test is a little broken with the
# bounding box.
del wcsobj._pipeline[0][1].bounding_box
del wcsobj._pipeline[0].transform.bounding_box

hlvl = HighLevelWCSWrapper(wcsobj)

Expand Down
12 changes: 10 additions & 2 deletions gwcs/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,17 @@ def test_init_no_transform():
Test initializing a WCS object without a forward_transform.
"""
gw = wcs.WCS(output_frame='icrs')
assert gw._pipeline == [('detector', None), ('icrs', None)]
assert len(gw._pipeline) == 2
assert gw.pipeline[0].frame == "detector"
assert gw.pipeline[0][0] == "detector"
assert gw.pipeline[1].frame == "icrs"
assert gw.pipeline[1][0] == "icrs"
assert np.in1d(gw.available_frames, ['detector', 'icrs']).all()
gw = wcs.WCS(output_frame=icrs, input_frame=detector)
assert gw._pipeline == [('detector', None), ('icrs', None)]
assert gw._pipeline[0].frame == "detector"
assert gw._pipeline[0][0] == "detector"
assert gw._pipeline[1].frame == "icrs"
assert gw._pipeline[1][0] == "icrs"
assert np.in1d(gw.available_frames, ['detector', 'icrs']).all()
with pytest.raises(NotImplementedError):
gw(1, 2)
Expand Down Expand Up @@ -110,6 +117,7 @@ def test_get_transform():
tr_back = w.get_transform('icrs', 'detector')
x, y = 1, 2
fx, fy = tr_forward(1, 2)
assert_allclose(w.pipeline[0].transform(x, y), (fx, fy))
assert_allclose(w.pipeline[0][1](x, y), (fx, fy))
assert_allclose((x, y), tr_back(*w(x, y)))
assert(w.get_transform('detector', 'detector') is None)
Expand Down
85 changes: 64 additions & 21 deletions gwcs/wcs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import functools
import itertools
import warnings
import numpy as np
import numpy.linalg as npla
from astropy.modeling.core import Model # , fix_inputs
Expand Down Expand Up @@ -58,10 +59,7 @@ def __init__(self, forward_transform=None, input_frame='detector', output_frame=
self._array_shape = None
self._initialize_wcs(forward_transform, input_frame, output_frame)
self._pixel_shape = None

@property
def steps(self):
return [Step(*step) for step in self._pipeline]
self._pipeline = [Step(*step) for step in self._pipeline]

def _initialize_wcs(self, forward_transform, input_frame, output_frame):
if forward_transform is not None:
Expand Down Expand Up @@ -127,12 +125,14 @@ def get_transform(self, from_frame, to_frame):
except ValueError:
raise CoordinateFrameError("Frame {0} is not in the available frames".format(to_frame))
if to_ind < from_ind:
transforms = np.array(self._pipeline[to_ind: from_ind], dtype="object")[:, 1].tolist()
#transforms = np.array(self._pipeline[to_ind: from_ind], dtype="object")[:, 1].tolist()
transforms = [step.transform for step in self._pipeline[to_ind: from_ind]]
transforms = [tr.inverse for tr in transforms[::-1]]
elif to_ind == from_ind:
return None
else:
transforms = np.array(self._pipeline[from_ind: to_ind], dtype="object")[:, 1].copy()
#transforms = np.array(self._pipeline[from_ind: to_ind], dtype="object")[:, 1].copy()
transforms = [step.transform for step in self._pipeline[from_ind: to_ind]]
return functools.reduce(lambda x, y: x | y, transforms)

def set_transform(self, from_frame, to_frame, transform):
Expand Down Expand Up @@ -168,7 +168,7 @@ def set_transform(self, from_frame, to_frame, transform):

if from_ind + 1 != to_ind:
raise ValueError("Frames {0} and {1} are not in sequence".format(from_name, to_name))
self._pipeline[from_ind] = (self._pipeline[from_ind][0], transform)
self._pipeline[from_ind].transform = transform

@property
def forward_transform(self):
Expand All @@ -178,7 +178,8 @@ def forward_transform(self):
"""

if self._pipeline:
return functools.reduce(lambda x, y: x | y, [step[1] for step in self._pipeline[: -1]])
#return functools.reduce(lambda x, y: x | y, [step[1] for step in self._pipeline[: -1]])
return functools.reduce(lambda x, y: x | y, [step.transform for step in self._pipeline[:-1]])
else:
return None

Expand All @@ -205,7 +206,8 @@ def _get_frame_index(self, frame):
"""
if isinstance(frame, coordinate_frames.CoordinateFrame):
frame = frame.name
frame_names = [getattr(item[0], "name", item[0]) for item in self._pipeline]
#frame_names = [getattr(item[0], "name", item[0]) for item in self._pipeline]
frame_names = [step.frame if isinstance(step.frame, str) else step.frame.name for step in self._pipeline]
return frame_names.index(frame)

def _get_frame_name(self, frame):
Expand Down Expand Up @@ -392,7 +394,8 @@ def available_frames(self):
{frame_name: frame_object or None}
"""
if self._pipeline:
return [getattr(frame[0], "name", frame[0]) for frame in self._pipeline]
#return [getattr(frame[0], "name", frame[0]) for frame in self._pipeline]
return [step.frame if isinstance(step.frame, str) else step.frame.name for step in self._pipeline ]
else:
return None

Expand All @@ -415,18 +418,19 @@ def insert_transform(self, frame, transform, after=False):
name, _ = self._get_frame_name(frame)
frame_ind = self._get_frame_index(name)
if not after:
fr, current_transform = self._pipeline[frame_ind - 1]
self._pipeline[frame_ind - 1] = (fr, current_transform | transform)
current_transform = self._pipeline[frame_ind - 1].transform
self._pipeline[frame_ind - 1].transform = current_transform | transform
else:
fr, current_transform = self._pipeline[frame_ind]
self._pipeline[frame_ind] = (fr, transform | current_transform)
current_transform = self._pipeline[frame_ind].transform
self._pipeline[frame_ind].transform = transform | current_transform

@property
def unit(self):
"""The unit of the coordinates in the output coordinate system."""
if self._pipeline:
try:
return getattr(self, self._pipeline[-1][0].name).unit
#return getattr(self, self._pipeline[-1][0].name).unit
return self._pipeline[-1].frame.unit
except AttributeError:
return None
else:
Expand All @@ -436,7 +440,8 @@ def unit(self):
def output_frame(self):
"""Return the output coordinate frame."""
if self._pipeline:
frame = self._pipeline[-1][0]
#frame = self._pipeline[-1][0]
frame = self._pipeline[-1].frame
if not isinstance(frame, str):
frame = frame.name
return getattr(self, frame)
Expand All @@ -447,7 +452,8 @@ def output_frame(self):
def input_frame(self):
"""Return the input coordinate frame."""
if self._pipeline:
frame = self._pipeline[0][0]
#frame = self._pipeline[0][0]
frame = self._pipeline[0].frame
if not isinstance(frame, str):
frame = frame.name
return getattr(self, frame)
Expand Down Expand Up @@ -533,10 +539,12 @@ def _get_axes_indices(self):

def __str__(self):
from astropy.table import Table
col1 = [item[0] for item in self._pipeline]
#col1 = [item[0] for item in self._pipeline]
col1 = [step.frame for step in self._pipeline]
col2 = []
for item in self._pipeline[: -1]:
model = item[1]
#model = item[1]
model = item.transform
if model.name is not None:
col2.append(model.name)
else:
Expand Down Expand Up @@ -931,13 +939,48 @@ class Step:
The transform of the last step should be ``None``.
"""
def __init__(self, frame, transform=None):
self._frame = frame
self._transform = transform
self.frame = frame
self.transform = transform

@property
def frame(self):
return self._frame

@frame.setter
def frame(self, val):
if not isinstance(val, (cf.CoordinateFrame, str)):
raise TypeError('"frame" should be an instance of CoordinateFrame or a string.')

self._frame = val

@property
def transform(self):
return self._transform

@transform.setter
def transform(self, val):
if val is not None and not isinstance(val, (Model)):
raise TypeError('"transform" should be an instance of astropy.modeling.Model.')
self._transform = val

@property
def frame_name(self):
if isinstance(self.frame, str):
return self.frame
return self.frame.name

def __getitem__(self, ind):
warnings.warn("Indexing a WCS.pipeline step is deprecated. "
"Use the `frame` and `transform` attributes instead.", DeprecationWarning)
if ind not in (0, 1):
raise IndexError("Allowed inices are 0 (frame) and 1 (transform).")
if ind == 0:
return self.frame
return self.transform

def __str__(self):
return f"{self.frame_name}\t {getattr(self.transform, 'name', 'None') or self.transform.__class__.__name__}"

def __repr__(self):
return f"Step(frame={self.frame_name}, \
transform={getattr(self.transform, 'name', 'None') or self.transform.__class__.__name__})"

0 comments on commit de6cb8d

Please sign in to comment.