Skip to content

Commit

Permalink
Merge pull request #199 from jenshnielsen/fix_rgba_error
Browse files Browse the repository at this point in the history
Fix rgba error
  • Loading branch information
astafan8 committed Jun 7, 2021
2 parents 3c23ca6 + 68b31f0 commit f60efe3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
7 changes: 3 additions & 4 deletions plottr/plot/mpl/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
"""

from enum import Enum, auto, unique
from typing import Optional, Tuple, Any, Union
from typing import Any, Optional, Tuple, Union

import numpy as np
from matplotlib import colors, rcParams
from matplotlib.axes import Axes
from matplotlib.image import AxesImage

from plottr.utils import num
from plottr.utils.num import interp_meshgrid_2d, centers2edges_2d

from plottr.utils.num import centers2edges_2d, interp_meshgrid_2d

__author__ = 'Wolfgang Pfaff'
__license__ = 'MIT'
Expand Down Expand Up @@ -125,7 +124,7 @@ def colorplot2d(ax: Axes,
elif plotType is PlotType.colormesh:
im = ppcolormesh_from_meshgrid(ax, x, y, z, cmap=cmap, **kw)
elif plotType is PlotType.scatter2d:
im = ax.scatter(x, y, c=z, cmap=cmap, **kw)
im = ax.scatter(x.ravel(), y.ravel(), c=z.ravel(), cmap=cmap, **kw)
else:
im = None

Expand Down
30 changes: 30 additions & 0 deletions test/pytest/test_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import matplotlib.pyplot as plt
import numpy as np
from plottr.plot.mpl.plotting import PlotType, colorplot2d


def test_colorplot2d_scatter_rgba_error():
"""
Check that scatter plots are not trying to plot 1x3 and 1x4
z arrays as rgb(a) colors.
"""
fig, ax = plt.subplots(1, 1)
x = np.array([[0.0, 11.11111111, 22.22222222, 33.33333333]])
y = np.array(
[
[
0.0,
0.0,
0.0,
0.0,
]
]
)
z = np.array([[5.08907021, 4.93923391, 5.11400073, 5.0925613]])
colorplot2d(ax, x, y, z, PlotType.scatter2d)

x = np.array([[0.0, 11.11111111, 22.22222222]])
y = np.array([[0.0, 0.0, 0.0]])
z = np.array([[5.08907021, 4.93923391, 5.11400073]])
colorplot2d(ax, x, y, z, PlotType.scatter2d)

0 comments on commit f60efe3

Please sign in to comment.