From f68761c12f042f2f8a85d58fa0731b9ffee12dac Mon Sep 17 00:00:00 2001 From: Aahil Mehta Date: Fri, 24 Sep 2021 15:12:00 +0100 Subject: [PATCH] Update attribute plots for binary variables (#119) * update binary plots * adapt seaborn.countplot for normalization * remove kwargs from _countplot * update args for plotter * add tests for distr_plot * Update ci.yml Co-authored-by: George-Bogdan Surdu <51715053+bogdansurdu@users.noreply.github.com> Co-authored-by: Simon Swan --- .github/workflows/ci.yml | 4 +- src/fairlens/plot/distr.py | 86 ++++++++++++++++++++++++++++++++++++++ tests/test_plot.py | 11 ++++- 3 files changed, 98 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9c89665f..3acd8eb1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -46,7 +46,7 @@ jobs: SYNTHESIZED_KEY: ${{ secrets.SYNTHESIZED_KEY }} - name: Upload Codecov report - if: ${{ matrix.python-version }} == '3.7' + if: ${{ matrix.python-version == 3.7 }} uses: codecov/codecov-action@v1.5.2 with: files: coverage-reports/cobertura.xml @@ -54,7 +54,7 @@ jobs: fail_ci_if_error: false - name: SonarCloud Scan - if: ${{ matrix.python-version }} == '3.7' + if: ${{ matrix.python-version == 3.7 }} uses: SonarSource/sonarcloud-github-action@master with: projectBaseDir: . diff --git a/src/fairlens/plot/distr.py b/src/fairlens/plot/distr.py index c8881fb5..6225affd 100644 --- a/src/fairlens/plot/distr.py +++ b/src/fairlens/plot/distr.py @@ -11,9 +11,12 @@ import pandas as pd import seaborn as sns from matplotlib.axes import Axes +from seaborn.categorical import _CountPlotter from .. import utils +LABEL_THRESH = 5 + def distr_plot( df: pd.DataFrame, @@ -188,6 +191,9 @@ def attr_distr_plot( col = utils.infer_dtype(df_[attr]) + if distr_type is None: + distr_type = utils.infer_distr_type(df_[target_attr]).value + if attr_distr_type is None: attr_distr_type = utils.infer_distr_type(col).value @@ -229,6 +235,16 @@ def attr_distr_plot( return None + if distr_type == "binary": + _countplot(x=df_[attr], hue=df_[target_attr], palette=cmap, normalize=normalize) + plt.title(attr) + + if df_[attr].nunique() > LABEL_THRESH: + plt.xticks(rotation=45) + plt.tight_layout() + + return ax + distr_plot( df_, target_attr, @@ -347,3 +363,73 @@ def _shade_area(ax: Axes, cmap: Sequence[Tuple[float, float, float]], alpha: flo for line in ax.lines: xy = line.get_xydata() ax.fill_between(xy[:, 0], xy[:, 1], color=next(palette), alpha=alpha) + + +def _countplot( + x: Any = None, + y: Any = None, + hue: Any = None, + data: Any = None, + normalize: bool = False, + order: List[str] = None, + hue_order: List[str] = None, + orient: Optional[str] = None, + color: Any = None, + palette: Any = None, + saturation: float = 0.75, + dodge: bool = True, + ax: Axes = None, +) -> Axes: + """Adaptation of seaborn.countplot""" + + def prob(a): + return len(a) / len(x) + + estimator = prob if normalize else len + ci = None + n_boot = 0 + units = None + seed = None + errcolor = None + errwidth = None + capsize = None + + if x is None and y is not None: + orient = "h" + x = y + elif y is None and x is not None: + orient = "v" + y = x + elif x is not None and y is not None: + raise ValueError("Cannot pass values for both `x` and `y`") + + plotter = _CountPlotter( + x=x, + y=y, + hue=hue, + data=data, + order=order, + hue_order=hue_order, + estimator=estimator, + ci=ci, + n_boot=n_boot, + units=units, + seed=seed, + orient=orient, + color=color, + palette=palette, + saturation=saturation, + errcolor=errcolor, + errwidth=errwidth, + capsize=capsize, + dodge=dodge, + ) + + plotter.value_label = "probability" if normalize else "count" + + if ax is None: + ax = plt.gca() + + plotter.plot(ax, {}) + + return ax diff --git a/tests/test_plot.py b/tests/test_plot.py index 47c8e007..d2e8c717 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -11,11 +11,20 @@ def test_distr_plot(): distr_plot(dfc, "RawScore", [{"Sex": ["Male"]}, {"Sex": ["Female"]}, {"Ethnicity": ["Asian"]}]) - distr_plot(dfc, "RawScore", [{"Sex": ["Male"]}, dfc["Sex"] == "Female"], cmap=sns.color_palette()) + + groups = [{"Sex": ["Male"]}, dfc["Sex"] == "Female"] + distr_plot(dfc, "RawScore", groups, cmap=sns.color_palette()) + distr_plot(dfc, "RawScore", groups, show_curve=None) + distr_plot(dfc, "RawScore", groups, show_hist=True, show_curve=False) + distr_plot(dfc, "RawScore", groups, show_hist=False, show_curve=True) + distr_plot(dfc, "RawScore", groups, normalize=True) + distr_plot(dfc, "RawScore", groups, normalize=True, distr_type="continuous") + distr_plot(dfc, "DateOfBirth", groups, normalize=True, distr_type="datetime") def test_attr_distr_plot(): attr_distr_plot(dfc, "RawScore", "Sex") + attr_distr_plot(dfc, "RawScore", "Sex", distr_type="continuous", attr_distr_type="binary") attr_distr_plot(dfc, "RawScore", "Ethnicity", separate=True)