Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions _doc/api/tools/pandas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,24 @@
teachpyx.tools.pandas
=====================

Example
-------

.. plot::

import pandas
import matplotlib.pyplot as plt
from teachpyx.tools.pandas import plot_waterfall

plt.close("all")

df = pandas.DataFrame({"name": ["A", "B", "C"], "delta": [10, -3, 5]})
ax, _ = plot_waterfall(df, "delta", "name", total_label="TOTAL")
ax.set_title("Example waterfall")
plt.xticks(rotation=30, ha="right")
plt.tight_layout()
plt.show()

.. automodule:: teachpyx.tools.pandas
:members:
:no-undoc-members:
37 changes: 36 additions & 1 deletion _unittests/ut_tools/test_pandas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import unittest
from matplotlib.axes import Axes
import pandas
from teachpyx.ext_test_case import ExtTestCase
from teachpyx.tools.pandas import read_csv_cached
from teachpyx.tools.pandas import plot_waterfall, read_csv_cached


class TestPandas(ExtTestCase):
Expand All @@ -14,6 +16,39 @@ def test_read_csv_cached(self):
self.assertEqual(df.shape, df2.shape)
self.assertEqual(list(df.columns), list(df2.columns))

def test_plot_waterfall(self):
df = pandas.DataFrame(
{
"name": ["A", "B", "C"],
"delta": [10, -3, 5],
}
)
ax, plot_df = plot_waterfall(df, "delta", "name", total_label="TOTAL")
self.assertIsInstance(ax, Axes)
self.assertEqual(list(plot_df["label"]), ["A", "B", "C", "TOTAL"])
self.assertEqual(list(plot_df["start"]), [0.0, 10.0, 7.0, 0.0])
self.assertEqual(list(plot_df["end"]), [10.0, 7.0, 12.0, 12.0])
Comment on lines 1 to +30
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plot_waterfall creates a new Matplotlib figure when ax is None, and this unit test will exercise that path. In headless CI environments, importing/using matplotlib.pyplot can fail or pick an interactive backend unless a non-interactive backend (e.g., Agg) is forced by the test runner/environment. Consider explicitly forcing a non-interactive backend for this test suite (or in this test) and closing the created figure to avoid accumulating open figures across tests.

Copilot uses AI. Check for mistakes.

def test_plot_waterfall_missing_column(self):
df = pandas.DataFrame({"name": ["A"], "delta": [1]})
with self.assertRaises(ValueError):
plot_waterfall(df, "missing", "name")

def test_plot_waterfall_missing_label_column(self):
df = pandas.DataFrame({"name": ["A"], "delta": [1]})
with self.assertRaises(ValueError):
plot_waterfall(df, "delta", "missing")

def test_plot_waterfall_bad_colors(self):
df = pandas.DataFrame({"name": ["A"], "delta": [1]})
with self.assertRaises(ValueError):
plot_waterfall(df, "delta", "name", colors=("r",))

def test_plot_waterfall_not_numeric(self):
df = pandas.DataFrame({"name": ["A"], "delta": ["x"]})
with self.assertRaises(ValueError):
plot_waterfall(df, "delta", "name")


if __name__ == "__main__":
unittest.main()
81 changes: 81 additions & 0 deletions teachpyx/tools/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import re
from pathlib import Path
from typing import Optional, Tuple
from urllib.parse import urlparse, unquote
import pandas

Expand Down Expand Up @@ -46,3 +47,83 @@ def read_csv_cached(
df = pandas.read_csv(filepath_or_buffer, **kwargs)
df.to_csv(cache_name, index=False)
return df


def plot_waterfall(
data: pandas.DataFrame,
value_column: str,
label_column: Optional[str] = None,
total_label: str = "total",
ax=None,
colors: Tuple[str, str, str] = ("#2ca02c", "#d62728", "#1f77b4"),
):
"""
Draws a waterfall chart from a dataframe.

:param data: dataframe containing increments
:param value_column: column with increments
:param label_column: column with labels, index is used if None
:param total_label: label used for the final total
:param ax: existing axis or None to create one
:param colors: positive, negative, total colors
:return: axis, computed dataframe used to draw the chart

.. versionadded:: 0.6.1
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring marks this helper as .. versionadded:: 0.6.1, but the project version in pyproject.toml is currently 0.6.0. Please align the versionadded tag with the release version that will actually contain this function (either bump the package version as part of this change, or adjust the tag) to avoid misleading API docs.

Suggested change
.. versionadded:: 0.6.1
.. versionadded:: 0.6.0

Copilot uses AI. Check for mistakes.
"""
if value_column not in data.columns:
raise ValueError(f"Unable to find column {value_column!r} in dataframe.")
if label_column is not None and label_column not in data.columns:
raise ValueError(f"Unable to find column {label_column!r} in dataframe.")
if len(colors) != 3:
raise ValueError(f"colors must contain 3 values, not {len(colors)}.")

values = pandas.to_numeric(data[value_column], errors="raise").astype(float)
labels = data[label_column] if label_column is not None else data.index
labels = labels.astype(str)

starts = values.cumsum().shift(1, fill_value=0.0)
plot_df = pandas.DataFrame(
{
"label": labels,
"value": values,
"start": starts,
"end": starts + values,
"kind": "variation",
}
)

total = float(values.sum())
total_row = pandas.DataFrame(
{
"label": [total_label],
"value": [total],
"start": [0.0],
"end": [total],
"kind": ["total"],
}
)
plot_df = pandas.concat([plot_df, total_row], axis=0, ignore_index=True)

if ax is None:
import matplotlib.pyplot as plt

_, ax = plt.subplots(1, 1)

bar_colors = [
colors[2]
if kind == "total"
else (colors[0] if value >= 0 else colors[1])
for value, kind in zip(plot_df["value"], plot_df["kind"])
]
ax.bar(
plot_df["label"],
plot_df["value"],
bottom=plot_df["start"],
color=bar_colors,
)

ax.axhline(0, color="black", linewidth=0.8)
ax.set_ylabel(value_column)
ax.set_xlabel(label_column or "index")

return ax, plot_df