Skip to content

Commit

Permalink
fix(eda):clean config code and fix scatter sample param
Browse files Browse the repository at this point in the history
  • Loading branch information
jinglinpeng committed Sep 18, 2021
1 parent a868c50 commit 8ab27f9
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 57 deletions.
167 changes: 122 additions & 45 deletions dataprep/eda/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""

# pylint: disable=too-many-lines,no-self-use,blacklisted-name,no-else-raise,too-many-branches,no-name-in-module

# pylint: disable = protected-access
from __future__ import annotations

from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -642,32 +642,74 @@ class Scatter(BaseModel):
"""
enable: bool, default True
Whether to create this element
sample_size: int, default 1000
Number of points to randomly sample per partition
sample_size: int, optional, default=1000
Number of points to randomly sample per partition.
Cannot be used with sample_rate.
sample_rate: float, optional, default None
sample rate per partition. Cannot be used with
sample_size. Set it to 1.0 for no sampling.
height: int, default "auto"
Height of the plot
width: int, default "auto"
Width of the plot
"""

enable: bool = True
sample_size: int = 1000
sample_size: Optional[int] = 1000
sample_rate: Optional[float] = None
height: Union[int, None] = None
width: Union[int, None] = None
# used internally for param checking. Seems
# internal param will be treat as class attr. in pydantic,
# hence we need to init. it in __init__.
_user_input_params: Dict[str, Any]

def __init__(self) -> None:
super().__init__()
object.__setattr__(self, "_user_input_params", {})

def how_to_guide(self, height: int, width: int) -> List[Tuple[str, str]]:
"""
how-to guide
"""
vals = [self.sample_size, height, width]
names = ["scatter.sample_size", "height", "width"]
if self.sample_size is not None:
para_val: Union[int, float, None] = self.sample_size
para_name = "scatter.sample_size"
para_desc = "Number of points to randomly sample per partition"
else:
para_val = self.sample_rate
para_name = "scatter.sample_rate"
para_desc = "Sample rate to randomly sample per partition"

vals = [para_val, height, width]
names = [para_name, "height", "width"]
descs = [
"Number of points to randomly sample per partition",
para_desc,
"Height of the plot",
"Width of the plot",
]
return [(f"'{name}': {val}", desc) for name, val, desc in zip(names, vals, descs)]

def _check_and_correct_param(self) -> None:
"""Check whether the parameters are valid, and correct param when necessary"""
user_set_sample_size = (
self._user_input_params["sample_size"]
if "sample_size" in self._user_input_params
else None
)
user_set_sample_rate = (
self._user_input_params["sample_rate"]
if "sample_rate" in self._user_input_params
else None
)
if (user_set_sample_size is not None) and (user_set_sample_rate is not None):
raise AttributeError(
f"Scatter plot set sample size {user_set_sample_size} and "
+ f"sample rate {user_set_sample_rate}, please only set one of them."
)
if user_set_sample_rate is not None:
self.sample_size = None


