From 12125009ab177ab50d005de69551c1c288118e15 Mon Sep 17 00:00:00 2001 From: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> Date: Sat, 13 Jul 2024 10:53:22 +0200 Subject: [PATCH] ENH: Add grouped_bar() method This is a WIP to implement #24313. It will be updated incrementally. As a first step, I've designed the data and label input API. Feedback is welcome. --- .../grouped_bar_chart.py | 104 ++++++++++++++++++ lib/matplotlib/axes/_axes.py | 88 +++++++++++++++ 2 files changed, 192 insertions(+) create mode 100644 galleries/examples/lines_bars_and_markers/grouped_bar_chart.py diff --git a/galleries/examples/lines_bars_and_markers/grouped_bar_chart.py b/galleries/examples/lines_bars_and_markers/grouped_bar_chart.py new file mode 100644 index 000000000000..d5ffdce5e6c4 --- /dev/null +++ b/galleries/examples/lines_bars_and_markers/grouped_bar_chart.py @@ -0,0 +1,104 @@ +""" +================= +Grouped bar chart +================= + +This example serves to develop and discuss the API. It's geared towards illustating +API usage and design decisions only through the development phase. It's not intended +to go into the final PR in this form. + +Case 1: multiple separate datasets +---------------------------------- + +""" +import matplotlib.pyplot as plt +import numpy as np + +x = ['A', 'B'] +data1 = [1, 1.2] +data2 = [2, 2.4] +data3 = [3, 3.6] + + +fig, axs = plt.subplots(1, 2) + +# current solution: manual positioning with multiple bar)= calls +label_pos = np.array([0, 1]) +bar_width = 0.8 / 3 +data_shift = -1*bar_width + np.array([0, bar_width, 2*bar_width]) +axs[0].bar(label_pos + data_shift[0], data1, width=bar_width, label="data1") +axs[0].bar(label_pos + data_shift[1], data2, width=bar_width, label="data2") +axs[0].bar(label_pos + data_shift[2], data3, width=bar_width, label="data3") +axs[0].set_xticks(label_pos, x) +axs[0].legend() + +# grouped_bar() with list of datasets +# note also that this is a straight-forward generalization of the single-dataset case: +# bar(x, data1, label="data1") +axs[1].grouped_bar(x, [data1, data2, data3], dataset_labels=["data1", "data2", "data3"]) + + +# %% +# Case 1b: multiple datasets as dict +# ---------------------------------- +# instead of carrying a list of datasets and a list of dataset labels, users may +# want to organized their datasets in a dict. + +datasets = { + 'data1': data1, + 'data2': data2, + 'data3': data3, +} + +# %% +# While you can feed keys and values into the above API, it may be convenient to pass +# the whole dict as "data" and automatically extract the labels from the keys: + +fig, axs = plt.subplots(1, 2) + +# explicitly extract values and labels from a dict and feed to grouped_bar(): +axs[0].grouped_bar(x, datasets.values(), dataset_labels=datasets.keys()) +# accepting a dict as input +axs[1].grouped_bar(x, datasets) + +# %% +# Case 2: 2D array data +# --------------------- +# When receiving a 2D array, we interpret the data as +# +# .. code-block:: none +# +# dataset_0 dataset_1 dataset_2 +# x[0]='A' ds0_a ds1_a ds2_a +# x[1]='B' ds0_b ds1_b ds2_b +# +# This is consistent with the standard data science interpretation of instances +# on the vertical and features on the horizontal. And also matches how pandas is +# interpreting rows and columns. +# +# Note that a list of individual datasets and a 2D array behave structurally different, +# i.e. hen turning a list into a numpy array, you have to transpose that array to get +# the correct representation. Those two behave the same:: +# +# grouped_bar(x, [data1, data2]) +# grouped_bar(x, np.array([data1, data2]).T) +# +# This is a conscious decision, because the commonly understood dimension ordering +# semantics of "list of datasets" and 2D array of datasets is different. + +x = ['A', 'B'] +data = np.array([ + [1, 2, 3], + [1.2, 2.4, 3.6], +]) +columns = ["data1", "data2", "data3"] + +fig, ax = plt.subplots() +ax.grouped_bar(x, data, dataset_labels=columns) + +# %% +# This creates the same plot as pandas (code cannot be executed because pandas +# os not a doc dependency):: +# +# df = pd.DataFrame(data, index=x, columns=columns) +# df.plot.bar() diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index 5d5248951314..b751cb7458af 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -3000,6 +3000,94 @@ def broken_barh(self, xranges, yrange, **kwargs): return col + def grouped_bar(self, x, heights, dataset_labels=None): + """ + Parameters + ----------- + x : array-like of str + The labels. + heights : list of array-like or dict of array-like or 2D array + The heights for all x and groups. One of: + + - list of array-like: A list of datasets, each dataset must have + ``len(x)`` elements. + + .. code-block:: none + + x = ['a', 'b'] + group_labels = ['ds0', 'ds1', 'ds2'] + + # group_labels: ds0 ds1 dw2 + heights = [dataset_0, dataset_1, dataset_2] + + # x[0] x[1] + dataset_0 = [ds0_a, ds0_b] + + # x[0] x[1] + heights = [[ds0_a, ds0_b], # dataset_0 + [ds1_a, ds1_b], # dataset_1 + [ds2_a, ds2_b], # dataset_2 + ] + + - dict of array-like: A names to datasets, each dataset (dict value) + must have ``len(x)`` elements. + + group_labels = heights.keys() + heights = heights.values() + + - a 2D array: columns map to *x*, columns are the different datasets. + + .. code-block:: none + + dataset_0 dataset_1 dataset_2 + x[0]='a' ds0_a ds1_a ds2_a + x[1]='b' ds0_b ds1_b ds2_b + + Note that this is consistent with pandas. These two calls produce + the same bar plot structure:: + + grouped_bar(x, array, group_labels=group_labels) + pd.DataFrame(array, index=x, columns=group_labels).plot.bar() + + + An iterable of array-like: The iteration runs over the groups. + Each individual array-like is the list of label values for that group. + dataset_labels : array-like of str, optional + The labels of the datasets. + """ + if hasattr(heights, 'keys'): + if dataset_labels is not None: + raise ValueError( + "'dataset_labels' cannot be used if 'heights' are a mapping") + dataset_labels = heights.keys() + heights = heights.values() + elif hasattr(heights, 'shape'): + heights = heights.T + + num_labels = len(x) + num_datasets = len(heights) + + for dataset in heights: + assert len(dataset) == num_labels + + margin = 0.1 + bar_width = (1 - 2 * margin) / num_datasets + block_centers = np.arange(num_labels) + + if dataset_labels is None: + dataset_labels = [None] * num_datasets + else: + assert len(dataset_labels) == num_datasets + + for i, (hs, dataset_label) in enumerate(zip(heights, dataset_labels)): + lefts = block_centers - 0.5 + margin + i * bar_width + print(i, x, lefts, hs, dataset_label) + self.bar(lefts, hs, width=bar_width, align="edge", label=dataset_label) + + self.xaxis.set_ticks(block_centers, labels=x) + + # TODO: does not return anything for now + @_preprocess_data() def stem(self, *args, linefmt=None, markerfmt=None, basefmt=None, bottom=0, label=None, orientation='vertical'):