Skip to content

Commit

Permalink
PERF/CLN: Avoid ravel in plotting (#58973)
Browse files Browse the repository at this point in the history
* Avoid ravel in plotting

* Use reshape instead of ravel

* Add type ignore
  • Loading branch information
mroeschke committed Jun 14, 2024
1 parent 3bcc95f commit c1dcd54
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 30 deletions.
9 changes: 2 additions & 7 deletions pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,6 @@ def _grouped_plot_by_column(
layout=layout,
)

_axes = flatten_axes(axes)

# GH 45465: move the "by" label based on "vert"
xlabel, ylabel = kwargs.pop("xlabel", None), kwargs.pop("ylabel", None)
if kwargs.get("vert", True):
Expand All @@ -322,8 +320,7 @@ def _grouped_plot_by_column(

ax_values = []

for i, col in enumerate(columns):
ax = _axes[i]
for ax, col in zip(flatten_axes(axes), columns):
gp_col = grouped[col]
keys, values = zip(*gp_col)
re_plotf = plotf(keys, values, ax, xlabel=xlabel, ylabel=ylabel, **kwargs)
Expand Down Expand Up @@ -531,10 +528,8 @@ def boxplot_frame_groupby(
figsize=figsize,
layout=layout,
)
axes = flatten_axes(axes)

data = {}
for (key, group), ax in zip(grouped, axes):
for (key, group), ax in zip(grouped, flatten_axes(axes)):
d = group.boxplot(
ax=ax, column=column, fontsize=fontsize, rot=rot, grid=grid, **kwds
)
Expand Down
2 changes: 1 addition & 1 deletion pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def _axes_and_fig(self) -> tuple[Sequence[Axes], Figure]:
fig.set_size_inches(self.figsize)
axes = self.ax

axes = flatten_axes(axes)
axes = np.fromiter(flatten_axes(axes), dtype=object)

if self.logx is True or self.loglog is True:
[a.set_xscale("log") for a in axes]
Expand Down
17 changes: 6 additions & 11 deletions pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]):
def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray:
"""Calculate bins given data"""
nd_values = data.infer_objects()._get_numeric_data()
values = np.ravel(nd_values)
values = nd_values.values
if nd_values.ndim == 2:
values = values.reshape(-1)
values = values[~isna(values)]

hist, bins = np.histogram(values, bins=bins, range=self._bin_range)
return bins
return np.histogram_bin_edges(values, bins=bins, range=self._bin_range)

# error: Signature of "_plot" incompatible with supertype "LinePlot"
@classmethod
Expand Down Expand Up @@ -322,10 +323,7 @@ def _grouped_plot(
naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout
)

_axes = flatten_axes(axes)

for i, (key, group) in enumerate(grouped):
ax = _axes[i]
for ax, (key, group) in zip(flatten_axes(axes), grouped):
if numeric_only and isinstance(group, ABCDataFrame):
group = group._get_numeric_data()
plotf(group, ax, **kwargs)
Expand Down Expand Up @@ -557,12 +555,9 @@ def hist_frame(
figsize=figsize,
layout=layout,
)
_axes = flatten_axes(axes)

can_set_label = "label" not in kwds

for i, col in enumerate(data.columns):
ax = _axes[i]
for ax, col in zip(flatten_axes(axes), data.columns):
if legend and can_set_label:
kwds["label"] = col
ax.hist(data[col].dropna().values, bins=bins, **kwds)
Expand Down
26 changes: 15 additions & 11 deletions pandas/plotting/_matplotlib/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
)

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import (
Generator,
Iterable,
)

from matplotlib.axes import Axes
from matplotlib.axis import Axis
Expand Down Expand Up @@ -231,7 +234,7 @@ def create_subplots(
else:
if is_list_like(ax):
if squeeze:
ax = flatten_axes(ax)
ax = np.fromiter(flatten_axes(ax), dtype=object)
if layout is not None:
warnings.warn(
"When passing multiple axes, layout keyword is ignored.",
Expand Down Expand Up @@ -260,7 +263,7 @@ def create_subplots(
if squeeze:
return fig, ax
else:
return fig, flatten_axes(ax)
return fig, np.fromiter(flatten_axes(ax), dtype=object)
else:
warnings.warn(
"To output multiple subplots, the figure containing "
Expand Down Expand Up @@ -439,12 +442,13 @@ def handle_shared_axes(
_remove_labels_from_axis(ax.yaxis)


def flatten_axes(axes: Axes | Iterable[Axes]) -> np.ndarray:
def flatten_axes(axes: Axes | Iterable[Axes]) -> Generator[Axes, None, None]:
if not is_list_like(axes):
return np.array([axes])
yield axes # type: ignore[misc]
elif isinstance(axes, (np.ndarray, ABCIndex)):
return np.asarray(axes).ravel()
return np.array(axes)
yield from np.asarray(axes).reshape(-1)
else:
yield from axes # type: ignore[misc]


def set_ticks_props(
Expand All @@ -456,13 +460,13 @@ def set_ticks_props(
):
for ax in flatten_axes(axes):
if xlabelsize is not None:
mpl.artist.setp(ax.get_xticklabels(), fontsize=xlabelsize)
mpl.artist.setp(ax.get_xticklabels(), fontsize=xlabelsize) # type: ignore[arg-type]
if xrot is not None:
mpl.artist.setp(ax.get_xticklabels(), rotation=xrot)
mpl.artist.setp(ax.get_xticklabels(), rotation=xrot) # type: ignore[arg-type]
if ylabelsize is not None:
mpl.artist.setp(ax.get_yticklabels(), fontsize=ylabelsize)
mpl.artist.setp(ax.get_yticklabels(), fontsize=ylabelsize) # type: ignore[arg-type]
if yrot is not None:
mpl.artist.setp(ax.get_yticklabels(), rotation=yrot)
mpl.artist.setp(ax.get_yticklabels(), rotation=yrot) # type: ignore[arg-type]
return axes


Expand Down

0 comments on commit c1dcd54

Please sign in to comment.