class Hexbin(BaseModel):
"""
Expand Down Expand Up @@ -1146,6 +1188,71 @@ class Config(BaseModel):
missingvalues: MissingValues = Field(default_factory=MissingValues)
diff: Diff = Field(default_factory=Diff)

def _set_enable_for_plots(self, display: List[str]) -> None:
"""set the enable for all plots from display, used for 'from_dict' constructor """
all_plot_names = vars(self).keys()
try:
# set all plots not in display list to enable=False except for Plot and Diff class
valid_display = [DISPLAY_MAP[disp] for disp in display]
for plot_name in set(all_plot_names) - set(valid_display) - {"plot"} - {"diff"}:
setattr(getattr(self, plot_name), "enable", False)
except KeyError:
# handle report config
valid_display = [DISPLAY_REPORT_MAP[disp] for disp in display]
for plot_name in set(DISPLAY_REPORT_MAP.values()) - set(valid_display):
setattr(getattr(self, plot_name), "enable", False)

def _set_param_for_plot(
self, plot_name: str, param: str, val: Any, raise_error_if_not_exists: bool
) -> None:
"""set the parameter for a given plot, used when
set global and local parameters for each plot"""
if plot_name not in vars(self).keys():
raise AttributeError(f"plot {plot_name} does not exist")
plot = getattr(self, plot_name)
if hasattr(plot, param):
setattr(plot, param, val)
if hasattr(plot, "_user_input_params"):
plot._user_input_params[param] = val
else:
if raise_error_if_not_exists:
raise AttributeError(f"{plot_name} plot does not have parameter {param}")

def _set_global_param_for_plots(self, global_params: Dict[str, Any]) -> None:
"""set the global parameters for all plots, used for 'from_dict' constructor """
all_plot_names = vars(self).keys()
valid_global_params = vars(self.plot).keys()
for param, val in global_params.items():
# set the parameter to the specified value for each plot that
# has this parameter
if param not in valid_global_params:
raise AttributeError(f"{param} is not a global parameter")
else:
# ngroups applies to "bars" and "slices" for the bar and pie charts
if param == "ngroups":
setattr(getattr(self, "bar"), "bars", val)
setattr(getattr(self, "pie"), "slices", val)
for plot_name in all_plot_names:
self._set_param_for_plot(plot_name, param, val, raise_error_if_not_exists=False)

def _set_local_param_for_plots(self, local_params: Dict[str, Any]) -> None:
"""set the local parameters for all plots, used for 'from_dict' constructor """
for key, value in local_params.items():
plot_name, rest = key.split(".", 1)
param = rest.replace(".", "__")
self._set_param_for_plot(plot_name, param, value, raise_error_if_not_exists=True)

def _check_and_correct_params_for_plots(self) -> None:
"""Call the '_check_and_correct_param' for some plots, used for 'from_dict' constructor.
The '_check_and_correct_param' is used to check and correct parameter and handle the case
when multiple parameters are not allowed set at the same time. E.g., the sample size and
sample rate in scatter plot."""
all_plot_names = vars(self).keys()
for plot_name in all_plot_names:
plot = getattr(self, plot_name)
if hasattr(plot, "_check_and_correct_param"):
plot._check_and_correct_param()

@classmethod
def from_dict(
cls, display: Optional[List[str]] = None, config: Optional[Dict[str, Any]] = None
Expand All @@ -1154,45 +1261,15 @@ def from_dict(
Converts an dictionary instance into a config class
"""
cfg = cls()
if display:
try:
display = [DISPLAY_MAP[disp] for disp in display]
# set all plots not in display list to enable=False except for Plot and Diff class
for plot in set(vars(cfg).keys()) - set(display) - {"plot"} - {"diff"}:
setattr(getattr(cfg, plot), "enable", False)
except KeyError:
display = [DISPLAY_REPORT_MAP[disp] for disp in display]
for plot in set(DISPLAY_REPORT_MAP.values()) - set(display):
setattr(getattr(cfg, plot), "enable", False)

if config:
# get the global parameters from config
if display is not None:
cfg._set_enable_for_plots(display)

if config is not None:
# get the global and local parameters from config
global_params = {key: config[key] for key in config if "." not in key}
for param, val in global_params.items():
# set the parameter to the specified value for each plot that
# has this parameter
if param not in vars(cfg.plot).keys():
raise Exception(param + " does not exist")
else:
for plot in vars(cfg).keys():
if hasattr(getattr(cfg, plot), param):
setattr(getattr(cfg, plot), param, val)

# ngroups applies to "bars" and "slices" for the bar and pie charts
if param == "ngroups":
setattr(getattr(cfg, "bar"), "bars", val)
setattr(getattr(cfg, "pie"), "slices", val)

# get the local parameters from config
local_params = {key: config[key] for key in config if key not in global_params}
for key, value in local_params.items():
plot, rest = key.split(".", 1)
param = rest.replace(".", "__")
if plot not in vars(cfg).keys():
raise Exception(plot + " does not exist")
elif not hasattr(getattr(cfg, plot), param):
raise Exception(key.replace(f"{plot}.", "") + " does not exist")
else:
setattr(getattr(cfg, plot), param, value)

