2222import matplotlib as mpl
2323import numpy as np
2424
25+ from pandas ._libs import lib
2526from pandas .errors import AbstractMethodError
2627from pandas .util ._decorators import cache_readonly
2728from pandas .util ._exceptions import find_stack_level
@@ -1221,13 +1222,6 @@ def __init__(self, data, x, y, **kwargs) -> None:
12211222 if is_integer (y ) and not self .data .columns ._holds_integer ():
12221223 y = self .data .columns [y ]
12231224
1224- # Scatter plot allows to plot objects data
1225- if self ._kind == "hexbin" :
1226- if len (self .data [x ]._get_numeric_data ()) == 0 :
1227- raise ValueError (self ._kind + " requires x column to be numeric" )
1228- if len (self .data [y ]._get_numeric_data ()) == 0 :
1229- raise ValueError (self ._kind + " requires y column to be numeric" )
1230-
12311225 self .x = x
12321226 self .y = y
12331227
@@ -1269,14 +1263,30 @@ class ScatterPlot(PlanePlot):
12691263 def _kind (self ) -> Literal ["scatter" ]:
12701264 return "scatter"
12711265
1272- def __init__ (self , data , x , y , s = None , c = None , ** kwargs ) -> None :
1266+ def __init__ (
1267+ self ,
1268+ data ,
1269+ x ,
1270+ y ,
1271+ s = None ,
1272+ c = None ,
1273+ * ,
1274+ colorbar : bool | lib .NoDefault = lib .no_default ,
1275+ norm = None ,
1276+ ** kwargs ,
1277+ ) -> None :
12731278 if s is None :
12741279 # hide the matplotlib default for size, in case we want to change
12751280 # the handling of this argument later
12761281 s = 20
12771282 elif is_hashable (s ) and s in data .columns :
12781283 s = data [s ]
1279- super ().__init__ (data , x , y , s = s , ** kwargs )
1284+ self .s = s
1285+
1286+ self .colorbar = colorbar
1287+ self .norm = norm
1288+
1289+ super ().__init__ (data , x , y , ** kwargs )
12801290 if is_integer (c ) and not self .data .columns ._holds_integer ():
12811291 c = self .data .columns [c ]
12821292 self .c = c
@@ -1292,6 +1302,44 @@ def _make_plot(self, fig: Figure):
12921302 )
12931303
12941304 color = self .kwds .pop ("color" , None )
1305+ c_values = self ._get_c_values (color , color_by_categorical , c_is_column )
1306+ norm , cmap = self ._get_norm_and_cmap (c_values , color_by_categorical )
1307+ cb = self ._get_colorbar (c_values , c_is_column )
1308+
1309+ if self .legend :
1310+ label = self .label
1311+ else :
1312+ label = None
1313+ scatter = ax .scatter (
1314+ data [x ].values ,
1315+ data [y ].values ,
1316+ c = c_values ,
1317+ label = label ,
1318+ cmap = cmap ,
1319+ norm = norm ,
1320+ s = self .s ,
1321+ ** self .kwds ,
1322+ )
1323+ if cb :
1324+ cbar_label = c if c_is_column else ""
1325+ cbar = self ._plot_colorbar (ax , fig = fig , label = cbar_label )
1326+ if color_by_categorical :
1327+ n_cats = len (self .data [c ].cat .categories )
1328+ cbar .set_ticks (np .linspace (0.5 , n_cats - 0.5 , n_cats ))
1329+ cbar .ax .set_yticklabels (self .data [c ].cat .categories )
1330+
1331+ if label is not None :
1332+ self ._append_legend_handles_labels (scatter , label )
1333+
1334+ errors_x = self ._get_errorbars (label = x , index = 0 , yerr = False )
1335+ errors_y = self ._get_errorbars (label = y , index = 0 , xerr = False )
1336+ if len (errors_x ) > 0 or len (errors_y ) > 0 :
1337+ err_kwds = dict (errors_x , ** errors_y )
1338+ err_kwds ["ecolor" ] = scatter .get_facecolor ()[0 ]
1339+ ax .errorbar (data [x ].values , data [y ].values , linestyle = "none" , ** err_kwds )
1340+
1341+ def _get_c_values (self , color , color_by_categorical : bool , c_is_column : bool ):
1342+ c = self .c
12951343 if c is not None and color is not None :
12961344 raise TypeError ("Specify exactly one of `c` and `color`" )
12971345 if c is None and color is None :
@@ -1304,7 +1352,10 @@ def _make_plot(self, fig: Figure):
13041352 c_values = self .data [c ].values
13051353 else :
13061354 c_values = c
1355+ return c_values
13071356
1357+ def _get_norm_and_cmap (self , c_values , color_by_categorical : bool ):
1358+ c = self .c
13081359 if self .colormap is not None :
13091360 cmap = mpl .colormaps .get_cmap (self .colormap )
13101361 # cmap is only used if c_values are integers, otherwise UserWarning.
@@ -1323,65 +1374,49 @@ def _make_plot(self, fig: Figure):
13231374 cmap = colors .ListedColormap ([cmap (i ) for i in range (cmap .N )])
13241375 bounds = np .linspace (0 , n_cats , n_cats + 1 )
13251376 norm = colors .BoundaryNorm (bounds , cmap .N )
1377+ # TODO: warn that we are ignoring self.norm if user specified it?
1378+ # Doesn't happen in any tests 2023-11-09
13261379 else :
1327- norm = self .kwds .pop ("norm" , None )
1380+ norm = self .norm
1381+ return norm , cmap
1382+
1383+ def _get_colorbar (self , c_values , c_is_column : bool ) -> bool :
13281384 # plot colorbar if
13291385 # 1. colormap is assigned, and
13301386 # 2.`c` is a column containing only numeric values
13311387 plot_colorbar = self .colormap or c_is_column
1332- cb = self .kwds .pop ("colorbar" , is_numeric_dtype (c_values ) and plot_colorbar )
1333-
1334- if self .legend and hasattr (self , "label" ):
1335- label = self .label
1336- else :
1337- label = None
1338- scatter = ax .scatter (
1339- data [x ].values ,
1340- data [y ].values ,
1341- c = c_values ,
1342- label = label ,
1343- cmap = cmap ,
1344- norm = norm ,
1345- ** self .kwds ,
1346- )
1347- if cb :
1348- cbar_label = c if c_is_column else ""
1349- cbar = self ._plot_colorbar (ax , fig = fig , label = cbar_label )
1350- if color_by_categorical :
1351- cbar .set_ticks (np .linspace (0.5 , n_cats - 0.5 , n_cats ))
1352- cbar .ax .set_yticklabels (self .data [c ].cat .categories )
1353-
1354- if label is not None :
1355- self ._append_legend_handles_labels (scatter , label )
1356- else :
1357- self .legend = False
1358-
1359- errors_x = self ._get_errorbars (label = x , index = 0 , yerr = False )
1360- errors_y = self ._get_errorbars (label = y , index = 0 , xerr = False )
1361- if len (errors_x ) > 0 or len (errors_y ) > 0 :
1362- err_kwds = dict (errors_x , ** errors_y )
1363- err_kwds ["ecolor" ] = scatter .get_facecolor ()[0 ]
1364- ax .errorbar (data [x ].values , data [y ].values , linestyle = "none" , ** err_kwds )
1388+ cb = self .colorbar
1389+ if cb is lib .no_default :
1390+ return is_numeric_dtype (c_values ) and plot_colorbar
1391+ return cb
13651392
13661393
13671394class HexBinPlot (PlanePlot ):
13681395 @property
13691396 def _kind (self ) -> Literal ["hexbin" ]:
13701397 return "hexbin"
13711398
1372- def __init__ (self , data , x , y , C = None , ** kwargs ) -> None :
1399+ def __init__ (self , data , x , y , C = None , * , colorbar : bool = True , * *kwargs ) -> None :
13731400 super ().__init__ (data , x , y , ** kwargs )
13741401 if is_integer (C ) and not self .data .columns ._holds_integer ():
13751402 C = self .data .columns [C ]
13761403 self .C = C
13771404
1405+ self .colorbar = colorbar
1406+
1407+ # Scatter plot allows to plot objects data
1408+ if len (self .data [self .x ]._get_numeric_data ()) == 0 :
1409+ raise ValueError (self ._kind + " requires x column to be numeric" )
1410+ if len (self .data [self .y ]._get_numeric_data ()) == 0 :
1411+ raise ValueError (self ._kind + " requires y column to be numeric" )
1412+
13781413 def _make_plot (self , fig : Figure ) -> None :
13791414 x , y , data , C = self .x , self .y , self .data , self .C
13801415 ax = self .axes [0 ]
13811416 # pandas uses colormap, matplotlib uses cmap.
13821417 cmap = self .colormap or "BuGn"
13831418 cmap = mpl .colormaps .get_cmap (cmap )
1384- cb = self .kwds . pop ( " colorbar" , True )
1419+ cb = self .colorbar
13851420
13861421 if C is None :
13871422 c_values = None
0 commit comments