Skip to content

Commit

Permalink
Update attribute plots for binary variables (#119)
Browse files Browse the repository at this point in the history
* update binary plots

* adapt seaborn.countplot for normalization

* remove kwargs from _countplot

* update args for plotter

* add tests for distr_plot

* Update ci.yml

Co-authored-by: George-Bogdan Surdu <51715053+bogdansurdu@users.noreply.github.com>
Co-authored-by: Simon Swan <simon@synthesized.io>
  • Loading branch information
3 people committed Sep 24, 2021
1 parent 735165c commit f68761c
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 3 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ jobs:
SYNTHESIZED_KEY: ${{ secrets.SYNTHESIZED_KEY }}

- name: Upload Codecov report
if: ${{ matrix.python-version }} == '3.7'
if: ${{ matrix.python-version == 3.7 }}
uses: codecov/codecov-action@v1.5.2
with:
files: coverage-reports/cobertura.xml
flags: unittests
fail_ci_if_error: false

- name: SonarCloud Scan
if: ${{ matrix.python-version }} == '3.7'
if: ${{ matrix.python-version == 3.7 }}
uses: SonarSource/sonarcloud-github-action@master
with:
projectBaseDir: .
Expand Down
86 changes: 86 additions & 0 deletions src/fairlens/plot/distr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
from seaborn.categorical import _CountPlotter

from .. import utils

LABEL_THRESH = 5


def distr_plot(
df: pd.DataFrame,
Expand Down Expand Up @@ -188,6 +191,9 @@ def attr_distr_plot(

col = utils.infer_dtype(df_[attr])

if distr_type is None:
distr_type = utils.infer_distr_type(df_[target_attr]).value

if attr_distr_type is None:
attr_distr_type = utils.infer_distr_type(col).value

Expand Down Expand Up @@ -229,6 +235,16 @@ def attr_distr_plot(

return None

if distr_type == "binary":
_countplot(x=df_[attr], hue=df_[target_attr], palette=cmap, normalize=normalize)
plt.title(attr)

if df_[attr].nunique() > LABEL_THRESH:
plt.xticks(rotation=45)
plt.tight_layout()

return ax

distr_plot(
df_,
target_attr,
Expand Down Expand Up @@ -347,3 +363,73 @@ def _shade_area(ax: Axes, cmap: Sequence[Tuple[float, float, float]], alpha: flo
for line in ax.lines:
xy = line.get_xydata()
ax.fill_between(xy[:, 0], xy[:, 1], color=next(palette), alpha=alpha)


def _countplot(
x: Any = None,
y: Any = None,
hue: Any = None,
data: Any = None,
normalize: bool = False,
order: List[str] = None,
hue_order: List[str] = None,
orient: Optional[str] = None,
color: Any = None,
palette: Any = None,
saturation: float = 0.75,
dodge: bool = True,
ax: Axes = None,
) -> Axes:
"""Adaptation of seaborn.countplot"""

def prob(a):
return len(a) / len(x)

estimator = prob if normalize else len
ci = None
n_boot = 0
units = None
seed = None
errcolor = None
errwidth = None
capsize = None

if x is None and y is not None:
orient = "h"
x = y
elif y is None and x is not None:
orient = "v"
y = x
elif x is not None and y is not None:
raise ValueError("Cannot pass values for both `x` and `y`")

plotter = _CountPlotter(
x=x,
y=y,
hue=hue,
data=data,
order=order,
hue_order=hue_order,
estimator=estimator,
ci=ci,
n_boot=n_boot,
units=units,
seed=seed,
orient=orient,
color=color,
palette=palette,
saturation=saturation,
errcolor=errcolor,
errwidth=errwidth,
capsize=capsize,
dodge=dodge,
)

plotter.value_label = "probability" if normalize else "count"

if ax is None:
ax = plt.gca()

plotter.plot(ax, {})

return ax
11 changes: 10 additions & 1 deletion tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,20 @@

def test_distr_plot():
distr_plot(dfc, "RawScore", [{"Sex": ["Male"]}, {"Sex": ["Female"]}, {"Ethnicity": ["Asian"]}])
distr_plot(dfc, "RawScore", [{"Sex": ["Male"]}, dfc["Sex"] == "Female"], cmap=sns.color_palette())

groups = [{"Sex": ["Male"]}, dfc["Sex"] == "Female"]
distr_plot(dfc, "RawScore", groups, cmap=sns.color_palette())
distr_plot(dfc, "RawScore", groups, show_curve=None)
distr_plot(dfc, "RawScore", groups, show_hist=True, show_curve=False)
distr_plot(dfc, "RawScore", groups, show_hist=False, show_curve=True)
distr_plot(dfc, "RawScore", groups, normalize=True)
distr_plot(dfc, "RawScore", groups, normalize=True, distr_type="continuous")
distr_plot(dfc, "DateOfBirth", groups, normalize=True, distr_type="datetime")


def test_attr_distr_plot():
attr_distr_plot(dfc, "RawScore", "Sex")
attr_distr_plot(dfc, "RawScore", "Sex", distr_type="continuous", attr_distr_type="binary")
attr_distr_plot(dfc, "RawScore", "Ethnicity", separate=True)


Expand Down

0 comments on commit f68761c

Please sign in to comment.