77from collections .abc import (
88 Hashable ,
99 Iterable ,
10+ Iterator ,
1011 Sequence ,
1112)
1213from typing import (
@@ -431,17 +432,15 @@ def _validate_color_args(self):
431432 )
432433
433434 @final
434- def _iter_data (self , data = None , keep_index : bool = False , fillna = None ):
435- if data is None :
436- data = self .data
437- if fillna is not None :
438- data = data .fillna (fillna )
439-
435+ @staticmethod
436+ def _iter_data (
437+ data : DataFrame | dict [Hashable , Series | DataFrame ]
438+ ) -> Iterator [tuple [Hashable , np .ndarray ]]:
440439 for col , values in data .items ():
441- if keep_index is True :
442- yield col , values
443- else :
444- yield col , values .values
440+ # This was originally written to use values.values before EAs
441+ # were implemented; adding np.asarray(...) to keep consistent
442+ # typing.
443+ yield col , np . asarray ( values .values )
445444
446445 @property
447446 def nseries (self ) -> int :
@@ -480,7 +479,7 @@ def _has_plotted_object(ax: Axes) -> bool:
480479 return len (ax .lines ) != 0 or len (ax .artists ) != 0 or len (ax .containers ) != 0
481480
482481 @final
483- def _maybe_right_yaxis (self , ax : Axes , axes_num : int ):
482+ def _maybe_right_yaxis (self , ax : Axes , axes_num : int ) -> Axes :
484483 if not self .on_right (axes_num ):
485484 # secondary axes may be passed via ax kw
486485 return self ._get_ax_layer (ax )
@@ -656,11 +655,7 @@ def _compute_plot_data(self):
656655
657656 numeric_data = data .select_dtypes (include = include_type , exclude = exclude_type )
658657
659- try :
660- is_empty = numeric_data .columns .empty
661- except AttributeError :
662- is_empty = not len (numeric_data )
663-
658+ is_empty = numeric_data .shape [- 1 ] == 0
664659 # no non-numeric frames or series allowed
665660 if is_empty :
666661 raise TypeError ("no numeric data to plot" )
@@ -682,7 +677,7 @@ def _add_table(self) -> None:
682677 tools .table (ax , data )
683678
684679 @final
685- def _post_plot_logic_common (self , ax , data ):
680+ def _post_plot_logic_common (self , ax : Axes , data ) -> None :
686681 """Common post process for each axes"""
687682 if self .orientation == "vertical" or self .orientation is None :
688683 self ._apply_axis_properties (ax .xaxis , rot = self .rot , fontsize = self .fontsize )
@@ -701,7 +696,7 @@ def _post_plot_logic_common(self, ax, data):
701696 raise ValueError
702697
703698 @abstractmethod
704- def _post_plot_logic (self , ax , data ) -> None :
699+ def _post_plot_logic (self , ax : Axes , data ) -> None :
705700 """Post process for each axes. Overridden in child classes"""
706701
707702 @final
@@ -1056,7 +1051,7 @@ def _get_colors(
10561051 )
10571052
10581053 @final
1059- def _parse_errorbars (self , label , err ):
1054+ def _parse_errorbars (self , label : str , err ):
10601055 """
10611056 Look for error keyword arguments and return the actual errorbar data
10621057 or return the error DataFrame/dict
@@ -1137,7 +1132,10 @@ def match_labels(data, e):
11371132 err = np .tile (err , (self .nseries , 1 ))
11381133
11391134 elif is_number (err ):
1140- err = np .tile ([err ], (self .nseries , len (self .data )))
1135+ err = np .tile (
1136+ [err ], # pyright: ignore[reportGeneralTypeIssues]
1137+ (self .nseries , len (self .data )),
1138+ )
11411139
11421140 else :
11431141 msg = f"No valid { label } detected"
@@ -1418,14 +1416,14 @@ def _make_plot(self, fig: Figure) -> None:
14181416
14191417 x = data .index # dummy, not used
14201418 plotf = self ._ts_plot
1421- it = self . _iter_data ( data = data , keep_index = True )
1419+ it = data . items ( )
14221420 else :
14231421 x = self ._get_xticks (convert_period = True )
14241422 # error: Incompatible types in assignment (expression has type
14251423 # "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has
14261424 # type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]")
14271425 plotf = self ._plot # type: ignore[assignment]
1428- it = self ._iter_data ()
1426+ it = self ._iter_data (data = self . data )
14291427
14301428 stacking_id = self ._get_stacking_id ()
14311429 is_errorbar = com .any_not_none (* self .errors .values ())
@@ -1434,7 +1432,12 @@ def _make_plot(self, fig: Figure) -> None:
14341432 for i , (label , y ) in enumerate (it ):
14351433 ax = self ._get_ax (i )
14361434 kwds = self .kwds .copy ()
1437- style , kwds = self ._apply_style_colors (colors , kwds , i , label )
1435+ style , kwds = self ._apply_style_colors (
1436+ colors ,
1437+ kwds ,
1438+ i ,
1439+ label , # pyright: ignore[reportGeneralTypeIssues]
1440+ )
14381441
14391442 errors = self ._get_errorbars (label = label , index = i )
14401443 kwds = dict (kwds , ** errors )
@@ -1446,7 +1449,7 @@ def _make_plot(self, fig: Figure) -> None:
14461449 newlines = plotf (
14471450 ax ,
14481451 x ,
1449- y ,
1452+ y , # pyright: ignore[reportGeneralTypeIssues]
14501453 style = style ,
14511454 column_num = i ,
14521455 stacking_id = stacking_id ,
@@ -1465,7 +1468,14 @@ def _make_plot(self, fig: Figure) -> None:
14651468 # error: Signature of "_plot" incompatible with supertype "MPLPlot"
14661469 @classmethod
14671470 def _plot ( # type: ignore[override]
1468- cls , ax : Axes , x , y , style = None , column_num = None , stacking_id = None , ** kwds
1471+ cls ,
1472+ ax : Axes ,
1473+ x ,
1474+ y : np .ndarray ,
1475+ style = None ,
1476+ column_num = None ,
1477+ stacking_id = None ,
1478+ ** kwds ,
14691479 ):
14701480 # column_num is used to get the target column from plotf in line and
14711481 # area plots
@@ -1492,7 +1502,7 @@ def _ts_plot(self, ax: Axes, x, data: Series, style=None, **kwds):
14921502 decorate_axes (ax .right_ax , freq , kwds )
14931503 ax ._plot_data .append ((data , self ._kind , kwds ))
14941504
1495- lines = self ._plot (ax , data .index , data .values , style = style , ** kwds )
1505+ lines = self ._plot (ax , data .index , np . asarray ( data .values ) , style = style , ** kwds )
14961506 # set date formatter, locators and rescale limits
14971507 # error: Argument 3 to "format_dateaxis" has incompatible type "Index";
14981508 # expected "DatetimeIndex | PeriodIndex"
@@ -1520,7 +1530,9 @@ def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None:
15201530
15211531 @final
15221532 @classmethod
1523- def _get_stacked_values (cls , ax : Axes , stacking_id , values , label ):
1533+ def _get_stacked_values (
1534+ cls , ax : Axes , stacking_id : int | None , values : np .ndarray , label
1535+ ) -> np .ndarray :
15241536 if stacking_id is None :
15251537 return values
15261538 if not hasattr (ax , "_stacker_pos_prior" ):
@@ -1540,7 +1552,7 @@ def _get_stacked_values(cls, ax: Axes, stacking_id, values, label):
15401552
15411553 @final
15421554 @classmethod
1543- def _update_stacker (cls , ax : Axes , stacking_id , values ) -> None :
1555+ def _update_stacker (cls , ax : Axes , stacking_id : int | None , values ) -> None :
15441556 if stacking_id is None :
15451557 return
15461558 if (values >= 0 ).all ():
@@ -1618,7 +1630,7 @@ def _plot( # type: ignore[override]
16181630 cls ,
16191631 ax : Axes ,
16201632 x ,
1621- y ,
1633+ y : np . ndarray ,
16221634 style = None ,
16231635 column_num = None ,
16241636 stacking_id = None ,
@@ -1744,7 +1756,7 @@ def _plot( # type: ignore[override]
17441756 cls ,
17451757 ax : Axes ,
17461758 x ,
1747- y ,
1759+ y : np . ndarray ,
17481760 w ,
17491761 start : int | npt .NDArray [np .intp ] = 0 ,
17501762 log : bool = False ,
@@ -1763,7 +1775,8 @@ def _make_plot(self, fig: Figure) -> None:
17631775 pos_prior = neg_prior = np .zeros (len (self .data ))
17641776 K = self .nseries
17651777
1766- for i , (label , y ) in enumerate (self ._iter_data (fillna = 0 )):
1778+ data = self .data .fillna (0 )
1779+ for i , (label , y ) in enumerate (self ._iter_data (data = data )):
17671780 ax = self ._get_ax (i )
17681781 kwds = self .kwds .copy ()
17691782 if self ._is_series :
@@ -1842,7 +1855,14 @@ def _post_plot_logic(self, ax: Axes, data) -> None:
18421855
18431856 self ._decorate_ticks (ax , self ._get_index_name (), str_index , s_edge , e_edge )
18441857
1845- def _decorate_ticks (self , ax : Axes , name , ticklabels , start_edge , end_edge ) -> None :
1858+ def _decorate_ticks (
1859+ self ,
1860+ ax : Axes ,
1861+ name : str | None ,
1862+ ticklabels : list [str ],
1863+ start_edge : float ,
1864+ end_edge : float ,
1865+ ) -> None :
18461866 ax .set_xlim ((start_edge , end_edge ))
18471867
18481868 if self .xticks is not None :
@@ -1876,7 +1896,7 @@ def _plot( # type: ignore[override]
18761896 cls ,
18771897 ax : Axes ,
18781898 x ,
1879- y ,
1899+ y : np . ndarray ,
18801900 w ,
18811901 start : int | npt .NDArray [np .intp ] = 0 ,
18821902 log : bool = False ,
@@ -1887,7 +1907,14 @@ def _plot( # type: ignore[override]
18871907 def _get_custom_index_name (self ):
18881908 return self .ylabel
18891909
1890- def _decorate_ticks (self , ax : Axes , name , ticklabels , start_edge , end_edge ) -> None :
1910+ def _decorate_ticks (
1911+ self ,
1912+ ax : Axes ,
1913+ name : str | None ,
1914+ ticklabels : list [str ],
1915+ start_edge : float ,
1916+ end_edge : float ,
1917+ ) -> None :
18911918 # horizontal bars
18921919 ax .set_ylim ((start_edge , end_edge ))
18931920 ax .set_yticks (self .tick_pos )
@@ -1921,7 +1948,7 @@ def _make_plot(self, fig: Figure) -> None:
19211948 colors = self ._get_colors (num_colors = len (self .data ), color_kwds = "colors" )
19221949 self .kwds .setdefault ("colors" , colors )
19231950
1924- for i , (label , y ) in enumerate (self ._iter_data ()):
1951+ for i , (label , y ) in enumerate (self ._iter_data (data = self . data )):
19251952 ax = self ._get_ax (i )
19261953 if label is not None :
19271954 label = pprint_thing (label )
0 commit comments