Skip to content

Commit

Permalink
feat(eda):add categorical interaction in create_report
Browse files Browse the repository at this point in the history
  • Loading branch information
jinglinpeng committed Nov 24, 2021
1 parent c91014f commit 7f13cd5
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 41 deletions.
5 changes: 5 additions & 0 deletions dataprep/eda/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,9 +1109,14 @@ class Interactions(BaseModel):
"""
enable: bool, default True
Whether to create this element
cat_enable: bool, default False
where enable categorical column in interactions. By default it is False,
which means only numerical interactions are computed. If set to True, cat-cat
and cat-num interactions will be computed.
"""

enable: bool = True
cat_enable: bool = False


class Correlations(BaseModel):
Expand Down
93 changes: 65 additions & 28 deletions dataprep/eda/correlation/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def render_correlation(itmdt: Intermediate, cfg: Config) -> Any:
elif itmdt.visual_type == "correlation_scatter":
visual_elem = render_scatter(itmdt, plot_width, plot_height, cfg)
elif itmdt.visual_type == "correlation_crossfilter":
visual_elem = render_crossfilter(itmdt, plot_width, plot_height)
visual_elem = render_crossfilter(itmdt, plot_width, plot_height, cfg)
else:
raise NotImplementedError(f"Unknown visual type {itmdt.visual_type}")

Expand Down Expand Up @@ -372,25 +372,34 @@ def render_scatter(


######### Interactions for report #########
def render_crossfilter(itmdt: Intermediate, plot_width: int, plot_height: int) -> column:
def render_crossfilter(
itmdt: Intermediate, plot_width: int, plot_height: int, cfg: Config
) -> column:
"""
Render crossfilter scatter plot with a regression line.
"""

# pylint: disable=too-many-locals, too-many-function-args
df = itmdt["data"]
df["__x__"] = df[df.columns[0]]
df["__y__"] = df[df.columns[0]]
source_scatter = ColumnDataSource(df)
source_xy_value = ColumnDataSource({"x": [df.columns[0]], "y": [df.columns[0]]})
var_list = list(df.columns[:-2])

if cfg.interactions.cat_enable:
all_cols = itmdt["all_cols"]
else:
all_cols = itmdt["num_cols"]

scatter_df = itmdt["scatter_source"]
# all other plots except for scatter plot, used for cat-cat and cat-num interactions.
other_plots = itmdt["other_plots"]
scatter_df["__x__"] = scatter_df[scatter_df.columns[0]]
scatter_df["__y__"] = scatter_df[scatter_df.columns[0]]
source_scatter = ColumnDataSource(scatter_df)
source_xy_value = ColumnDataSource({"x": [scatter_df.columns[0]], "y": [scatter_df.columns[0]]})
var_list = list(all_cols)

xcol = source_xy_value.data["x"][0]
ycol = source_xy_value.data["y"][0]

tooltips = [("X-Axis: ", "@__x__"), ("Y-Axis: ", "@__y__")]

fig = Figure(
scatter_fig = Figure(
plot_width=plot_width,
plot_height=plot_height,
toolbar_location=None,
Expand All @@ -399,10 +408,12 @@ def render_crossfilter(itmdt: Intermediate, plot_width: int, plot_height: int) -
x_axis_label=xcol,
y_axis_label=ycol,
)
scatter = fig.scatter("__x__", "__y__", source=source_scatter)
scatter = scatter_fig.scatter("__x__", "__y__", source=source_scatter)

hover = HoverTool(tooltips=tooltips, renderers=[scatter])
fig.add_tools(hover)
scatter_fig.add_tools(hover)

fig_all_in_one = column(scatter_fig, sizing_mode="stretch_width")

x_select = Select(title="X-Axis", value=xcol, options=var_list, width=150)
y_select = Select(title="Y-Axis", value=ycol, options=var_list, width=150)
Expand All @@ -413,19 +424,31 @@ def render_crossfilter(itmdt: Intermediate, plot_width: int, plot_height: int) -
args=dict(
scatter=source_scatter,
xy_value=source_xy_value,
x_axis=fig.xaxis[0],
fig_all_in_one=fig_all_in_one,
scatter_plot=scatter_fig,
x_axis=scatter_fig.xaxis[0],
other_plots=other_plots,
),
code="""
let currentSelect = this.value;
let xyValueData = xy_value.data;
let scatterData = scatter.data;
xyValueData['x'][0] = currentSelect;
scatterData['__x__'] = scatterData[currentSelect];
x_axis.axis_label = currentSelect;
scatter.change.emit();
xy_value.change.emit();
const children = []
let ycol = xyValueData['y'][0];
let col = currentSelect + '_' + ycol
if (col in other_plots) {
children.push(other_plots[col])
}
else {
scatterData['__x__'] = scatterData[currentSelect];
x_axis.axis_label = currentSelect;
scatter.change.emit();
children.push(scatter_plot)
}
fig_all_in_one.children = children;
""",
),
)
Expand All @@ -435,25 +458,39 @@ def render_crossfilter(itmdt: Intermediate, plot_width: int, plot_height: int) -
args=dict(
scatter=source_scatter,
xy_value=source_xy_value,
y_axis=fig.yaxis[0],
fig_all_in_one=fig_all_in_one,
scatter_plot=scatter_fig,
y_axis=scatter_fig.yaxis[0],
other_plots=other_plots,
),
code="""
let currentSelect = this.value;
let ycol = this.value;
let xyValueData = xy_value.data;
let scatterData = scatter.data;
xyValueData['y'][0] = currentSelect;
scatterData['__y__'] = scatterData[currentSelect];
y_axis.axis_label = currentSelect;
scatter.change.emit();
xyValueData['y'][0] = ycol;
xy_value.change.emit();
const children = []
let xcol = xyValueData['x'][0];
let col = xcol + '_' + ycol;
if (col in other_plots) {
children.push(other_plots[col])
}
else {
scatterData['__y__'] = scatterData[ycol];
y_axis.axis_label = ycol;
scatter.change.emit();
children.push(scatter_plot)
}
fig_all_in_one.children = children;
""",
),
)

