Skip to content

Commit

Permalink
Merge pull request #693 from nabobalis/wcs
Browse files Browse the repository at this point in the history
Check more things in test helpers
  • Loading branch information
DanRyanIrish committed Apr 26, 2024
2 parents 3001645 + bbdfba7 commit 4858041
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions ndcube/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from numpy.testing import assert_equal

import astropy
from astropy.wcs.wcsapi import BaseHighLevelWCS
from astropy.wcs.wcsapi.fitswcs import SlicedFITSWCS
from astropy.wcs.wcsapi.low_level_api import BaseLowLevelWCS
from astropy.wcs.wcsapi.wrappers.sliced_wcs import sanitize_slices
Expand Down Expand Up @@ -95,9 +96,11 @@ def assert_metas_equal(test_input, expected_output):
assert test_input[key] == expected_output[key]


def assert_cubes_equal(test_input, expected_cube):
def assert_cubes_equal(test_input, expected_cube, check_data=True):
assert isinstance(test_input, type(expected_cube))
assert np.all(test_input.mask == expected_cube.mask)
if check_data:
np.testing.assert_array_equal(test_input.data, expected_cube.data)
assert_wcs_are_equal(test_input.wcs, expected_cube.wcs)
if test_input.uncertainty:
assert test_input.uncertainty.array.shape == expected_cube.uncertainty.array.shape
Expand All @@ -110,12 +113,12 @@ def assert_cubes_equal(test_input, expected_cube):
assert_extra_coords_equal(test_input.extra_coords, expected_cube.extra_coords)


def assert_cubesequences_equal(test_input, expected_sequence):
def assert_cubesequences_equal(test_input, expected_sequence, check_data=True):
assert isinstance(test_input, type(expected_sequence))
assert_metas_equal(test_input.meta, expected_sequence.meta)
assert test_input._common_axis == expected_sequence._common_axis
for i, cube in enumerate(test_input.data):
assert_cubes_equal(cube, expected_sequence.data[i])
assert_cubes_equal(cube, expected_sequence.data[i], check_data=check_data)


def assert_wcs_are_equal(wcs1, wcs2):
Expand All @@ -140,7 +143,12 @@ def assert_wcs_are_equal(wcs1, wcs2):
assert wcs1.world_axis_units == wcs2.world_axis_units
assert_equal(wcs1.axis_correlation_matrix, wcs2.axis_correlation_matrix)
assert wcs1.pixel_bounds == wcs2.pixel_bounds

if wcs1.pixel_shape is not None:
random_idx = np.random.randint(wcs1.pixel_shape,size=[10,wcs1.pixel_n_dim])
# SlicedLowLevelWCS vs BaseHighLevelWCS don't have the same pixel_to_world method
low_level_wcs1 = wcs1.low_level_wcs if isinstance(wcs1, BaseHighLevelWCS) else wcs1
low_level_wcs2 = wcs2.low_level_wcs if isinstance(wcs2, BaseHighLevelWCS) else wcs2
np.testing.assert_array_equal(low_level_wcs1.pixel_to_world_values(*random_idx.T), low_level_wcs2.pixel_to_world_values(*random_idx.T))

def create_sliced_wcs(wcs, item, dim):
"""
Expand All @@ -152,15 +160,15 @@ def create_sliced_wcs(wcs, item, dim):
return SlicedFITSWCS(wcs, item)


def assert_collections_equal(collection1, collection2):
def assert_collections_equal(collection1, collection2, check_data=True):
assert collection1.keys() == collection2.keys()
assert collection1.aligned_axes == collection2.aligned_axes
for cube1, cube2 in zip(collection1.values(), collection2.values()):
# Check cubes are same type.
assert type(cube1) is type(cube2)
if isinstance(cube1, NDCube):
assert_cubes_equal(cube1, cube2)
assert_cubes_equal(cube1, cube2, check_data=check_data)
elif isinstance(cube1, NDCubeSequence):
assert_cubesequences_equal(cube1, cube2)
assert_cubesequences_equal(cube1, cube2, check_data=check_data)
else:
raise TypeError(f"Unsupported Type in NDCollection: {type(cube1)}")

0 comments on commit 4858041

Please sign in to comment.