Skip to content

Commit

Permalink
perf(eda): improve progress bar performance
Browse files Browse the repository at this point in the history
And added option to disable progress bar.
  • Loading branch information
dovahcrow committed Sep 4, 2020
1 parent 8ebe9cc commit 64be889
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 25 deletions.
15 changes: 9 additions & 6 deletions dataprep/eda/correlation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def plot_correlation(
*,
value_range: Optional[Tuple[float, float]] = None,
k: Optional[int] = None,
progress: bool = True,
) -> Report:
"""
This function is designed to calculate the correlation between columns
Expand All @@ -32,15 +33,17 @@ def plot_correlation(
Parameters
----------
df
The pandas data_frame for which plots are calculated for each column
The pandas data_frame for which plots are calculated for each column.
x
A valid column name of the data frame
A valid column name of the data frame.
y
A valid column name of the data frame
A valid column name of the data frame.
value_range
Range of value
Range of value.
k
Choose top-k element
Choose top-k element.
progress
Enable the progress bar.
Examples
--------
Expand All @@ -61,7 +64,7 @@ def plot_correlation(
This function only supports numerical or categorical data,
and it is better to drop None, Nan and Null value before using it
"""
with ProgressBar(minimum=1):
with ProgressBar(minimum=1, disable=not progress):
intermediate = compute_correlation(df, x=x, y=y, value_range=value_range, k=k)
figure = render_correlation(intermediate)

Expand Down
6 changes: 5 additions & 1 deletion dataprep/eda/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def plot(
yscale: str = "linear",
tile_size: Optional[float] = None,
dtype: Optional[DTypeDef] = None,
progress: bool = True,
) -> Union[Report, Container]:
"""Generates plots for exploratory data analysis.
Expand Down Expand Up @@ -133,6 +134,9 @@ def plot(
E.g. dtype = {"a": Continuous, "b": "Nominal"} or
dtype = {"a": Continuous(), "b": "nominal"}
or dtype = Continuous() or dtype = "Continuous" or dtype = Continuous()
progress
Enable the progress bar.
Examples
--------
>>> import pandas as pd
Expand All @@ -144,7 +148,7 @@ def plot(
"""
# pylint: disable=too-many-locals,line-too-long

