Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plot column pair utility method #224

Merged
merged 3 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion sdmetrics/reports/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Reports for sdmetrics."""
from sdmetrics.reports.utils import get_column_plot
from sdmetrics.reports.utils import get_column_pair_plot, get_column_plot

__all__ = [
'get_column_pair_plot',
'get_column_plot',
]
180 changes: 175 additions & 5 deletions sdmetrics/reports/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff
from pandas.core.tools.datetimes import _guess_datetime_format_for_array

from sdmetrics.utils import is_datetime

DATACEBO_DARK = '#000036'
DATACEBO_LIGHT = '#01E0C9'
BACKGROUND_COLOR = '#F5F5F8'
CONTINUOUS_SDTYPES = ['numerical', 'datetime']
DISCRETE_SDTYPES = ['categorical', 'boolean']


def make_discrete_column_plot(real_column, synthetic_column, sdtype):
Expand Down Expand Up @@ -79,7 +85,7 @@ def make_discrete_column_plot(real_column, synthetic_column, sdtype):
title=f"Real vs. Synthetic Data for column '{column_name}'",
xaxis_title='Category*' if show_missing_values else 'Category',
yaxis_title='Frequency',
plot_bgcolor='#F5F5F8',
plot_bgcolor=BACKGROUND_COLOR,
annotations=annotations,
)

Expand Down Expand Up @@ -153,7 +159,7 @@ def make_continuous_column_plot(real_column, synthetic_column, sdtype):
title=f'Real vs. Synthetic Data for column {column_name}',
xaxis_title='Value*' if show_missing_values else 'Value',
yaxis_title='Frequency',
plot_bgcolor='#F5F5F8',
plot_bgcolor=BACKGROUND_COLOR,
annotations=annotations,
)

Expand All @@ -169,19 +175,183 @@ def get_column_plot(real_column, synthetic_column, sdtype):
synthetic_column (pandas.Series):
The synthetic data for the desired column.
sdtype (str):
The data type of the column.
The data type of the column. Must be one of
('numerical', 'datetime', 'categorical', or 'boolean').

Returns:
plotly.graph_objects._figure.Figure
"""
if sdtype == 'numerical' or sdtype == 'datetime':
if sdtype in CONTINUOUS_SDTYPES:
fig = make_continuous_column_plot(real_column, synthetic_column, sdtype)
elif sdtype == 'categorical' or sdtype == 'boolean':
elif sdtype in DISCRETE_SDTYPES:
fig = make_discrete_column_plot(real_column, synthetic_column, sdtype)
else:
raise ValueError(f'sdtype of {sdtype} not recognized.')

return fig


def convert_to_datetime(column_data):
"""Convert a column data to pandas datetime.

Args:
column_data (pandas.Series):
The column data

Returns:
pandas.Series:
The converted column data.
"""
if is_datetime(column_data):
return column_data

dt_format = _guess_datetime_format_for_array(column_data.astype(str).to_numpy())
return pd.to_datetime(column_data, format=dt_format)


def make_continuous_column_pair_plot(real_data, synthetic_data):
"""Make a column pair plot for continuous data.

Args:
real_data (pandas.DataFrame):
The real data for the desired column pair.
synthetic_column (pandas.Dataframe):
The synthetic data for the desired column pair.

Returns:
plotly.graph_objects._figure.Figure
"""
columns = real_data.columns
real_data = real_data.copy()
real_data['Data'] = 'Real'
synthetic_data = synthetic_data.copy()
synthetic_data['Data'] = 'Synthetic'
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)

fig = px.scatter(
all_data,
x=columns[0],
y=columns[1],
color='Data',
color_discrete_map={'Real': DATACEBO_DARK, 'Synthetic': DATACEBO_LIGHT},
symbol='Data'
)

fig.update_layout(
title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
plot_bgcolor=BACKGROUND_COLOR,
)

return fig


def make_discrete_column_pair_plot(real_data, synthetic_data):
"""Make a column pair plot for discrete data.

Args:
real_data (pandas.DataFrame):
The real data for the desired column pair.
synthetic_column (pandas.Dataframe):
The synthetic data for the desired column pair.

Returns:
plotly.graph_objects._figure.Figure
"""
columns = real_data.columns
real_data = real_data.copy()
real_data['Data'] = 'Real'
synthetic_data = synthetic_data.copy()
synthetic_data['Data'] = 'Synthetic'
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)

fig = px.density_heatmap(
all_data,
x=columns[0],
y=columns[1],
facet_col='Data',
histnorm='probability'
)

fig.update_layout(
title_text=f"Real vs Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
coloraxis={'colorscale': [DATACEBO_DARK, DATACEBO_LIGHT]},
)

fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1] + ' Data'))

return fig


def make_mixed_column_pair_plot(real_data, synthetic_data):
"""Make a column pair plot for mixed discrete and continuous column data.

Args:
real_data (pandas.DataFrame):
The real data for the desired column pair.
synthetic_column (pandas.Dataframe):
The synthetic data for the desired column pair.

Returns:
plotly.graph_objects._figure.Figure
"""
columns = real_data.columns
real_data = real_data.copy()
real_data['Data'] = 'Real'
synthetic_data = synthetic_data.copy()
synthetic_data['Data'] = 'Synthetic'
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)

fig = px.box(
all_data,
x=columns[0],
y=columns[1],
color='Data',
color_discrete_map={'Real': DATACEBO_DARK, 'Synthetic': DATACEBO_LIGHT},
)

fig.update_layout(
title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
plot_bgcolor=BACKGROUND_COLOR
)

return fig


def get_column_pair_plot(real_data, synthetic_data, sdtypes):
"""Return a plot of the real and synthetic data for a given column pair.

Args:
real_data (pandas.DataFrame):
The real data for the desired column pair.
synthetic_column (pandas.Dataframe):
The synthetic data for the desired column pair.
sdtypes (list[string]):
The data type of the column pair. The data type string must be one of
('numerical', 'datetime', 'categorical', or 'boolean').

Returns:
plotly.graph_objects._figure.Figure
"""
all_sdtypes = CONTINUOUS_SDTYPES + DISCRETE_SDTYPES
if not sdtypes[0] in all_sdtypes:
katxiao marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f'sdtype of {sdtypes[0]} not recognized.')
if not sdtypes[1] in all_sdtypes:
katxiao marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f'sdtype of {sdtypes[1]} not recognized.')

if all([t in DISCRETE_SDTYPES for t in sdtypes]):
return make_discrete_column_pair_plot(real_data, synthetic_data)

if sdtypes[0] == 'datetime':
real_data.iloc[:, 0] = convert_to_datetime(real_data.iloc[:, 0])
if sdtypes[1] == 'datetime':
real_data.iloc[:, 1] = convert_to_datetime(real_data.iloc[:, 1])

if all([t in CONTINUOUS_SDTYPES for t in sdtypes]):
return make_continuous_column_pair_plot(real_data, synthetic_data)
else:
return make_mixed_column_pair_plot(real_data, synthetic_data)


def discretize_table_data(real_data, synthetic_data, metadata):
"""Create a copy of the real and synthetic data with discretized data.

Expand Down