Skip to content

Commit

Permalink
Add plot column pair utility method and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed Sep 14, 2022
1 parent 4a79665 commit 77d446b
Show file tree
Hide file tree
Showing 3 changed files with 509 additions and 8 deletions.
3 changes: 2 additions & 1 deletion sdmetrics/reports/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Reports for sdmetrics."""
from sdmetrics.reports.utils import get_column_plot
from sdmetrics.reports.utils import get_column_pair_plot, get_column_plot

__all__ = [
'get_column_pair_plot',
'get_column_plot',
]
180 changes: 175 additions & 5 deletions sdmetrics/reports/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff
from pandas.core.tools.datetimes import _guess_datetime_format_for_array

from sdmetrics.utils import is_datetime

DATACEBO_DARK = '#000036'
DATACEBO_LIGHT = '#01E0C9'
BACKGROUND_COLOR = '#F5F5F8'
CONTINUOUS_SDTYPES = ['numerical', 'datetime']
DISCRETE_SDTYPES = ['categorical', 'boolean']


def make_discrete_column_plot(real_column, synthetic_column, sdtype):
Expand Down Expand Up @@ -79,7 +85,7 @@ def make_discrete_column_plot(real_column, synthetic_column, sdtype):
title=f"Real vs. Synthetic Data for column '{column_name}'",
xaxis_title='Category*' if show_missing_values else 'Category',
yaxis_title='Frequency',
plot_bgcolor='#F5F5F8',
plot_bgcolor=BACKGROUND_COLOR,
annotations=annotations,
)

Expand Down Expand Up @@ -153,7 +159,7 @@ def make_continuous_column_plot(real_column, synthetic_column, sdtype):
title=f'Real vs. Synthetic Data for column {column_name}',
xaxis_title='Value*' if show_missing_values else 'Value',
yaxis_title='Frequency',
plot_bgcolor='#F5F5F8',
plot_bgcolor=BACKGROUND_COLOR,
annotations=annotations,
)

Expand All @@ -169,19 +175,183 @@ def get_column_plot(real_column, synthetic_column, sdtype):
synthetic_column (pandas.Series):
The synthetic data for the desired column.
sdtype (str):
The data type of the column.
The data type of the column. Must be one of
('numerical', 'datetime', 'categorical', or 'boolean').
Returns:
plotly.graph_objects._figure.Figure
"""
if sdtype == 'numerical' or sdtype == 'datetime':
if sdtype in CONTINUOUS_SDTYPES:
fig = make_continuous_column_plot(real_column, synthetic_column, sdtype)
elif sdtype == 'categorical' or sdtype == 'boolean':
elif sdtype in DISCRETE_SDTYPES:
fig = make_discrete_column_plot(real_column, synthetic_column, sdtype)
else:
raise ValueError(f'sdtype of {sdtype} not recognized.')

return fig


def convert_to_datetime(column_data):
"""Convert a column data to pandas datetime.
Args:
column_data (pandas.Series):
The column data
Returns:
pandas.Series:
The converted column data.
"""
if is_datetime(column_data):
return column_data

dt_format = _guess_datetime_format_for_array(column_data.astype(str).to_numpy())
return pd.to_datetime(column_data, format=dt_format)


def make_continuous_column_pair_plot(real_data, synthetic_data):
"""Make a column pair plot for continuous data.
Args:
real_data (pandas.DataFrame):
The real data for the desired column pair.
synthetic_column (pandas.Dataframe):
The synthetic data for the desired column pair.
Returns:
plotly.graph_objects._figure.Figure
"""
columns = real_data.columns
real_data = real_data.copy()
real_data['Data'] = 'Real'
synthetic_data = synthetic_data.copy()
synthetic_data['Data'] = 'Synthetic'
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)

fig = px.scatter(
all_data,
x=columns[0],
y=columns[1],
color='Data',
color_discrete_map={'Real': DATACEBO_DARK, 'Synthetic': DATACEBO_LIGHT},
symbol='Data'
)

fig.update_layout(
title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
plot_bgcolor=BACKGROUND_COLOR,
)

return fig


def make_discrete_column_pair_plot(real_data, synthetic_data):
"""Make a column pair plot for discrete data.
Args:
real_data (pandas.DataFrame):
The real data for the desired column pair.
synthetic_column (pandas.Dataframe):
The synthetic data for the desired column pair.
Returns:
plotly.graph_objects._figure.Figure
"""
columns = real_data.columns
real_data = real_data.copy()
real_data['Data'] = 'Real'
synthetic_data = synthetic_data.copy()
synthetic_data['Data'] = 'Synthetic'
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)

fig = px.density_heatmap(
all_data,
x=columns[0],
y=columns[1],
facet_col='Data',
histnorm='probability'
)

fig.update_layout(
title_text=f"Real vs Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
coloraxis={'colorscale': [DATACEBO_DARK, DATACEBO_LIGHT]},
)

fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1] + ' Data'))

return fig


def make_mixed_column_pair_plot(real_data, synthetic_data):
"""Make a column pair plot for mixed discrete and continuous column data.
Args:
real_data (pandas.DataFrame):
The real data for the desired column pair.
synthetic_column (pandas.Dataframe):
The synthetic data for the desired column pair.
Returns:
plotly.graph_objects._figure.Figure
"""
columns = real_data.columns
real_data = real_data.copy()
real_data['Data'] = 'Real'
synthetic_data = synthetic_data.copy()
synthetic_data['Data'] = 'Synthetic'
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)

fig = px.box(
all_data,
x=columns[0],
y=columns[1],
color='Data',
color_discrete_map={'Real': DATACEBO_DARK, 'Synthetic': DATACEBO_LIGHT},
)

fig.update_layout(
title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
plot_bgcolor=BACKGROUND_COLOR
)

return fig


def get_column_pair_plot(real_data, synthetic_data, sdtypes):
"""Return a plot of the real and synthetic data for a given column pair.
Args:
real_data (pandas.DataFrame):
The real data for the desired column pair.
synthetic_column (pandas.Dataframe):
The synthetic data for the desired column pair.
sdtypes (list[string]):
The data type of the column pair. The data type string must be one of
('numerical', 'datetime', 'categorical', or 'boolean').
Returns:
plotly.graph_objects._figure.Figure
"""
all_sdtypes = CONTINUOUS_SDTYPES + DISCRETE_SDTYPES
if not sdtypes[0] in all_sdtypes:
raise ValueError(f'sdtype of {sdtypes[0]} not recognized.')
if not sdtypes[1] in all_sdtypes:
raise ValueError(f'sdtype of {sdtypes[1]} not recognized.')

if all([t in DISCRETE_SDTYPES for t in sdtypes]):
return make_discrete_column_pair_plot(real_data, synthetic_data)

if sdtypes[0] == 'datetime':
real_data.iloc[:, 0] = convert_to_datetime(real_data.iloc[:, 0])
if sdtypes[1] == 'datetime':
real_data.iloc[:, 1] = convert_to_datetime(real_data.iloc[:, 1])

if all([t in CONTINUOUS_SDTYPES for t in sdtypes]):
return make_continuous_column_pair_plot(real_data, synthetic_data)
else:
return make_mixed_column_pair_plot(real_data, synthetic_data)


def discretize_table_data(real_data, synthetic_data, metadata):
"""Create a copy of the real and synthetic data with discretized data.
Expand Down

0 comments on commit 77d446b

Please sign in to comment.