8989
9090 from pandas ._typing import (
9191 IndexLabel ,
92+ NDFrameT ,
9293 PlottingOrientation ,
9394 npt ,
9495 )
9596
96- from pandas import Series
97+ from pandas import (
98+ PeriodIndex ,
99+ Series ,
100+ )
97101
98102
99103def _color_in_style (style : str ) -> bool :
@@ -161,8 +165,6 @@ def __init__(
161165 ) -> None :
162166 import matplotlib .pyplot as plt
163167
164- self .data = data
165-
166168 # if users assign an empty list or tuple, raise `ValueError`
167169 # similar to current `df.box` and `df.hist` APIs.
168170 if by in ([], ()):
@@ -193,9 +195,11 @@ def __init__(
193195
194196 self .kind = kind
195197
196- self .subplots = self ._validate_subplots_kwarg (subplots )
198+ self .subplots = type (self )._validate_subplots_kwarg (
199+ subplots , data , kind = self ._kind
200+ )
197201
198- self .sharex = self ._validate_sharex (sharex , ax , by )
202+ self .sharex = type ( self ) ._validate_sharex (sharex , ax , by )
199203 self .sharey = sharey
200204 self .figsize = figsize
201205 self .layout = layout
@@ -245,10 +249,11 @@ def __init__(
245249 # parse errorbar input if given
246250 xerr = kwds .pop ("xerr" , None )
247251 yerr = kwds .pop ("yerr" , None )
248- self .errors = {
249- kw : self ._parse_errorbars (kw , err )
250- for kw , err in zip (["xerr" , "yerr" ], [xerr , yerr ])
251- }
252+ nseries = self ._get_nseries (data )
253+ xerr , data = type (self )._parse_errorbars ("xerr" , xerr , data , nseries )
254+ yerr , data = type (self )._parse_errorbars ("yerr" , yerr , data , nseries )
255+ self .errors = {"xerr" : xerr , "yerr" : yerr }
256+ self .data = data
252257
253258 if not isinstance (secondary_y , (bool , tuple , list , np .ndarray , ABCIndex )):
254259 secondary_y = [secondary_y ]
@@ -271,7 +276,8 @@ def __init__(
271276 self ._validate_color_args ()
272277
273278 @final
274- def _validate_sharex (self , sharex : bool | None , ax , by ) -> bool :
279+ @staticmethod
280+ def _validate_sharex (sharex : bool | None , ax , by ) -> bool :
275281 if sharex is None :
276282 # if by is defined, subplots are used and sharex should be False
277283 if ax is None and by is None : # pylint: disable=simplifiable-if-statement
@@ -285,8 +291,9 @@ def _validate_sharex(self, sharex: bool | None, ax, by) -> bool:
285291 return bool (sharex )
286292
287293 @final
294+ @staticmethod
288295 def _validate_subplots_kwarg (
289- self , subplots : bool | Sequence [Sequence [str ]]
296+ subplots : bool | Sequence [Sequence [str ]], data : Series | DataFrame , kind : str
290297 ) -> bool | list [tuple [int , ...]]:
291298 """
292299 Validate the subplots parameter
@@ -323,18 +330,18 @@ def _validate_subplots_kwarg(
323330 "area" ,
324331 "pie" ,
325332 )
326- if self . _kind not in supported_kinds :
333+ if kind not in supported_kinds :
327334 raise ValueError (
328335 "When subplots is an iterable, kind must be "
329- f"one of { ', ' .join (supported_kinds )} . Got { self . _kind } ."
336+ f"one of { ', ' .join (supported_kinds )} . Got { kind } ."
330337 )
331338
332- if isinstance (self . data , ABCSeries ):
339+ if isinstance (data , ABCSeries ):
333340 raise NotImplementedError (
334341 "An iterable subplots for a Series is not supported."
335342 )
336343
337- columns = self . data .columns
344+ columns = data .columns
338345 if isinstance (columns , ABCMultiIndex ):
339346 raise NotImplementedError (
340347 "An iterable subplots for a DataFrame with a MultiIndex column "
@@ -442,18 +449,22 @@ def _iter_data(
442449 # typing.
443450 yield col , np .asarray (values .values )
444451
445- @property
446- def nseries (self ) -> int :
452+ def _get_nseries (self , data : Series | DataFrame ) -> int :
447453 # When `by` is explicitly assigned, grouped data size will be defined, and
448454 # this will determine number of subplots to have, aka `self.nseries`
449- if self . data .ndim == 1 :
455+ if data .ndim == 1 :
450456 return 1
451457 elif self .by is not None and self ._kind == "hist" :
452458 return len (self ._grouped )
453459 elif self .by is not None and self ._kind == "box" :
454460 return len (self .columns )
455461 else :
456- return self .data .shape [1 ]
462+ return data .shape [1 ]
463+
464+ @final
465+ @property
466+ def nseries (self ) -> int :
467+ return self ._get_nseries (self .data )
457468
458469 @final
459470 def draw (self ) -> None :
@@ -880,10 +891,12 @@ def _get_xticks(self, convert_period: bool = False):
880891 index = self .data .index
881892 is_datetype = index .inferred_type in ("datetime" , "date" , "datetime64" , "time" )
882893
894+ x : list [int ] | np .ndarray
883895 if self .use_index :
884896 if convert_period and isinstance (index , ABCPeriodIndex ):
885897 self .data = self .data .reindex (index = index .sort_values ())
886- x = self .data .index .to_timestamp ()._mpl_repr ()
898+ index = cast ("PeriodIndex" , self .data .index )
899+ x = index .to_timestamp ()._mpl_repr ()
887900 elif is_any_real_numeric_dtype (index .dtype ):
888901 # Matplotlib supports numeric values or datetime objects as
889902 # xaxis values. Taking LBYL approach here, by the time
@@ -1050,8 +1063,12 @@ def _get_colors(
10501063 color = self .kwds .get (color_kwds ),
10511064 )
10521065
1066+ # TODO: tighter typing for first return?
10531067 @final
1054- def _parse_errorbars (self , label : str , err ):
1068+ @staticmethod
1069+ def _parse_errorbars (
1070+ label : str , err , data : NDFrameT , nseries : int
1071+ ) -> tuple [Any , NDFrameT ]:
10551072 """
10561073 Look for error keyword arguments and return the actual errorbar data
10571074 or return the error DataFrame/dict
@@ -1071,32 +1088,32 @@ def _parse_errorbars(self, label: str, err):
10711088 should be in a ``Mx2xN`` array.
10721089 """
10731090 if err is None :
1074- return None
1091+ return None , data
10751092
10761093 def match_labels (data , e ):
10771094 e = e .reindex (data .index )
10781095 return e
10791096
10801097 # key-matched DataFrame
10811098 if isinstance (err , ABCDataFrame ):
1082- err = match_labels (self . data , err )
1099+ err = match_labels (data , err )
10831100 # key-matched dict
10841101 elif isinstance (err , dict ):
10851102 pass
10861103
10871104 # Series of error values
10881105 elif isinstance (err , ABCSeries ):
10891106 # broadcast error series across data
1090- err = match_labels (self . data , err )
1107+ err = match_labels (data , err )
10911108 err = np .atleast_2d (err )
1092- err = np .tile (err , (self . nseries , 1 ))
1109+ err = np .tile (err , (nseries , 1 ))
10931110
10941111 # errors are a column in the dataframe
10951112 elif isinstance (err , str ):
1096- evalues = self . data [err ].values
1097- self . data = self . data [self . data .columns .drop (err )]
1113+ evalues = data [err ].values
1114+ data = data [data .columns .drop (err )]
10981115 err = np .atleast_2d (evalues )
1099- err = np .tile (err , (self . nseries , 1 ))
1116+ err = np .tile (err , (nseries , 1 ))
11001117
11011118 elif is_list_like (err ):
11021119 if is_iterator (err ):
@@ -1108,40 +1125,40 @@ def match_labels(data, e):
11081125 err_shape = err .shape
11091126
11101127 # asymmetrical error bars
1111- if isinstance (self . data , ABCSeries ) and err_shape [0 ] == 2 :
1128+ if isinstance (data , ABCSeries ) and err_shape [0 ] == 2 :
11121129 err = np .expand_dims (err , 0 )
11131130 err_shape = err .shape
1114- if err_shape [2 ] != len (self . data ):
1131+ if err_shape [2 ] != len (data ):
11151132 raise ValueError (
11161133 "Asymmetrical error bars should be provided "
1117- f"with the shape (2, { len (self . data )} )"
1134+ f"with the shape (2, { len (data )} )"
11181135 )
1119- elif isinstance (self . data , ABCDataFrame ) and err .ndim == 3 :
1136+ elif isinstance (data , ABCDataFrame ) and err .ndim == 3 :
11201137 if (
1121- (err_shape [0 ] != self . nseries )
1138+ (err_shape [0 ] != nseries )
11221139 or (err_shape [1 ] != 2 )
1123- or (err_shape [2 ] != len (self . data ))
1140+ or (err_shape [2 ] != len (data ))
11241141 ):
11251142 raise ValueError (
11261143 "Asymmetrical error bars should be provided "
1127- f"with the shape ({ self . nseries } , 2, { len (self . data )} )"
1144+ f"with the shape ({ nseries } , 2, { len (data )} )"
11281145 )
11291146
11301147 # broadcast errors to each data series
11311148 if len (err ) == 1 :
1132- err = np .tile (err , (self . nseries , 1 ))
1149+ err = np .tile (err , (nseries , 1 ))
11331150
11341151 elif is_number (err ):
11351152 err = np .tile (
11361153 [err ], # pyright: ignore[reportGeneralTypeIssues]
1137- (self . nseries , len (self . data )),
1154+ (nseries , len (data )),
11381155 )
11391156
11401157 else :
11411158 msg = f"No valid { label } detected"
11421159 raise ValueError (msg )
11431160
1144- return err
1161+ return err , data # pyright: ignore[reportGeneralTypeIssues]
11451162
11461163 @final
11471164 def _get_errorbars (
@@ -1215,8 +1232,7 @@ def __init__(self, data, x, y, **kwargs) -> None:
12151232 self .y = y
12161233
12171234 @final
1218- @property
1219- def nseries (self ) -> int :
1235+ def _get_nseries (self , data : Series | DataFrame ) -> int :
12201236 return 1
12211237
12221238 @final
0 commit comments