fig = column(row(x_select, y_select, align="center"), fig, sizing_mode="stretch_width")
return fig
interaction_fig = column(
row(x_select, y_select, align="center"), fig_all_in_one, sizing_mode="stretch_width"
)
return interaction_fig


# ######### Interactions for report #########
Expand Down
50 changes: 37 additions & 13 deletions dataprep/eda/create_report/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..configs import Config
from ..correlation import render_correlation
from ..correlation.compute.overview import correlation_nxn
from ..distribution import render
from ..distribution import compute, render
from ..utils import _calc_line_dt
from ..distribution.compute.overview import calc_stats
from ..distribution.compute.univariate import calc_stats_dt, cont_comps, nom_comps
Expand Down Expand Up @@ -146,16 +146,21 @@ def _format_variables(df: EDAFrame, cfg: Config, data: Dict[str, Any]) -> Dict[s
def _format_interaction(data: Dict[str, Any], cfg: Config) -> Dict[str, Any]:
"""Format of Interaction section"""
res: Dict[str, Any] = {}
if len(data["num_cols"]) > 0:
# interactions
if cfg.interactions.enable:
res["has_interaction"] = True
itmdt = Intermediate(data=data["scat"], visual_type="correlation_crossfilter")
rndrd = render_correlation(itmdt, cfg)
rndrd.sizing_mode = "stretch_width"
res["interactions"] = components(rndrd)
else:
res["has_interaction"] = False
# interactions
if cfg.interactions.enable:
res["has_interaction"] = True
itmdt = Intermediate(
scatter_source=data["interaction.scatter_source"],
other_plots=data["interaction.other_plots"],
num_cols=data["num_cols"],
all_cols=data["all_cols"],
visual_type="correlation_crossfilter",
)
rndrd = render_correlation(itmdt, cfg)
rndrd.sizing_mode = "stretch_width"
res["interactions"] = components(rndrd)
else:
res["has_interaction"] = False
return res


Expand Down Expand Up @@ -372,25 +377,44 @@ def basic_computations(
cfg
The config dict user passed in. E.g. config = {"hist.bins": 20}
Without user's specifications, the default is "auto"
""" # pylint: disable=too-many-branches
"""
# pylint: disable=too-many-branches, protected-access, too-many-locals

variables_data = _compute_variables(df, cfg)
overview_data = _compute_overview(df, cfg)
data: Dict[str, Any] = {**variables_data, **overview_data}

df_num = df.select_num_columns()
num_columns = df_num.columns
cat_columns = [col for col in df.columns if col not in num_columns]
data["num_cols"] = df_num.columns
data["all_cols"] = df.columns
# interactions
if cfg.interactions.enable:
if cfg.scatter.sample_size is not None:
sample_func = lambda x: x.sample(n=min(cfg.scatter.sample_size, x.shape[0]))
else:
sample_func = lambda x: x.sample(frac=cfg.scatter.sample_rate)
data["scat"] = df_num.frame.map_partitions(
data["interaction.scatter_source"] = df_num.frame.map_partitions(
sample_func,
meta=df_num.frame,
)

data["interaction.other_plots"] = {}
if cfg.interactions.cat_enable:
other_plots = {}
curr_cfg = Config.from_dict(display=["Box Plot", "Nested Bar Chart"])
for cat_col in cat_columns:
for other_col in df.columns:
if (cat_col == other_col) or (cat_col + "_" + other_col in other_plots):
continue
# print(f"cat col:{cat_col}, other col:{other_col}")
imdt = compute(df._ddf, cat_col, other_col, cfg=curr_cfg)
box_plot = render(imdt, curr_cfg)["layout"][0]
other_plots[cat_col + "_" + other_col] = box_plot
other_plots[other_col + "_" + cat_col] = box_plot
data["interaction.other_plots"] = other_plots

# correlations
if cfg.correlations.enable:
data.update(zip(("cordx", "cordy", "corrs"), correlation_nxn(df_num, cfg)))
Expand Down
1 change: 1 addition & 0 deletions dataprep/tests/eda/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ def test_sanity_compute_3(simpledf: dd.DataFrame) -> None:

def test_report(simpledf: dd.DataFrame) -> None:
create_report(simpledf, display=["Overview", "Interactions"])
create_report(simpledf, display=["Interactions"], config={"interactions.cat_enable": True})

1 comment on commit 7f13cd5

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataPrep.EDA Benchmarks

Benchmark suite Current: 7f13cd5 Previous: c91014f Ratio
dataprep/tests/benchmarks/eda.py::test_create_report 0.16625453704234128 iter/sec (stddev: 0.33579210286508854) 0.18085867440398032 iter/sec (stddev: 0.05296779792059889) 1.09

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.