diff --git a/pygridtools/core.py b/pygridtools/core.py index 3a0c9f5..e0519af 100644 --- a/pygridtools/core.py +++ b/pygridtools/core.py @@ -261,18 +261,25 @@ def _plot_nodes(self, boundary=None, engine='mpl', ax=None, **kwargs): def plotCells(self, engine='mpl', ax=None, usemask=True, river=None, islands=None, boundary=None, - bxcol='x', bycol='y', **kwargs): + bxcol='x', bycol='y', **kwargs): # pragma: no cover if usemask: mask = self.cell_mask.copy() else: mask = None + fig, ax = viz._check_ax(ax) if boundary is not None: - fg = viz.plotReachDF(boundary, bxcol, bycol) - - fig, ax = viz.plotCells(self.xn, self.yn, engine=engine, - ax=fg.axes[0, 0], mask=mask, **kwargs) + fig = viz.plotReachDF(boundary, bxcol, bycol, ax=ax) + + fig = viz.plotCells( + self.xn, + self.yn, + engine=engine, + ax=ax, + mask=mask, + **kwargs + ) if river is not None or islands is not None: fig, ax = viz.plotBoundaries(river=river, islands=islands, diff --git a/pygridtools/viz.py b/pygridtools/viz.py index 6c5ba2d..5bd3304 100644 --- a/pygridtools/viz.py +++ b/pygridtools/viz.py @@ -203,7 +203,10 @@ def _plot_cells_mpl(nodes_x, nodes_y, mask=None, ax=None): rows, cols = nodes_x.shape if mask is None: - mask = np.zeros(nodes_x.shape) + if hasattr(nodes_x, 'mask'): + mask = nodes_x.mask + else: + mask = np.zeros(nodes_x.shape) for jj in range(rows - 1): for ii in range(cols - 1): diff --git a/tests/test_core.py b/tests/test_core.py index e7a0354..6d7683e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -617,6 +617,17 @@ def test_writeGEFDCGridextFiles(self): known_filename ) + @nptest.dec.skipif(True) + def test_plotCells_basic(self): + fig, ax = self.mg.plotCells() + + @nptest.dec.skipif(True) + def test_plotCells_boundary(self): + fig, ax = self.mg.plotCells( + boundary=testing.makeSimpleBoundary(), + usemask=True + ) + class test_makeGrid(object): def setup(self):