Skip to content

Commit

Permalink
Add support for numeric x
Browse files Browse the repository at this point in the history
  • Loading branch information
timhoffm committed Jul 17, 2024
1 parent 1212500 commit 7241d40
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 9 deletions.
16 changes: 16 additions & 0 deletions galleries/examples/lines_bars_and_markers/grouped_bar_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,19 @@
#
# df = pd.DataFrame(data, index=x, columns=columns)
# df.plot.bar()

# %%
# Numeric x values
# ----------------
# In the most common case, one will want to pass categorical labels as *x*.
# Additionally, we allow numeric values for *x*, as with `~.Axes.bar()`.
# But for simplicity and clarity, we require that these are equidistant.

x = [0, 2, 4]
data = {
'data1': [1, 2, 3],
'data2': [1.2, 2.2, 3.2],
}

fig, ax = plt.subplots()
ax.grouped_bar(x, data)
35 changes: 26 additions & 9 deletions lib/matplotlib/axes/_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3004,8 +3004,12 @@ def grouped_bar(self, x, heights, dataset_labels=None):
"""
Parameters
-----------
x : array-like of str
The labels.
x : array-like or list of str
The center positions of the bar groups. If these are numeric values,
they have to be equidistant. As with `~.Axes.bar`, you can provide
categorical labels, which will be used at integer numeric positions
``range(x)``.
heights : list of array-like or dict of array-like or 2D array
The heights for all x and groups. One of:
Expand Down Expand Up @@ -3064,27 +3068,40 @@ def grouped_bar(self, x, heights, dataset_labels=None):
elif hasattr(heights, 'shape'):
heights = heights.T

num_labels = len(x)
num_groups = len(x)
num_datasets = len(heights)

for dataset in heights:
assert len(dataset) == num_labels
if isinstance(x[0], str):
tick_labels = x
group_centers = np.arange(num_groups)
else:
if num_groups > 1:
d = np.diff(x)
if not np.allclose(d, d.mean()):
raise ValueError("'x' must be equidistant")
group_centers = np.asarray(x)
tick_labels = None

for i, dataset in enumerate(heights):
if len(dataset) != num_groups:
raise ValueError(
f"'x' indicates {num_groups} groups, but dataset {i} "
f"has {len(dataset)} groups"
)

margin = 0.1
bar_width = (1 - 2 * margin) / num_datasets
block_centers = np.arange(num_labels)

if dataset_labels is None:
dataset_labels = [None] * num_datasets
else:
assert len(dataset_labels) == num_datasets

for i, (hs, dataset_label) in enumerate(zip(heights, dataset_labels)):
lefts = block_centers - 0.5 + margin + i * bar_width
print(i, x, lefts, hs, dataset_label)
lefts = group_centers - 0.5 + margin + i * bar_width
self.bar(lefts, hs, width=bar_width, align="edge", label=dataset_label)

self.xaxis.set_ticks(block_centers, labels=x)
self.xaxis.set_ticks(group_centers, labels=tick_labels)

# TODO: does not return anything for now

Expand Down

0 comments on commit 7241d40

Please sign in to comment.