Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_column_plot #455

Merged
merged 4 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions sdmetrics/reports/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Reports for sdmetrics."""
from sdmetrics.reports.utils import get_column_pair_plot, get_column_plot
from sdmetrics.reports.utils import get_column_pair_plot

__all__ = [
'get_column_pair_plot',
'get_column_plot',
]
47 changes: 0 additions & 47 deletions sdmetrics/reports/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,53 +224,6 @@ def make_continuous_column_plot(real_column, synthetic_column, sdtype):
return fig


def get_column_plot(real_data, synthetic_data, column_name, metadata):
"""Return a plot of the real and synthetic data for a given column.

Args:
real_data (pandas.DataFrame):
The real table data.
synthetic_data (pandas.DataFrame):
The synthetic table data.
column_name (str):
The name of the column.
metadata (dict):
The table metadata.

Returns:
plotly.graph_objects._figure.Figure
"""
columns = get_columns_from_metadata(metadata)
if column_name not in columns:
raise ValueError(f"Column '{column_name}' not found in metadata.")
elif 'sdtype' not in columns[column_name]:
raise ValueError(f"Metadata for column '{column_name}' missing 'type' information.")
if column_name not in real_data.columns:
raise ValueError(f"Column '{column_name}' not found in real table data.")
if column_name not in synthetic_data.columns:
raise ValueError(f"Column '{column_name}' not found in synthetic table data.")

column_meta = columns[column_name]
sdtype = get_type_from_column_meta(columns[column_name])
if sdtype == 'datetime':
real_column, synthetic_column = convert_datetime_columns(
real_data[column_name],
synthetic_data[column_name],
column_meta
)
else:
real_column = real_data[column_name]
synthetic_column = synthetic_data[column_name]
if sdtype in CONTINUOUS_SDTYPES:
fig = make_continuous_column_plot(real_column, synthetic_column, sdtype)
elif sdtype in DISCRETE_SDTYPES:
fig = make_discrete_column_plot(real_column, synthetic_column, sdtype)
else:
raise ValueError(f"sdtype of type '{sdtype}' not recognized.")

return fig


def make_continuous_column_pair_plot(real_data, synthetic_data):
"""Make a column pair plot for continuous data.

Expand Down
70 changes: 60 additions & 10 deletions sdmetrics/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pandas.api.types import is_datetime64_dtype

from sdmetrics.reports.utils import PlotConfig
from sdmetrics.utils import get_missing_percentage
from sdmetrics.utils import get_missing_percentage, is_datetime


def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
Expand All @@ -25,7 +25,7 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
plotly.graph_objects._figure.Figure
"""
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
default_histogram_kwargs = {
histogram_kwargs = {
'x': 'values',
'color': 'Data',
'barmode': 'group',
Expand All @@ -34,9 +34,10 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
'pattern_shape_sequence': ['', '/'],
'histnorm': 'probability density',
}
histogram_kwargs.update(plot_kwargs)
fig = px.histogram(
all_data,
**{**default_histogram_kwargs, **plot_kwargs}
**histogram_kwargs
)

return fig
Expand Down Expand Up @@ -105,25 +106,28 @@ def _generate_column_plot(real_column,

column_name = real_column.name if hasattr(real_column, 'name') else ''

real_data = pd.DataFrame({'values': real_column.copy()})
missing_data_real = get_missing_percentage(real_column)
missing_data_synthetic = get_missing_percentage(synthetic_column)

real_data = pd.DataFrame({'values': real_column.copy().dropna()})
real_data['Data'] = 'Real'
synthetic_data = pd.DataFrame({'values': synthetic_column.copy()})
synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()})
synthetic_data['Data'] = 'Synthetic'

is_datetime_sdtype = False
if is_datetime64_dtype(real_column.dtype):
is_datetime_sdtype = True
real_data = real_data.astype('int64')
synthetic_data = synthetic_data.astype('int64')

missing_data_real = get_missing_percentage(real_column)
missing_data_synthetic = get_missing_percentage(synthetic_column)
real_data['values'] = real_data['values'].astype('int64')
synthetic_data['values'] = synthetic_data['values'].astype('int64')

trace_args = {}

if plot_type == 'bar':
fig = _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs)
elif plot_type == 'distplot':
if x_label is None:
x_label = 'Value'

fig = _generate_column_distplot(real_data, synthetic_data, plot_kwargs)
trace_args = {'fill': 'tozeroy'}

Expand Down Expand Up @@ -259,3 +263,49 @@ def get_cardinality_plot(real_data, synthetic_data, child_table_name, parent_tab
)

return fig


def get_column_plot(real_data, synthetic_data, column_name, plot_type=None):
"""Return a plot of the real and synthetic data for a given column.

Args:
real_data (pandas.DataFrame):
The real table data.
synthetic_data (pandas.DataFrame):
The synthetic table data.
column_name (str):
The name of the column.
plot_type (str or None):
The plot to be used. Can choose between ``distplot``, ``bar`` or ``None``. If ``None`
select between ``distplot`` or ``bar`` depending on the data that the column contains:
``distplot`` for datetime and numerical values and ``bar`` for categorical.
Defaults to ``None``.

Returns:
plotly.graph_objects._figure.Figure
"""
if plot_type not in ['bar', 'distplot', None]:
raise ValueError(
f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot', None]."
)

if column_name not in real_data.columns:
raise ValueError(f"Column '{column_name}' not found in real table data.")
if column_name not in synthetic_data.columns:
raise ValueError(f"Column '{column_name}' not found in synthetic table data.")

real_column = real_data[column_name]
if plot_type is None:
column_is_datetime = is_datetime(real_data[column_name])
dtype = real_column.dropna().infer_objects().dtype.kind
if column_is_datetime or dtype in ('i', 'f'):
plot_type = 'distplot'
else:
plot_type = 'bar'

real_column = real_data[column_name]
synthetic_column = synthetic_data[column_name]

fig = _generate_column_plot(real_column, synthetic_column, plot_type)

return fig