Skip to content

Commit

Permalink
Add plot column pair utility method (#224)
Browse files Browse the repository at this point in the history
* Add plot column pair utility method and tests

* address cr

* styling change
  • Loading branch information
katxiao committed Sep 14, 2022
1 parent 4a79665 commit 62e6b73
Show file tree
Hide file tree
Showing 3 changed files with 515 additions and 14 deletions.
3 changes: 2 additions & 1 deletion sdmetrics/reports/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Reports for sdmetrics."""
from sdmetrics.reports.utils import get_column_plot
from sdmetrics.reports.utils import get_column_pair_plot, get_column_plot

__all__ = [
'get_column_pair_plot',
'get_column_plot',
]
192 changes: 181 additions & 11 deletions sdmetrics/reports/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff
from pandas.core.tools.datetimes import _guess_datetime_format_for_array

from sdmetrics.utils import is_datetime

DATACEBO_DARK = '#000036'
DATACEBO_LIGHT = '#01E0C9'
BACKGROUND_COLOR = '#F5F5F8'
CONTINUOUS_SDTYPES = ['numerical', 'datetime']
DISCRETE_SDTYPES = ['categorical', 'boolean']


def make_discrete_column_plot(real_column, synthetic_column, sdtype):
Expand Down Expand Up @@ -65,8 +71,8 @@ def make_discrete_column_plot(real_column, synthetic_column, sdtype):
{
'xref': 'paper',
'yref': 'paper',
'x': -0.08,
'y': -0.2,
'x': 1.0,
'y': 1.05,
'showarrow': False,
'text': (
f'*Missing Values: Real Data ({missing_data_real}%), '
Expand All @@ -77,9 +83,9 @@ def make_discrete_column_plot(real_column, synthetic_column, sdtype):

fig.update_layout(
title=f"Real vs. Synthetic Data for column '{column_name}'",
xaxis_title='Category*' if show_missing_values else 'Category',
xaxis_title='Category',
yaxis_title='Frequency',
plot_bgcolor='#F5F5F8',
plot_bgcolor=BACKGROUND_COLOR,
annotations=annotations,
)

Expand Down Expand Up @@ -139,8 +145,8 @@ def make_continuous_column_plot(real_column, synthetic_column, sdtype):
{
'xref': 'paper',
'yref': 'paper',
'x': -0.08,
'y': -0.2,
'x': 1.0,
'y': 1.05,
'showarrow': False,
'text': (
f'*Missing Values: Real Data ({missing_data_real}%), '
Expand All @@ -151,9 +157,9 @@ def make_continuous_column_plot(real_column, synthetic_column, sdtype):

fig.update_layout(
title=f'Real vs. Synthetic Data for column {column_name}',
xaxis_title='Value*' if show_missing_values else 'Value',
xaxis_title='Value',
yaxis_title='Frequency',
plot_bgcolor='#F5F5F8',
plot_bgcolor=BACKGROUND_COLOR,
annotations=annotations,
)

Expand All @@ -169,19 +175,183 @@ def get_column_plot(real_column, synthetic_column, sdtype):
synthetic_column (pandas.Series):
The synthetic data for the desired column.
sdtype (str):
The data type of the column.
The data type of the column. Must be one of
('numerical', 'datetime', 'categorical', or 'boolean').
Returns:
plotly.graph_objects._figure.Figure
"""
if sdtype == 'numerical' or sdtype == 'datetime':
if sdtype in CONTINUOUS_SDTYPES:
fig = make_continuous_column_plot(real_column, synthetic_column, sdtype)
elif sdtype == 'categorical' or sdtype == 'boolean':
elif sdtype in DISCRETE_SDTYPES:
fig = make_discrete_column_plot(real_column, synthetic_column, sdtype)
else:
raise ValueError(f"sdtype of type '{sdtype}' not recognized.")

return fig


def convert_to_datetime(column_data):
"""Convert a column data to pandas datetime.
Args:
column_data (pandas.Series):
The column data
Returns:
pandas.Series:
The converted column data.
"""
if is_datetime(column_data):
return column_data

dt_format = _guess_datetime_format_for_array(column_data.astype(str).to_numpy())
return pd.to_datetime(column_data, format=dt_format)


def make_continuous_column_pair_plot(real_data, synthetic_data):
"""Make a column pair plot for continuous data.
Args:
real_data (pandas.DataFrame):
The real data for the desired column pair.
synthetic_column (pandas.Dataframe):
The synthetic data for the desired column pair.
Returns:
plotly.graph_objects._figure.Figure
"""
columns = real_data.columns
real_data = real_data.copy()
real_data['Data'] = 'Real'
synthetic_data = synthetic_data.copy()
synthetic_data['Data'] = 'Synthetic'
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)

