Skip to content

Commit

Permalink
Merge pull request #188 from yashrsharma44/plot2D-cube-with-bugfix
Browse files Browse the repository at this point in the history
Support 2D plotting in NDCube using APE14
  • Loading branch information
DanRyanIrish committed Aug 20, 2019
2 parents 8ff1440 + 591cb59 commit 3837614
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 100 deletions.
17 changes: 10 additions & 7 deletions ndcube/mixins/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def _plot_2D_cube(self, axes=None, plot_axis_indices=None, axes_coordinates=None
axes_coord_check = False
if axes_coord_check:
# Build slice list for WCS for initializing WCSAxes object.
if self.wcs.naxis != 2:
if self.wcs.pixel_n_dim != 2:
slice_list = []
index = 0
for bool_ in self.missing_axes:
Expand All @@ -228,16 +228,17 @@ def _plot_2D_cube(self, axes=None, plot_axis_indices=None, axes_coordinates=None
else:
ax = wcsaxes_compat.gca_wcs(self.wcs)
# Set axis labels
x_wcs_axis = utils.cube.data_axis_to_wcs_axis(plot_axis_indices[0],
self.missing_axes)

x_wcs_axis = utils.cube.data_axis_to_wcs_ape14(plot_axis_indices[0], utils.wcs._pixel_keep(self.wcs),
self.wcs.pixel_n_dim)
ax.set_xlabel("{0} [{1}]".format(
self.world_axis_physical_types[plot_axis_indices[0]],
self.wcs.wcs.cunit[x_wcs_axis]))
y_wcs_axis = utils.cube.data_axis_to_wcs_axis(plot_axis_indices[1],
self.missing_axes)
self.wcs.world_axis_units[x_wcs_axis]))
y_wcs_axis = utils.cube.data_axis_to_wcs_ape14(plot_axis_indices[1], utils.wcs._pixel_keep(self.wcs),
self.wcs.pixel_n_dim)
ax.set_ylabel("{0} [{1}]".format(
self.world_axis_physical_types[plot_axis_indices[1]],
self.wcs.wcs.cunit[y_wcs_axis]))
self.wcs.world_axis_units[y_wcs_axis]))
# Plot data
ax.imshow(data, **kwargs)
else:
Expand Down Expand Up @@ -429,6 +430,7 @@ def _derive_axes_coordinates(self, axes_coordinates, axes_units, data_shape, edg
new_axes_units = []
default_labels = []
default_label_text = ""

for i, axis_coordinate in enumerate(axes_coordinates):
# If axis coordinate is None, derive axis values from WCS.
if axis_coordinate is None:
Expand Down Expand Up @@ -464,6 +466,7 @@ def _derive_axes_coordinates(self, axes_coordinates, axes_units, data_shape, edg
new_axis_unit = new_axis_coordinate.unit
new_axis_coordinate = new_axis_coordinate.value
else:

new_axis_unit = axes_units[i]
new_axis_coordinate = new_axis_coordinate.to(new_axis_unit).value
elif isinstance(new_axis_coordinate[0], datetime.datetime):
Expand Down
185 changes: 92 additions & 93 deletions ndcube/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,54 +207,53 @@ def test_cube_plot_1D_errors(test_input, test_kwargs, expected_error):
with pytest.raises(expected_error):
output = test_input.plot(**test_kwargs)


# @pytest.mark.parametrize("test_input, test_kwargs, expected_values", [
# (cube[0], {},
# (np.ma.masked_array(cube[0].data, cube[0].mask), "time [min]", "em.wl [m]",
# (0.4, 1.6, 2e-11, 6e-11))),

# (cube[0], {"axes_coordinates": ["bye", None], "axes_units": [None, u.cm]},
# (np.ma.masked_array(cube[0].data, cube[0].mask), "bye [m]", "em.wl [cm]",
# (0.0, 3.0, 2e-9, 6e-9))),

# (cube[0], {"axes_coordinates": [np.arange(10, 10+cube[0].data.shape[1]),
# u.Quantity(np.arange(10, 10+cube[0].data.shape[0]), unit=u.m)],
# "axes_units": [None, u.cm]},
# (np.ma.masked_array(cube[0].data, cube[0].mask), " [None]", " [cm]", (10, 13, 1000, 1200))),

# (cube[0], {"axes_coordinates": [np.arange(10, 10+cube[0].data.shape[1]),
# u.Quantity(np.arange(10, 10+cube[0].data.shape[0]), unit=u.m)]},
# (np.ma.masked_array(cube[0].data, cube[0].mask), " [None]", " [m]", (10, 13, 10, 12))),

# (cube_unit[0], {"plot_axis_indices": [0, 1], "axes_coordinates": [None, "bye"],
# "data_unit": u.erg},
# (np.ma.masked_array((cube_unit[0].data * cube_unit[0].unit).to(u.erg).value,
# cube_unit[0].mask).transpose(),
# "em.wl [m]", "bye [m]", (2e-11, 6e-11, 0.0, 3.0)))
# ])
# def test_cube_plot_2D(test_input, test_kwargs, expected_values):
# # Unpack expected properties.
# expected_data, expected_xlabel, expected_ylabel, expected_extent = \
# expected_values
# # Run plot method.
# output = test_input.plot(**test_kwargs)
# # Check plot properties are correct.
# assert isinstance(output, matplotlib.axes.Axes)
# np.testing.assert_array_equal(output.images[0].get_array(), expected_data)
# assert output.axes.xaxis.get_label_text() == expected_xlabel
# assert output.axes.yaxis.get_label_text() == expected_ylabel
# assert np.allclose(output.images[0].get_extent(), expected_extent)
@pytest.mark.parametrize("test_input, test_kwargs, expected_values", [
(cube[0, 0], {},
(np.ma.masked_array(cube[0, 0].data, cube[0, 0].mask), "time [min]", "em.wl [m]",
(-0.5, 3.5, 2.5, -0.5))),
# (cube[0, 0], {"axes_coordinates": ["bye", None], "axes_units": [None, u.cm]},
# (np.ma.masked_array(cube[0, 0].data, cube[0, 0].mask), "bye [m]", "em.wl [cm]",
# (0.0, 3.0, 2e-9, 6e-9))),
(cube[0, 0], {"axes_coordinates": [np.arange(10, 10+cube[0, 0].data.shape[1]),
u.Quantity(np.arange(10, 10+cube[0, 0].data.shape[0]), unit=u.m)],
"axes_units": [None, u.cm]},
(np.ma.masked_array(cube[0, 0].data, cube[0, 0].mask), " [None]", " [cm]", (10, 13, 1000, 1200))),
(cube[0, 0], {"axes_coordinates": [np.arange(10, 10+cube[0, 0].data.shape[1]),
u.Quantity(np.arange(10, 10+cube[0, 0].data.shape[0]), unit=u.m)]},
(np.ma.masked_array(cube[0, 0].data, cube[0, 0].mask), " [None]", " [m]", (10, 13, 10, 12))),
# (cube_unit[0], {"plot_axis_indices": [0, 1], "axes_coordinates": [None, "bye"],
# "data_unit": u.erg},
# (np.ma.masked_array((cube_unit[0].data * cube_unit[0].unit).to(u.erg).value,
# cube_unit[0].mask).transpose(),
# "em.wl [m]", "bye [m]", (2e-11, 6e-11, 0.0, 3.0)))
])
def test_cube_plot_2D(test_input, test_kwargs, expected_values):
# Unpack expected properties.
expected_data, expected_xlabel, expected_ylabel, expected_extent = \
expected_values
# Run plot method.
output = test_input.plot(**test_kwargs)
# Check plot properties are correct.
assert isinstance(output, matplotlib.axes.Axes)
np.testing.assert_array_equal(output.images[0].get_array(), expected_data)
assert output.axes.xaxis.get_label_text() == expected_xlabel
assert output.axes.yaxis.get_label_text() == expected_ylabel
assert np.allclose(output.images[0].get_extent(), expected_extent)


# @pytest.mark.parametrize("test_input, test_kwargs, expected_error", [
# (cube[0], {"axes_coordinates": ["array coord", None], "axes_units": [u.cm, None]}, TypeError),
# (cube[0], {"axes_coordinates": [np.arange(10, 10+cube[0].data.shape[1]), None],
# "axes_units": [u.cm, None]}, TypeError),
# (cube[0], {"data_unit": u.cm}, TypeError)
# ])
# def test_cube_plot_2D_errors(test_input, test_kwargs, expected_error):
# with pytest.raises(expected_error):
# output = test_input.plot(**test_kwargs)
@pytest.mark.parametrize("test_input, test_kwargs, expected_error", [
(cube[0, 0], {"axes_coordinates": ["array coord", None], "axes_units": [u.cm, None]}, TypeError),
(cube[0, 0], {"axes_coordinates": [np.arange(10, 10+cube[0].data.shape[1]), None],
"axes_units": [u.cm, None]}, TypeError),
(cube[0, 0], {"data_unit": u.cm}, TypeError)
])
def test_cube_plot_2D_errors(test_input, test_kwargs, expected_error):
with pytest.raises(expected_error):
output = test_input.plot(**test_kwargs)


# @pytest.mark.parametrize("test_input, test_kwargs, expected_values", [
Expand All @@ -273,53 +272,53 @@ def test_cube_plot_1D_errors(test_input, test_kwargs, expected_error):
# assert output.axes.yaxis.get_label_text() == expected_ylabel


# @pytest.mark.parametrize("input_values, expected_values", [
# ((None, None, None, None, {"image_axes": [-1, -2],
# "axis_ranges": [np.arange(3), np.arange(3)],
# "unit_x_axis": "km",
# "unit_y_axis": u.s,
# "unit": u.W}),
# ([-1, -2], [np.arange(3), np.arange(3)], ["km", u.s], u.W, {})),
# (([-1, -2], [np.arange(3), np.arange(3)], ["km", u.s], u.W, {}),
# ([-1, -2], [np.arange(3), np.arange(3)], ["km", u.s], u.W, {})),
# (([-1], None, None, None, {"unit_x_axis": "km"}),
# ([-1], None, "km", None, {})),
# (([-1, -2], None, None, None, {"unit_x_axis": "km"}),
# (([-1, -2], None, ["km", None], None, {}))),
# (([-1, -2], None, None, None, {"unit_y_axis": "km"}),
# (([-1, -2], None, [None, "km"], None, {})))
# ])
# def test_support_101_plot_API(input_values, expected_values):
# # Define expected values.
# expected_plot_axis_indices, expected_axes_coordinates, expected_axes_units, \
# expected_data_unit, expected_kwargs = expected_values
# # Run function
# output_plot_axis_indices, output_axes_coordinates, output_axes_units, \
# output_data_unit, output_kwargs = plotting._support_101_plot_API(*input_values)
# # Check values are correct
# assert output_plot_axis_indices == expected_plot_axis_indices
# if expected_axes_coordinates is None:
# assert output_axes_coordinates == expected_axes_coordinates
# elif type(expected_axes_coordinates) is list:
# for i, ac in enumerate(output_axes_coordinates):
# np.testing.assert_array_equal(ac, expected_axes_coordinates[i])
# assert output_axes_units == expected_axes_units
# assert output_data_unit == expected_data_unit
# assert output_kwargs == expected_kwargs


# @pytest.mark.parametrize("input_values", [
# ([0, 1], None, None, None, {"image_axes": [-1, -2]}),
# (None, [np.arange(1, 4), np.arange(1, 4)], None, None,
# {"axis_ranges": [np.arange(3), np.arange(3)]}),
# (None, None, [u.s, "km"], None, {"unit_x_axis": u.W}),
# (None, None, [u.s, "km"], None, {"unit_y_axis": u.W}),
# (None, None, None, u.s, {"unit": u.W}),
# ([0, 1, 2], None, None, None, {"unit_x_axis": [u.s, u.km, u.W]}),
# ])
# def test_support_101_plot_API_errors(input_values):
# with pytest.raises(ValueError):
# output = plotting._support_101_plot_API(*input_values)
@pytest.mark.parametrize("input_values, expected_values", [
((None, None, None, None, {"image_axes": [-1, -2],
"axis_ranges": [np.arange(3), np.arange(3)],
"unit_x_axis": "km",
"unit_y_axis": u.s,
"unit": u.W}),
([-1, -2], [np.arange(3), np.arange(3)], ["km", u.s], u.W, {})),
(([-1, -2], [np.arange(3), np.arange(3)], ["km", u.s], u.W, {}),
([-1, -2], [np.arange(3), np.arange(3)], ["km", u.s], u.W, {})),
(([-1], None, None, None, {"unit_x_axis": "km"}),
([-1], None, "km", None, {})),
(([-1, -2], None, None, None, {"unit_x_axis": "km"}),
(([-1, -2], None, ["km", None], None, {}))),
(([-1, -2], None, None, None, {"unit_y_axis": "km"}),
(([-1, -2], None, [None, "km"], None, {})))
])
def test_support_101_plot_API(input_values, expected_values):
# Define expected values.
expected_plot_axis_indices, expected_axes_coordinates, expected_axes_units, \
expected_data_unit, expected_kwargs = expected_values
# Run function
output_plot_axis_indices, output_axes_coordinates, output_axes_units, \
output_data_unit, output_kwargs = plotting._support_101_plot_API(*input_values)
# Check values are correct
assert output_plot_axis_indices == expected_plot_axis_indices
if expected_axes_coordinates is None:
assert output_axes_coordinates == expected_axes_coordinates
elif type(expected_axes_coordinates) is list:
for i, ac in enumerate(output_axes_coordinates):
np.testing.assert_array_equal(ac, expected_axes_coordinates[i])
assert output_axes_units == expected_axes_units
assert output_data_unit == expected_data_unit
assert output_kwargs == expected_kwargs


@pytest.mark.parametrize("input_values", [
([0, 1], None, None, None, {"image_axes": [-1, -2]}),
(None, [np.arange(1, 4), np.arange(1, 4)], None, None,
{"axis_ranges": [np.arange(3), np.arange(3)]}),
(None, None, [u.s, "km"], None, {"unit_x_axis": u.W}),
(None, None, [u.s, "km"], None, {"unit_y_axis": u.W}),
(None, None, None, u.s, {"unit": u.W}),
([0, 1, 2], None, None, None, {"unit_x_axis": [u.s, u.km, u.W]}),
])
def test_support_101_plot_API_errors(input_values):
with pytest.raises(ValueError):
output = plotting._support_101_plot_API(*input_values)


# @pytest.mark.parametrize("test_input, test_kwargs, expected_values", [
Expand Down

0 comments on commit 3837614

Please sign in to comment.