cfg._set_global_param_for_plots(global_params)
cfg._set_local_param_for_plots(local_params)
cfg._check_and_correct_params_for_plots()
return cfg
7 changes: 6 additions & 1 deletion dataprep/eda/correlation/compute/bivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,13 @@ def scatter_with_regression(
(coeffa, coeffb), _, _, _ = da.linalg.lstsq(arr[:, [0, 2]], arr[:, 1])

df = df.drop(columns=["ones"])
if cfg.scatter.sample_size is not None:
sample_func = lambda x: x.sample(n=min(cfg.scatter.sample_size, x.shape[0]))
else:
sample_func = lambda x: x.sample(frac=cfg.scatter.sample_rate)
df_smp = df.map_partitions(
lambda x: x.sample(min(cfg.scatter.sample_size, x.shape[0])), meta=df
sample_func,
meta=df,
)
# TODO influences should not be computed on a sample
influences = (
Expand Down
8 changes: 7 additions & 1 deletion dataprep/eda/create_report/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,15 @@ def basic_computations(
data["num_cols"] = df_num.columns
# interactions
if cfg.interactions.enable:
if cfg.scatter.sample_size is not None:
sample_func = lambda x: x.sample(n=min(cfg.scatter.sample_size, x.shape[0]))
else:
sample_func = lambda x: x.sample(frac=cfg.scatter.sample_rate)
data["scat"] = df_num.frame.map_partitions(
lambda x: x.sample(min(1000, x.shape[0])), meta=df_num.frame
sample_func,
meta=df_num.frame,
)

# correlations
if cfg.correlations.enable:
data.update(zip(("cordx", "cordy", "corrs"), correlation_nxn(df_num, cfg)))
Expand Down
8 changes: 5 additions & 3 deletions dataprep/eda/distribution/compute/bivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,11 @@ def compute_bivariate(
data: Dict[str, Any] = {}
if cfg.scatter.enable:
# scatter plot data
data["scat"] = tmp_df.map_partitions(
lambda x: x.sample(min(cfg.scatter.sample_size, x.shape[0])), meta=tmp_df
)
if cfg.scatter.sample_size is not None:
sample_func = lambda x: x.sample(n=min(cfg.scatter.sample_size, x.shape[0]))
else:
sample_func = lambda x: x.sample(frac=cfg.scatter.sample_rate)
data["scat"] = tmp_df.map_partitions(sample_func, meta=tmp_df)
if cfg.hexbin.enable:
# hexbin plot data
data["hex"] = tmp_df
Expand Down
25 changes: 22 additions & 3 deletions dataprep/eda/distribution/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,15 +914,25 @@ def scatter_viz(
df: pd.DataFrame,
x: str,
y: str,
spl_sz: int,
sample_sr_and_name: Tuple[Union[int, float], str],
plot_width: int,
plot_height: int,
) -> Any:
"""
Render a scatter plot
"""
# pylint: disable=too-many-arguments
title = f"{y} by {x}" if len(df) < spl_sz else f"{y} by {x} (sample size {spl_sz})"
if sample_sr_and_name[1] == "sample size":
title = (
f"{y} by {x}"
if len(df) < sample_sr_and_name[0]
else f"{y} by {x} (sample size {sample_sr_and_name[0]})"
)
elif sample_sr_and_name[1] == "sample rate":
title = f"{y} by {x} (sample rate {sample_sr_and_name[0]})"
else:
raise RuntimeError("parameter name should be either 'sample size' or 'sample rate'")

tooltips = [("(x, y)", f"(@{{{x}}}, @{{{y}}})")]
fig = figure(
tools="hover",
Expand Down Expand Up @@ -2143,12 +2153,21 @@ def render_two_num(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:

if cfg.scatter.enable:
# scatter plot
if cfg.scatter.sample_size is not None:
sample_sr_and_name: Tuple[Union[int, float], str] = (
cfg.scatter.sample_size,
"sample size",
)
elif cfg.scatter.sample_rate is not None:
sample_sr_and_name = (cfg.scatter.sample_rate, "sample rate")
else:
raise RuntimeError("In scatter plot, sample size and sample rate are both not None")
tabs.append(
scatter_viz(
data["scat"],
x,
y,
cfg.scatter.sample_size,
sample_sr_and_name,
plot_width,
plot_height,
)
Expand Down
12 changes: 8 additions & 4 deletions dataprep/tests/eda/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import random
import pytest

from ...eda import plot, plot_correlation, plot_missing
from ...eda import plot, plot_correlation, plot_missing, create_report
from ...eda.utils import to_dask


Expand Down Expand Up @@ -45,6 +45,10 @@ def test_sanity_compute_2(simpledf: dd.DataFrame) -> None:


def test_sanity_compute_3(simpledf: dd.DataFrame) -> None:
for _ in range(5):
sample_size = random.randint(200, 1000)
plot_correlation(simpledf, "a", "b", config={"scatter.sample_size": sample_size})
plot_correlation(simpledf, "a", "b", config={"scatter.sample_rate": 0.1})
plot_correlation(simpledf, "a", "b", config={"scatter.sample_size": 10000})
plot_correlation(simpledf, "a", "b", config={"scatter.sample_size": 100})


def test_report(simpledf: dd.DataFrame) -> None:
create_report(simpledf, display=["Overview", "Interactions"])

1 comment on commit 8ab27f9

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataPrep.EDA Benchmarks

Benchmark suite Current: 8ab27f9 Previous: a868c50 Ratio
dataprep/tests/benchmarks/eda.py::test_create_report 0.16845244429142248 iter/sec (stddev: 0.1629340202527945) 0.1622125604402684 iter/sec (stddev: 0.11301872409362092) 0.96

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.