fig = px.scatter(
all_data,
x=columns[0],
y=columns[1],
color='Data',
color_discrete_map={'Real': DATACEBO_DARK, 'Synthetic': DATACEBO_LIGHT},
symbol='Data'
)

fig.update_layout(
title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
plot_bgcolor=BACKGROUND_COLOR,
)

return fig


def make_discrete_column_pair_plot(real_data, synthetic_data):
"""Make a column pair plot for discrete data.
Args:
real_data (pandas.DataFrame):
The real data for the desired column pair.
synthetic_column (pandas.Dataframe):
The synthetic data for the desired column pair.
Returns:
plotly.graph_objects._figure.Figure
"""
columns = real_data.columns
real_data = real_data.copy()
real_data['Data'] = 'Real'
synthetic_data = synthetic_data.copy()
synthetic_data['Data'] = 'Synthetic'
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)

fig = px.density_heatmap(
all_data,
x=columns[0],
y=columns[1],
facet_col='Data',
histnorm='probability'
)

fig.update_layout(
title_text=f"Real vs Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
coloraxis={'colorscale': [DATACEBO_DARK, DATACEBO_LIGHT]},
)

fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1] + ' Data'))

return fig


def make_mixed_column_pair_plot(real_data, synthetic_data):
"""Make a column pair plot for mixed discrete and continuous column data.
Args:
real_data (pandas.DataFrame):
The real data for the desired column pair.
synthetic_column (pandas.Dataframe):
The synthetic data for the desired column pair.
Returns:
plotly.graph_objects._figure.Figure
"""
columns = real_data.columns
real_data = real_data.copy()
real_data['Data'] = 'Real'
synthetic_data = synthetic_data.copy()
synthetic_data['Data'] = 'Synthetic'
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)

fig = px.box(
all_data,
x=columns[0],
y=columns[1],
color='Data',
color_discrete_map={'Real': DATACEBO_DARK, 'Synthetic': DATACEBO_LIGHT},
)

fig.update_layout(
title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
plot_bgcolor=BACKGROUND_COLOR
)

return fig


def get_column_pair_plot(real_data, synthetic_data, sdtypes):
"""Return a plot of the real and synthetic data for a given column pair.
Args:
real_data (pandas.DataFrame):
The real data for the desired column pair.
synthetic_column (pandas.Dataframe):
The synthetic data for the desired column pair.
sdtypes (list[string]):
The data type of the column pair. The data type string must be one of
('numerical', 'datetime', 'categorical', or 'boolean').
Returns:
plotly.graph_objects._figure.Figure
"""
all_sdtypes = CONTINUOUS_SDTYPES + DISCRETE_SDTYPES
if sdtypes[0] not in all_sdtypes:
raise ValueError(f"sdtype of type '{sdtypes[0]}' not recognized.")
if sdtypes[1] not in all_sdtypes:
raise ValueError(f"sdtype of type '{sdtypes[1]}' not recognized.")

if all([t in DISCRETE_SDTYPES for t in sdtypes]):
return make_discrete_column_pair_plot(real_data, synthetic_data)

if sdtypes[0] == 'datetime':
real_data.iloc[:, 0] = convert_to_datetime(real_data.iloc[:, 0])
if sdtypes[1] == 'datetime':
real_data.iloc[:, 1] = convert_to_datetime(real_data.iloc[:, 1])

if all([t in CONTINUOUS_SDTYPES for t in sdtypes]):
return make_continuous_column_pair_plot(real_data, synthetic_data)
else:
return make_mixed_column_pair_plot(real_data, synthetic_data)


def discretize_table_data(real_data, synthetic_data, metadata):
"""Create a copy of the real and synthetic data with discretized data.
Expand Down

0 comments on commit 62e6b73

Please sign in to comment.