22
33from typing import (
44 TYPE_CHECKING ,
5+ Any ,
56 Literal ,
7+ final ,
68)
79
810import numpy as np
@@ -58,13 +60,15 @@ def __init__(
5860 bottom : int | np .ndarray = 0 ,
5961 * ,
6062 range = None ,
63+ weights = None ,
6164 ** kwargs ,
6265 ) -> None :
6366 if is_list_like (bottom ):
6467 bottom = np .array (bottom )
6568 self .bottom = bottom
6669
6770 self ._bin_range = range
71+ self .weights = weights
6872
6973 self .xlabel = kwargs .get ("xlabel" )
7074 self .ylabel = kwargs .get ("ylabel" )
@@ -96,7 +100,7 @@ def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray:
96100 @classmethod
97101 def _plot ( # type: ignore[override]
98102 cls ,
99- ax ,
103+ ax : Axes ,
100104 y ,
101105 style = None ,
102106 bottom : int | np .ndarray = 0 ,
@@ -140,7 +144,7 @@ def _make_plot(self, fig: Figure) -> None:
140144 if style is not None :
141145 kwds ["style" ] = style
142146
143- kwds = self ._make_plot_keywords (kwds , y )
147+ self ._make_plot_keywords (kwds , y )
144148
145149 # the bins is multi-dimension array now and each plot need only 1-d and
146150 # when by is applied, label should be columns that are grouped
@@ -149,21 +153,8 @@ def _make_plot(self, fig: Figure) -> None:
149153 kwds ["label" ] = self .columns
150154 kwds .pop ("color" )
151155
152- # We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
153- # and each sub-array (10,) will be called in each iteration. If users only
154- # provide 1D array, we assume the same weights is used for all iterations
155- weights = kwds .get ("weights" , None )
156- if weights is not None :
157- if np .ndim (weights ) != 1 and np .shape (weights )[- 1 ] != 1 :
158- try :
159- weights = weights [:, i ]
160- except IndexError as err :
161- raise ValueError (
162- "weights must have the same shape as data, "
163- "or be a single column"
164- ) from err
165- weights = weights [~ isna (y )]
166- kwds ["weights" ] = weights
156+ if self .weights is not None :
157+ kwds ["weights" ] = self ._get_column_weights (self .weights , i , y )
167158
168159 y = reformat_hist_y_given_by (y , self .by )
169160
@@ -175,12 +166,29 @@ def _make_plot(self, fig: Figure) -> None:
175166
176167 self ._append_legend_handles_labels (artists [0 ], label )
177168
178- def _make_plot_keywords (self , kwds , y ) :
169+ def _make_plot_keywords (self , kwds : dict [ str , Any ], y ) -> None :
179170 """merge BoxPlot/KdePlot properties to passed kwds"""
180171 # y is required for KdePlot
181172 kwds ["bottom" ] = self .bottom
182173 kwds ["bins" ] = self .bins
183- return kwds
174+
175+ @final
176+ @staticmethod
177+ def _get_column_weights (weights , i : int , y ):
178+ # We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
179+ # and each sub-array (10,) will be called in each iteration. If users only
180+ # provide 1D array, we assume the same weights is used for all iterations
181+ if weights is not None :
182+ if np .ndim (weights ) != 1 and np .shape (weights )[- 1 ] != 1 :
183+ try :
184+ weights = weights [:, i ]
185+ except IndexError as err :
186+ raise ValueError (
187+ "weights must have the same shape as data, "
188+ "or be a single column"
189+ ) from err
190+ weights = weights [~ isna (y )]
191+ return weights
184192
185193 def _post_plot_logic (self , ax : Axes , data ) -> None :
186194 if self .orientation == "horizontal" :
@@ -207,11 +215,14 @@ def _kind(self) -> Literal["kde"]:
207215 def orientation (self ) -> Literal ["vertical" ]:
208216 return "vertical"
209217
210- def __init__ (self , data , bw_method = None , ind = None , ** kwargs ) -> None :
218+ def __init__ (
219+ self , data , bw_method = None , ind = None , * , weights = None , ** kwargs
220+ ) -> None :
211221 # Do not call LinePlot.__init__ which may fill nan
212222 MPLPlot .__init__ (self , data , ** kwargs ) # pylint: disable=non-parent-init-called
213223 self .bw_method = bw_method
214224 self .ind = ind
225+ self .weights = weights
215226
216227 @staticmethod
217228 def _get_ind (y , ind ):
@@ -233,9 +244,10 @@ def _get_ind(y, ind):
233244 return ind
234245
235246 @classmethod
236- def _plot (
247+ # error: Signature of "_plot" incompatible with supertype "MPLPlot"
248+ def _plot ( # type: ignore[override]
237249 cls ,
238- ax ,
250+ ax : Axes ,
239251 y ,
240252 style = None ,
241253 bw_method = None ,
@@ -253,10 +265,9 @@ def _plot(
253265 lines = MPLPlot ._plot (ax , ind , y , style = style , ** kwds )
254266 return lines
255267
256- def _make_plot_keywords (self , kwds , y ) :
268+ def _make_plot_keywords (self , kwds : dict [ str , Any ], y ) -> None :
257269 kwds ["bw_method" ] = self .bw_method
258270 kwds ["ind" ] = self ._get_ind (y , ind = self .ind )
259- return kwds
260271
261272 def _post_plot_logic (self , ax , data ) -> None :
262273 ax .set_ylabel ("Density" )
0 commit comments