Skip to content

Commit

Permalink
ENH: Add grouped_bar() method
Browse files Browse the repository at this point in the history
This is a WIP to implement matplotlib#24313. It will be updated incrementally. As
a first step, I've designed the data and label input API. Feedback is
welcome.
  • Loading branch information
timhoffm committed Jul 17, 2024
1 parent 9b8a8c7 commit 1212500
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 0 deletions.
104 changes: 104 additions & 0 deletions galleries/examples/lines_bars_and_markers/grouped_bar_chart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
=================
Grouped bar chart
=================
This example serves to develop and discuss the API. It's geared towards illustating
API usage and design decisions only through the development phase. It's not intended
to go into the final PR in this form.
Case 1: multiple separate datasets
----------------------------------
"""
import matplotlib.pyplot as plt
import numpy as np

x = ['A', 'B']
data1 = [1, 1.2]
data2 = [2, 2.4]
data3 = [3, 3.6]


fig, axs = plt.subplots(1, 2)

# current solution: manual positioning with multiple bar)= calls
label_pos = np.array([0, 1])
bar_width = 0.8 / 3
data_shift = -1*bar_width + np.array([0, bar_width, 2*bar_width])
axs[0].bar(label_pos + data_shift[0], data1, width=bar_width, label="data1")
axs[0].bar(label_pos + data_shift[1], data2, width=bar_width, label="data2")
axs[0].bar(label_pos + data_shift[2], data3, width=bar_width, label="data3")
axs[0].set_xticks(label_pos, x)
axs[0].legend()

# grouped_bar() with list of datasets
# note also that this is a straight-forward generalization of the single-dataset case:
# bar(x, data1, label="data1")
axs[1].grouped_bar(x, [data1, data2, data3], dataset_labels=["data1", "data2", "data3"])


# %%
# Case 1b: multiple datasets as dict
# ----------------------------------
# instead of carrying a list of datasets and a list of dataset labels, users may
# want to organized their datasets in a dict.

datasets = {
'data1': data1,
'data2': data2,
'data3': data3,
}

# %%
# While you can feed keys and values into the above API, it may be convenient to pass
# the whole dict as "data" and automatically extract the labels from the keys:

fig, axs = plt.subplots(1, 2)

# explicitly extract values and labels from a dict and feed to grouped_bar():
axs[0].grouped_bar(x, datasets.values(), dataset_labels=datasets.keys())
# accepting a dict as input
axs[1].grouped_bar(x, datasets)

# %%
# Case 2: 2D array data
# ---------------------
# When receiving a 2D array, we interpret the data as
#
# .. code-block:: none
#
# dataset_0 dataset_1 dataset_2
# x[0]='A' ds0_a ds1_a ds2_a
# x[1]='B' ds0_b ds1_b ds2_b
#
# This is consistent with the standard data science interpretation of instances
# on the vertical and features on the horizontal. And also matches how pandas is
# interpreting rows and columns.
#
# Note that a list of individual datasets and a 2D array behave structurally different,
# i.e. hen turning a list into a numpy array, you have to transpose that array to get
# the correct representation. Those two behave the same::
#
# grouped_bar(x, [data1, data2])
# grouped_bar(x, np.array([data1, data2]).T)
#
# This is a conscious decision, because the commonly understood dimension ordering
# semantics of "list of datasets" and 2D array of datasets is different.

x = ['A', 'B']
data = np.array([
[1, 2, 3],
[1.2, 2.4, 3.6],
])
columns = ["data1", "data2", "data3"]

fig, ax = plt.subplots()
ax.grouped_bar(x, data, dataset_labels=columns)

# %%
# This creates the same plot as pandas (code cannot be executed because pandas
# os not a doc dependency)::
#
# df = pd.DataFrame(data, index=x, columns=columns)
# df.plot.bar()
88 changes: 88 additions & 0 deletions lib/matplotlib/axes/_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3000,6 +3000,94 @@ def broken_barh(self, xranges, yrange, **kwargs):

return col

def grouped_bar(self, x, heights, dataset_labels=None):
"""
Parameters
-----------
x : array-like of str
The labels.
heights : list of array-like or dict of array-like or 2D array
The heights for all x and groups. One of:
- list of array-like: A list of datasets, each dataset must have
``len(x)`` elements.
.. code-block:: none
x = ['a', 'b']
group_labels = ['ds0', 'ds1', 'ds2']
# group_labels: ds0 ds1 dw2
heights = [dataset_0, dataset_1, dataset_2]
# x[0] x[1]
dataset_0 = [ds0_a, ds0_b]
# x[0] x[1]
heights = [[ds0_a, ds0_b], # dataset_0
[ds1_a, ds1_b], # dataset_1
[ds2_a, ds2_b], # dataset_2
]
- dict of array-like: A names to datasets, each dataset (dict value)
must have ``len(x)`` elements.
group_labels = heights.keys()
heights = heights.values()
- a 2D array: columns map to *x*, columns are the different datasets.
.. code-block:: none
dataset_0 dataset_1 dataset_2
x[0]='a' ds0_a ds1_a ds2_a
x[1]='b' ds0_b ds1_b ds2_b
Note that this is consistent with pandas. These two calls produce
the same bar plot structure::
grouped_bar(x, array, group_labels=group_labels)
pd.DataFrame(array, index=x, columns=group_labels).plot.bar()
An iterable of array-like: The iteration runs over the groups.
Each individual array-like is the list of label values for that group.
dataset_labels : array-like of str, optional
The labels of the datasets.
"""
if hasattr(heights, 'keys'):
if dataset_labels is not None:
raise ValueError(
"'dataset_labels' cannot be used if 'heights' are a mapping")
dataset_labels = heights.keys()
heights = heights.values()
elif hasattr(heights, 'shape'):
heights = heights.T

num_labels = len(x)
num_datasets = len(heights)

for dataset in heights:
assert len(dataset) == num_labels

margin = 0.1
bar_width = (1 - 2 * margin) / num_datasets
block_centers = np.arange(num_labels)

if dataset_labels is None:
dataset_labels = [None] * num_datasets
else:
assert len(dataset_labels) == num_datasets

for i, (hs, dataset_label) in enumerate(zip(heights, dataset_labels)):
lefts = block_centers - 0.5 + margin + i * bar_width
print(i, x, lefts, hs, dataset_label)
self.bar(lefts, hs, width=bar_width, align="edge", label=dataset_label)

self.xaxis.set_ticks(block_centers, labels=x)

# TODO: does not return anything for now

@_preprocess_data()
def stem(self, *args, linefmt=None, markerfmt=None, basefmt=None, bottom=0,
label=None, orientation='vertical'):
Expand Down

0 comments on commit 1212500

Please sign in to comment.