with ProgressBar(minimum=1):
with ProgressBar(minimum=1, disable=not progress):
intermediate = compute(
df,
x=x,
Expand Down
17 changes: 10 additions & 7 deletions dataprep/eda/missing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def plot_missing(
bins: int = 30,
ndist_sample: int = 100,
dtype: Optional[DTypeDef] = None,
progress: bool = True,
) -> Report:
"""
This function is designed to deal with missing values
Expand All @@ -33,20 +34,22 @@ def plot_missing(
Parameters
----------
df
the pandas data_frame for which plots are calculated for each column
the pandas data_frame for which plots are calculated for each column.
x
a valid column name of the data frame
a valid column name of the data frame.
y
a valid column name of the data frame
a valid column name of the data frame.
bins
The number of rows in the figure
The number of rows in the figure.
ndist_sample
The number of sample points
The number of sample points.
wdtype: str or DType or dict of str or dict of DType, default None
Specify Data Types for designated column or all columns.
E.g. dtype = {"a": Continuous, "b": "Nominal"} or
dtype = {"a": Continuous(), "b": "nominal"}
or dtype = Continuous() or dtype = "Continuous" or dtype = Continuous()
or dtype = Continuous() or dtype = "Continuous" or dtype = Continuous().
progress
Enable the progress bar.
Examples
----------
Expand All @@ -57,7 +60,7 @@ def plot_missing(
>>> plot_missing(df, "HDI_for_year", "population")
"""

with ProgressBar(minimum=1):
with ProgressBar(minimum=1, disable=not progress):
itmdt = compute_missing(
df, x, y, dtype=dtype, bins=bins, ndist_sample=ndist_sample
)
Expand Down
87 changes: 76 additions & 11 deletions dataprep/eda/progress_bar.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""ProgressBar shows the how many dask tasks finished/remains using tqdm."""

from typing import Any, Optional, Dict, Tuple, Union
import sys
from time import time
from typing import Any, Dict, Optional, Tuple, Union

from dask.callbacks import Callback

Expand All @@ -15,18 +16,21 @@
# pylint: disable=method-hidden,too-many-instance-attributes
class ProgressBar(Callback): # type: ignore
"""A progress bar for DataPrep.EDA.
Not thread safe.
Parameters
----------
minimum : int, optional
Minimum time threshold in seconds before displaying a progress bar.
Default is 0 (always display)
_min_tasks : int, optional
min_tasks : int, optional
Minimum graph size to show a progress bar, default is 5
width : int, optional
Width of the bar. None means auto width.
interval : float, optional
Update resolution in seconds, default is 0.1 seconds
Update resolution in seconds, default is 0.1 seconds.
disable : bool, optional
Disable the progress bar.
"""

_minimum: float = 0
Expand All @@ -38,50 +42,75 @@ class ProgressBar(Callback): # type: ignore
_state: Optional[Dict[str, Any]] = None
_started: Optional[float] = None
_last_task: Optional[str] = None # in case we initialize the pbar in _finish
_pbar_runtime: float = 0
_last_updated: Optional[float] = None
_disable: bool = False

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
minimum: float = 0,
min_tasks: int = 5,
width: Optional[int] = None,
interval: float = 0.1,
disable: bool = False,
) -> None:
super().__init__()
self._minimum = minimum
self._min_tasks = min_tasks
self._width = width
self._interval = interval
self._disable = disable

def _start(self, _dsk: Any) -> None:
"""A hook to start this callback."""

def _start_state(self, _dsk: Any, state: Dict[str, Any]) -> None:
"""A hook called before every task gets executed."""
self._started = time()
if self._disable:
return

then = time()

self._last_updated = self._started = time()

self._state = state
_, ntasks = self._count_tasks()

if ntasks > self._min_tasks:
self._init_bar()
self._init_pbar()

self._pbar_runtime += time() - then

def _pretask(
self, key: Union[str, Tuple[str, ...]], _dsk: Any, _state: Dict[str, Any]
) -> None:
"""A hook called before one task gets executed."""
if self._disable:
return

then = time()

if self._started is None:
raise ValueError("ProgressBar not started properly")

if self._pbar is None and time() - self._started > self._minimum:
self._init_bar()
self._init_pbar()

if isinstance(key, tuple):
key = key[0]

if self._pbar is not None:
self._pbar.set_description(f"Computing {key}")
if self._last_updated is None:
raise ValueError("ProgressBar not started properly")

if time() - self._last_updated > self._interval:
self._pbar.set_description(f"Computing {key}")
self._last_updated = time()
else:
self._last_task = key

self._pbar_runtime += time() - then

def _posttask(
self,
_key: str,
Expand All @@ -92,21 +121,47 @@ def _posttask(
) -> None:
"""A hook called after one task gets executed."""

if self._disable:
return

then = time()

if self._pbar is not None:
self._update_bar()
if self._last_updated is None:
raise ValueError("ProgressBar not started properly")

if time() - self._last_updated > self._interval:
self._update_bar()
self._last_updated = time()

self._pbar_runtime += time() - then

def _finish(self, _dsk: Any, _state: Dict[str, Any], _errored: bool) -> None:
"""A hook called after all tasks get executed."""
if self._disable:
return

then = time()

if self._started is None:
raise ValueError("ProgressBar not started properly")

if self._pbar is None and time() - self._started > self._minimum:
self._init_bar()
self._init_pbar()

if self._pbar is not None:
self._update_bar()
self._pbar.close()

self._pbar_runtime += time() - then

if self._pbar_runtime / (time() - self._started) > 0.3:
print(
"[ProgressBar] ProgressBar takes additional 10%+ of the computation time,"
" consider disable it by passing 'progress=False' to the plot function.",
file=sys.stderr,
)

self._state = None
self._started = None
self._pbar = None
Expand All @@ -118,7 +173,7 @@ def _update_bar(self) -> None:

self._pbar.update(max(0, ndone - self._pbar.n))

def _init_bar(self) -> None:
def _init_pbar(self) -> None:
if self._pbar is not None:
raise ValueError("ProgressBar already initialized.")
ndone, ntasks = self._count_tasks()
Expand Down Expand Up @@ -157,3 +212,13 @@ def _count_tasks(self) -> Tuple[int, int]:
ntasks = sum(len(state[k]) for k in ["ready", "waiting", "running"]) + ndone

return ndone, ntasks

def register(self) -> None:
raise ValueError(
"ProgressBar is not thread safe thus cannot be regestered globally"
)

def unregister(self) -> None:
raise ValueError(
"ProgressBar is not thread safe thus cannot be unregestered globally"
)

0 comments on commit 64be889

Please sign in to comment.