Skip to content

Commit

Permalink
Merge pull request #4331 from Karl-Krauth/imshow-xarray-bugfixes
Browse files Browse the repository at this point in the history
Fix facet_col and animation_frame in px.imshow for xarrays.
  • Loading branch information
alexcjohnson committed Aug 22, 2023
2 parents 83e5cfa + d833fd5 commit c838fec
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
All notable changes to this project will be documented in this file.
This project adheres to [Semantic Versioning](http://semver.org/).

## UNRELEASED
- Fixed two issues with px.imshow: [[#4330](https://github.com/plotly/plotly.py/issues/4330)] when facet_col is an earlier dimension than animation_frame for xarrays and [[#4329](https://github.com/plotly/plotly.py/issues/4329)] when facet_col has string coordinates in xarrays [[#4331](https://github.com/plotly/plotly.py/pull/4331)]

## [5.16.1] - 2023-08-16

### Fixed
Expand Down
10 changes: 7 additions & 3 deletions packages/python/plotly/plotly/express/_imshow.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,18 @@ def imshow(
if xarray_imported and isinstance(img, xarray.DataArray):
dims = list(img.dims)
img_is_xarray = True
pop_indexes = []
if facet_col is not None:
facet_slices = img.coords[img.dims[facet_col]].values
_ = dims.pop(facet_col)
pop_indexes.append(facet_col)
facet_label = img.dims[facet_col]
if animation_frame is not None:
animation_slices = img.coords[img.dims[animation_frame]].values
_ = dims.pop(animation_frame)
pop_indexes.append(animation_frame)
animation_label = img.dims[animation_frame]
# Remove indices in sorted order.
for index in sorted(pop_indexes, reverse=True):
_ = dims.pop(index)
y_label, x_label = dims[0], dims[1]
# np.datetime64 is not handled correctly by go.Heatmap
for ax in [x_label, y_label]:
Expand Down Expand Up @@ -541,7 +545,7 @@ def imshow(
slice_label = (
"facet_col" if labels.get("facet_col") is None else labels["facet_col"]
)
col_labels = ["%s=%d" % (slice_label, i) for i in facet_slices]
col_labels = [f"{slice_label}={i}" for i in facet_slices]
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
for attr_name in ["height", "width"]:
if args[attr_name]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,40 @@ def test_imshow_xarray_slicethrough():
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"]))


def test_imshow_xarray_facet_col_string():
img = np.random.random((3, 4, 5))
da = xr.DataArray(
img, dims=["str_dim", "dim_1", "dim_2"], coords={"str_dim": ["A", "B", "C"]}
)
fig = px.imshow(da, facet_col="str_dim")
# Dimensions are used for axis labels and coordinates
assert fig.layout.xaxis.title.text == "dim_2"
assert fig.layout.yaxis.title.text == "dim_1"
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"]))


def test_imshow_xarray_animation_frame_string():
img = np.random.random((3, 4, 5))
da = xr.DataArray(
img, dims=["str_dim", "dim_1", "dim_2"], coords={"str_dim": ["A", "B", "C"]}
)
fig = px.imshow(da, animation_frame="str_dim")
# Dimensions are used for axis labels and coordinates
assert fig.layout.xaxis.title.text == "dim_2"
assert fig.layout.yaxis.title.text == "dim_1"
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"]))


def test_imshow_xarray_animation_facet_slicethrough():
img = np.random.random((3, 4, 5, 6))
da = xr.DataArray(img, dims=["dim_0", "dim_1", "dim_2", "dim_3"])
fig = px.imshow(da, facet_col="dim_0", animation_frame="dim_1")
# Dimensions are used for axis labels and coordinates
assert fig.layout.xaxis.title.text == "dim_3"
assert fig.layout.yaxis.title.text == "dim_2"
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_3"]))


def test_imshow_labels_and_ranges():
fig = px.imshow(
[[1, 2], [3, 4], [5, 6]],
Expand Down

0 comments on commit c838fec

Please sign in to comment.