Skip to content

Commit 50d7eb3

Browse files
Merge pull request #605 from alexpvpmindustry/master
API for adding labels: `mpf.make_addplot(..., label="myLabel")`
2 parents 46dcc89 + cbda0af commit 50d7eb3

File tree

4 files changed

+626
-9
lines changed

4 files changed

+626
-9
lines changed

examples/addplot_legends.ipynb

+576
Large diffs are not rendered by default.

src/mplfinance/_arg_validators.py

+14
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import matplotlib as mpl
88
import warnings
99

10+
1011
def _check_and_prepare_data(data, config):
1112
'''
1213
Check and Prepare the data input:
@@ -94,6 +95,19 @@ def _check_and_prepare_data(data, config):
9495

9596
return dates, opens, highs, lows, closes, volumes
9697

98+
99+
def _label_validator(label_value):
100+
''' Validates the input of [legend] label for added plots.
101+
label_value may be a str or a sequence of str.
102+
'''
103+
if isinstance(label_value,str):
104+
return True
105+
if isinstance(label_value,(list,tuple,np.ndarray)):
106+
if all([isinstance(v,str) for v in label_value]):
107+
return True
108+
return False
109+
110+
97111
def _get_valid_plot_types(plottype=None):
98112

99113
_alias_types = {

src/mplfinance/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
version_info = (0, 12, 9, 'beta', 9)
1+
version_info = (0, 12, 10, 'beta', 0)
22

33
_specifier_ = {'alpha': 'a','beta': 'b','candidate': 'rc','final': ''}
44

src/mplfinance/plotting.py

+35-8
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
from mplfinance import _styles
3434

35-
from mplfinance._arg_validators import _check_and_prepare_data, _mav_validator
35+
from mplfinance._arg_validators import _check_and_prepare_data, _mav_validator, _label_validator
3636
from mplfinance._arg_validators import _get_valid_plot_types, _fill_between_validator
3737
from mplfinance._arg_validators import _process_kwargs, _validate_vkwargs_dict
3838
from mplfinance._arg_validators import _kwarg_not_implemented, _bypass_kwarg_validation
@@ -765,6 +765,8 @@ def plot( data, **kwargs ):
765765

766766
elif not _list_of_dict(addplot):
767767
raise TypeError('addplot must be `dict`, or `list of dict`, NOT '+str(type(addplot)))
768+
769+
contains_legend_label=[] # a list of axes that contains legend labels
768770

769771
for apdict in addplot:
770772

@@ -788,10 +790,28 @@ def plot( data, **kwargs ):
788790
else:
789791
havedf = False # must be a single series or array
790792
apdata = [apdata,] # make it iterable
793+
if havedf and apdict['label']:
794+
if not isinstance(apdict['label'],(list,tuple,np.ndarray)):
795+
nlabels = 1
796+
else:
797+
nlabels = len(apdict['label'])
798+
ncolumns = len(apdata.columns)
799+
#print('nlabels=',nlabels,'ncolumns=',ncolumns)
800+
if nlabels < ncolumns:
801+
warnings.warn('\n =======================================\n'+
802+
' addplot MISMATCH between data and labels:\n'+
803+
' have '+str(ncolumns)+' columns to plot \n'+
804+
' BUT '+str(nlabels)+' labels for them.\n')
805+
colcount = 0
791806
for column in apdata:
792807
ydata = apdata.loc[:,column] if havedf else column
793-
ax = _addplot_columns(panid,panels,ydata,apdict,xdates,config)
808+
ax = _addplot_columns(panid,panels,ydata,apdict,xdates,config,colcount)
794809
_addplot_apply_supplements(ax,apdict,xdates)
810+
colcount += 1
811+
if apdict['label']: # not supported for aptype == 'ohlc' or 'candle'
812+
contains_legend_label.append(ax)
813+
for ax in set(contains_legend_label): # there might be duplicates
814+
ax.legend()
795815

796816
# fill_between is NOT supported for external_axes_mode
797817
# (caller can easily call ax.fill_between() themselves).
@@ -1079,7 +1099,7 @@ def _addplot_collections(panid,panels,apdict,xdates,config):
10791099
ax.autoscale_view()
10801100
return ax
10811101

1082-
def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
1102+
def _addplot_columns(panid,panels,ydata,apdict,xdates,config,colcount):
10831103
external_axes_mode = apdict['ax'] is not None
10841104
if not external_axes_mode:
10851105
secondary_y = False
@@ -1101,6 +1121,10 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
11011121
ax = apdict['ax']
11021122

11031123
aptype = apdict['type']
1124+
if isinstance(apdict['label'],(list,tuple,np.ndarray)):
1125+
label = apdict['label'][colcount]
1126+
else: # isinstance(...,str)
1127+
label = apdict['label']
11041128
if aptype == 'scatter':
11051129
size = apdict['markersize']
11061130
mark = apdict['marker']
@@ -1111,27 +1135,27 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
11111135

11121136
if isinstance(mark,(list,tuple,np.ndarray)):
11131137
_mscatter(xdates, ydata, ax=ax, m=mark, s=size, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths)
1114-
else:
1115-
ax.scatter(xdates, ydata, s=size, marker=mark, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths)
1138+
else:
1139+
ax.scatter(xdates, ydata, s=size, marker=mark, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths,label=label)
11161140
elif aptype == 'bar':
11171141
width = 0.8 if apdict['width'] is None else apdict['width']
11181142
bottom = apdict['bottom']
11191143
color = apdict['color']
11201144
alpha = apdict['alpha']
1121-
ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha)
1145+
ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha,label=label)
11221146
elif aptype == 'line':
11231147
ls = apdict['linestyle']
11241148
color = apdict['color']
11251149
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
11261150
alpha = apdict['alpha']
1127-
ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha)
1151+
ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha,label=label)
11281152
elif aptype == 'step':
11291153
stepwhere = apdict['stepwhere']
11301154
ls = apdict['linestyle']
11311155
color = apdict['color']
11321156
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
11331157
alpha = apdict['alpha']
1134-
ax.step(xdates,ydata,where = stepwhere,linestyle=ls,color=color,linewidth=width,alpha=alpha)
1158+
ax.step(xdates,ydata,where = stepwhere,linestyle=ls,color=color,linewidth=width,alpha=alpha,label=label)
11351159
else:
11361160
raise ValueError('addplot type "'+str(aptype)+'" NOT yet supported.')
11371161

@@ -1384,6 +1408,9 @@ def _valid_addplot_kwargs():
13841408
'fill_between': { 'Default' : None, # added by Wen
13851409
'Description' : " fill region",
13861410
'Validator' : _fill_between_validator },
1411+
'label' : { 'Default' : None,
1412+
'Description' : 'Label for the added plot. One per added plot.',
1413+
'Validator' : _label_validator },
13871414

13881415
}
13891416

0 commit comments

Comments
 (0)