Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

Already on GitHub? Sign in to your account

Add kind='line' or 'bar' argument to DataFrame.plot #348

Closed
wants to merge 1 commit into
from
Jump to file or symbol
Failed to load files and symbols.
+46 −16
Split
View
@@ -2805,8 +2805,8 @@ def clip_lower(self, threshold):
#----------------------------------------------------------------------
# Plotting
- def plot(self, subplots=False, sharex=True, sharey=False, use_index=True,
- figsize=None, grid=True, legend=True, ax=None, **kwds):
+ def plot(self, kind='line', subplots=False, sharex=True, sharey=False, use_index=True,
+ figsize=None, grid=True, legend=True, rot=30, ax=None, **kwds):
"""
Make line plot of DataFrame's series with the index on the x-axis using
matplotlib / pylab.
@@ -2842,22 +2842,49 @@ def plot(self, subplots=False, sharex=True, sharey=False, use_index=True,
else:
fig = ax.get_figure()
- if use_index:
- x = self.index
- else:
- x = range(len(self))
+ if kind == 'line':
+ if use_index:
+ x = self.index
+ else:
+ x = range(len(self))
+
+ for i, col in enumerate(_try_sort(self.columns)):
+ empty = self[col].count() == 0
+ y = self[col].values if not empty else np.zeros(x.shape)
+ if subplots:
+ ax = axes[i]
+ ax.plot(x, y, 'k', label=str(col), **kwds)
+ ax.legend(loc='best')
+ else:
+ ax.plot(x, y, label=str(col), **kwds)
+
+ ax.grid(grid)
+ elif kind == 'bar':
+ N = len(self)
+ M = len(self.columns)
+ xinds = np.arange(N) + 0.25
+ colors = ['red', 'green', 'blue', 'yellow', 'black']
+ rects = []
+ labels = []
+ for i, col in enumerate(_try_sort(self.columns)):
+ empty = self[col].count() == 0
+ y = self[col].values if not empty else np.zeros(x.shape)
+ if subplots:
+ ax = axes[i]
+ ax.bar(xinds, y, 0.5,
+ bottom=np.zeros(N), linewidth=1, **kwds)
+ ax.set_title(col)
+ else:
+ rects.append(ax.bar(xinds+i*0.5/M,y,0.5/M,bottom=np.zeros(N),color=colors[i % len(colors)], **kwds))
+ labels.append(col)
- for i, col in enumerate(_try_sort(self.columns)):
- empty = self[col].count() == 0
- y = self[col].values if not empty else np.zeros(x.shape)
- if subplots:
- ax = axes[i]
- ax.plot(x, y, 'k', label=str(col), **kwds)
- ax.legend(loc='best')
+ if N < 10:
+ fontsize = 12
else:
- ax.plot(x, y, label=str(col), **kwds)
+ fontsize = 10
- ax.grid(grid)
+ ax.set_xticks(xinds + 0.25)
+ ax.set_xticklabels(self.index, rotation=rot, fontsize=fontsize)
# try to make things prettier
try:
@@ -2866,7 +2893,10 @@ def plot(self, subplots=False, sharex=True, sharey=False, use_index=True,
pass
if legend and not subplots:
- ax.legend(loc='best')
+ if kind == 'line':
+ ax.legend(loc='best')
+ else:
+ ax.legend([r[0] for r in rects],labels,loc='best')
plt.draw_if_interactive()