diff --git a/_doc/api/tools/pandas.rst b/_doc/api/tools/pandas.rst index 58b06a0..352314a 100644 --- a/_doc/api/tools/pandas.rst +++ b/_doc/api/tools/pandas.rst @@ -2,6 +2,24 @@ teachpyx.tools.pandas ===================== +Example +------- + +.. plot:: + + import pandas + import matplotlib.pyplot as plt + from teachpyx.tools.pandas import plot_waterfall + + plt.close("all") + + df = pandas.DataFrame({"name": ["A", "B", "C"], "delta": [10, -3, 5]}) + ax, _ = plot_waterfall(df, "delta", "name", total_label="TOTAL") + ax.set_title("Example waterfall") + plt.xticks(rotation=30, ha="right") + plt.tight_layout() + plt.show() + .. automodule:: teachpyx.tools.pandas :members: :no-undoc-members: diff --git a/_unittests/ut_tools/test_pandas.py b/_unittests/ut_tools/test_pandas.py index 4833207..d5c3c97 100644 --- a/_unittests/ut_tools/test_pandas.py +++ b/_unittests/ut_tools/test_pandas.py @@ -1,6 +1,8 @@ import unittest +from matplotlib.axes import Axes +import pandas from teachpyx.ext_test_case import ExtTestCase -from teachpyx.tools.pandas import read_csv_cached +from teachpyx.tools.pandas import plot_waterfall, read_csv_cached class TestPandas(ExtTestCase): @@ -14,6 +16,39 @@ def test_read_csv_cached(self): self.assertEqual(df.shape, df2.shape) self.assertEqual(list(df.columns), list(df2.columns)) + def test_plot_waterfall(self): + df = pandas.DataFrame( + { + "name": ["A", "B", "C"], + "delta": [10, -3, 5], + } + ) + ax, plot_df = plot_waterfall(df, "delta", "name", total_label="TOTAL") + self.assertIsInstance(ax, Axes) + self.assertEqual(list(plot_df["label"]), ["A", "B", "C", "TOTAL"]) + self.assertEqual(list(plot_df["start"]), [0.0, 10.0, 7.0, 0.0]) + self.assertEqual(list(plot_df["end"]), [10.0, 7.0, 12.0, 12.0]) + + def test_plot_waterfall_missing_column(self): + df = pandas.DataFrame({"name": ["A"], "delta": [1]}) + with self.assertRaises(ValueError): + plot_waterfall(df, "missing", "name") + + def test_plot_waterfall_missing_label_column(self): + df = pandas.DataFrame({"name": ["A"], "delta": [1]}) + with self.assertRaises(ValueError): + plot_waterfall(df, "delta", "missing") + + def test_plot_waterfall_bad_colors(self): + df = pandas.DataFrame({"name": ["A"], "delta": [1]}) + with self.assertRaises(ValueError): + plot_waterfall(df, "delta", "name", colors=("r",)) + + def test_plot_waterfall_not_numeric(self): + df = pandas.DataFrame({"name": ["A"], "delta": ["x"]}) + with self.assertRaises(ValueError): + plot_waterfall(df, "delta", "name") + if __name__ == "__main__": unittest.main() diff --git a/teachpyx/tools/pandas.py b/teachpyx/tools/pandas.py index ca5f08b..8962704 100644 --- a/teachpyx/tools/pandas.py +++ b/teachpyx/tools/pandas.py @@ -2,6 +2,7 @@ import os import re from pathlib import Path +from typing import Optional, Tuple from urllib.parse import urlparse, unquote import pandas @@ -46,3 +47,83 @@ def read_csv_cached( df = pandas.read_csv(filepath_or_buffer, **kwargs) df.to_csv(cache_name, index=False) return df + + +def plot_waterfall( + data: pandas.DataFrame, + value_column: str, + label_column: Optional[str] = None, + total_label: str = "total", + ax=None, + colors: Tuple[str, str, str] = ("#2ca02c", "#d62728", "#1f77b4"), +): + """ + Draws a waterfall chart from a dataframe. + + :param data: dataframe containing increments + :param value_column: column with increments + :param label_column: column with labels, index is used if None + :param total_label: label used for the final total + :param ax: existing axis or None to create one + :param colors: positive, negative, total colors + :return: axis, computed dataframe used to draw the chart + + .. versionadded:: 0.6.1 + """ + if value_column not in data.columns: + raise ValueError(f"Unable to find column {value_column!r} in dataframe.") + if label_column is not None and label_column not in data.columns: + raise ValueError(f"Unable to find column {label_column!r} in dataframe.") + if len(colors) != 3: + raise ValueError(f"colors must contain 3 values, not {len(colors)}.") + + values = pandas.to_numeric(data[value_column], errors="raise").astype(float) + labels = data[label_column] if label_column is not None else data.index + labels = labels.astype(str) + + starts = values.cumsum().shift(1, fill_value=0.0) + plot_df = pandas.DataFrame( + { + "label": labels, + "value": values, + "start": starts, + "end": starts + values, + "kind": "variation", + } + ) + + total = float(values.sum()) + total_row = pandas.DataFrame( + { + "label": [total_label], + "value": [total], + "start": [0.0], + "end": [total], + "kind": ["total"], + } + ) + plot_df = pandas.concat([plot_df, total_row], axis=0, ignore_index=True) + + if ax is None: + import matplotlib.pyplot as plt + + _, ax = plt.subplots(1, 1) + + bar_colors = [ + colors[2] + if kind == "total" + else (colors[0] if value >= 0 else colors[1]) + for value, kind in zip(plot_df["value"], plot_df["kind"]) + ] + ax.bar( + plot_df["label"], + plot_df["value"], + bottom=plot_df["start"], + color=bar_colors, + ) + + ax.axhline(0, color="black", linewidth=0.8) + ax.set_ylabel(value_column) + ax.set_xlabel(label_column or "index") + + return ax, plot_df