32
32
33
33
from mplfinance import _styles
34
34
35
- from mplfinance ._arg_validators import _check_and_prepare_data , _mav_validator
35
+ from mplfinance ._arg_validators import _check_and_prepare_data , _mav_validator , _label_validator
36
36
from mplfinance ._arg_validators import _get_valid_plot_types , _fill_between_validator
37
37
from mplfinance ._arg_validators import _process_kwargs , _validate_vkwargs_dict
38
38
from mplfinance ._arg_validators import _kwarg_not_implemented , _bypass_kwarg_validation
@@ -765,6 +765,8 @@ def plot( data, **kwargs ):
765
765
766
766
elif not _list_of_dict (addplot ):
767
767
raise TypeError ('addplot must be `dict`, or `list of dict`, NOT ' + str (type (addplot )))
768
+
769
+ contains_legend_label = [] # a list of axes that contains legend labels
768
770
769
771
for apdict in addplot :
770
772
@@ -788,10 +790,28 @@ def plot( data, **kwargs ):
788
790
else :
789
791
havedf = False # must be a single series or array
790
792
apdata = [apdata ,] # make it iterable
793
+ if havedf and apdict ['label' ]:
794
+ if not isinstance (apdict ['label' ],(list ,tuple ,np .ndarray )):
795
+ nlabels = 1
796
+ else :
797
+ nlabels = len (apdict ['label' ])
798
+ ncolumns = len (apdata .columns )
799
+ #print('nlabels=',nlabels,'ncolumns=',ncolumns)
800
+ if nlabels < ncolumns :
801
+ warnings .warn ('\n =======================================\n ' +
802
+ ' addplot MISMATCH between data and labels:\n ' +
803
+ ' have ' + str (ncolumns )+ ' columns to plot \n ' +
804
+ ' BUT ' + str (nlabels )+ ' labels for them.\n ' )
805
+ colcount = 0
791
806
for column in apdata :
792
807
ydata = apdata .loc [:,column ] if havedf else column
793
- ax = _addplot_columns (panid ,panels ,ydata ,apdict ,xdates ,config )
808
+ ax = _addplot_columns (panid ,panels ,ydata ,apdict ,xdates ,config , colcount )
794
809
_addplot_apply_supplements (ax ,apdict ,xdates )
810
+ colcount += 1
811
+ if apdict ['label' ]: # not supported for aptype == 'ohlc' or 'candle'
812
+ contains_legend_label .append (ax )
813
+ for ax in set (contains_legend_label ): # there might be duplicates
814
+ ax .legend ()
795
815
796
816
# fill_between is NOT supported for external_axes_mode
797
817
# (caller can easily call ax.fill_between() themselves).
@@ -1079,7 +1099,7 @@ def _addplot_collections(panid,panels,apdict,xdates,config):
1079
1099
ax .autoscale_view ()
1080
1100
return ax
1081
1101
1082
- def _addplot_columns (panid ,panels ,ydata ,apdict ,xdates ,config ):
1102
+ def _addplot_columns (panid ,panels ,ydata ,apdict ,xdates ,config , colcount ):
1083
1103
external_axes_mode = apdict ['ax' ] is not None
1084
1104
if not external_axes_mode :
1085
1105
secondary_y = False
@@ -1101,6 +1121,10 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
1101
1121
ax = apdict ['ax' ]
1102
1122
1103
1123
aptype = apdict ['type' ]
1124
+ if isinstance (apdict ['label' ],(list ,tuple ,np .ndarray )):
1125
+ label = apdict ['label' ][colcount ]
1126
+ else : # isinstance(...,str)
1127
+ label = apdict ['label' ]
1104
1128
if aptype == 'scatter' :
1105
1129
size = apdict ['markersize' ]
1106
1130
mark = apdict ['marker' ]
@@ -1111,27 +1135,27 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
1111
1135
1112
1136
if isinstance (mark ,(list ,tuple ,np .ndarray )):
1113
1137
_mscatter (xdates , ydata , ax = ax , m = mark , s = size , color = color , alpha = alpha , edgecolors = edgecolors , linewidths = linewidths )
1114
- else :
1115
- ax .scatter (xdates , ydata , s = size , marker = mark , color = color , alpha = alpha , edgecolors = edgecolors , linewidths = linewidths )
1138
+ else :
1139
+ ax .scatter (xdates , ydata , s = size , marker = mark , color = color , alpha = alpha , edgecolors = edgecolors , linewidths = linewidths , label = label )
1116
1140
elif aptype == 'bar' :
1117
1141
width = 0.8 if apdict ['width' ] is None else apdict ['width' ]
1118
1142
bottom = apdict ['bottom' ]
1119
1143
color = apdict ['color' ]
1120
1144
alpha = apdict ['alpha' ]
1121
- ax .bar (xdates ,ydata ,width = width ,bottom = bottom ,color = color ,alpha = alpha )
1145
+ ax .bar (xdates ,ydata ,width = width ,bottom = bottom ,color = color ,alpha = alpha , label = label )
1122
1146
elif aptype == 'line' :
1123
1147
ls = apdict ['linestyle' ]
1124
1148
color = apdict ['color' ]
1125
1149
width = apdict ['width' ] if apdict ['width' ] is not None else 1.6 * config ['_width_config' ]['line_width' ]
1126
1150
alpha = apdict ['alpha' ]
1127
- ax .plot (xdates ,ydata ,linestyle = ls ,color = color ,linewidth = width ,alpha = alpha )
1151
+ ax .plot (xdates ,ydata ,linestyle = ls ,color = color ,linewidth = width ,alpha = alpha , label = label )
1128
1152
elif aptype == 'step' :
1129
1153
stepwhere = apdict ['stepwhere' ]
1130
1154
ls = apdict ['linestyle' ]
1131
1155
color = apdict ['color' ]
1132
1156
width = apdict ['width' ] if apdict ['width' ] is not None else 1.6 * config ['_width_config' ]['line_width' ]
1133
1157
alpha = apdict ['alpha' ]
1134
- ax .step (xdates ,ydata ,where = stepwhere ,linestyle = ls ,color = color ,linewidth = width ,alpha = alpha )
1158
+ ax .step (xdates ,ydata ,where = stepwhere ,linestyle = ls ,color = color ,linewidth = width ,alpha = alpha , label = label )
1135
1159
else :
1136
1160
raise ValueError ('addplot type "' + str (aptype )+ '" NOT yet supported.' )
1137
1161
@@ -1384,6 +1408,9 @@ def _valid_addplot_kwargs():
1384
1408
'fill_between' : { 'Default' : None , # added by Wen
1385
1409
'Description' : " fill region" ,
1386
1410
'Validator' : _fill_between_validator },
1411
+ 'label' : { 'Default' : None ,
1412
+ 'Description' : 'Label for the added plot. One per added plot.' ,
1413
+ 'Validator' : _label_validator },
1387
1414
1388
1415
}
1389
1416
0 commit comments