Skip to content

Commit

Permalink
Add update_column and add_column to SingleTableMetadata (#915)
Browse files Browse the repository at this point in the history
* Implement update / add columns. *WIP

* Add docstrings

* Fix call functions

* WIP: Unit tests in progress

* Finish unit tests

* Address comments

* Fix multiple python version erroring

* fix exception

* Fix error msg

* Address comments

* Address comments about frozensets

* Bump macos version.
  • Loading branch information
pvk-developer committed Jul 25, 2022
1 parent 028ebc3 commit 72ba076
Show file tree
Hide file tree
Showing 2 changed files with 721 additions and 5 deletions.
170 changes: 165 additions & 5 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Single Table Metadata."""

import copy
import json
import re
from copy import deepcopy
from datetime import datetime
from pathlib import Path

import pandas as pd
Expand All @@ -12,16 +14,92 @@
class SingleTableMetadata:
"""Single Table Metadata class."""

_EXPECTED_KWARGS = {
'numerical': frozenset(['representation']),
'datetime': frozenset(['datetime_format']),
'categorical': frozenset(['order', 'order_by']),
'boolean': frozenset([]),
'text': frozenset(['regex_format']),
}

_DTYPES_TO_SDTYPES = {
'i': 'numerical',
'f': 'numerical',
'O': 'categorical',
'b': 'boolean',
'M': 'datetime',
}
KEYS = ['columns', 'primary_key', 'alternate_keys', 'constraints', 'SCHEMA_VERSION']

_NUMERICAL_REPRESENTATIONS = frozenset([
'int', 'int64', 'int32', 'int16', 'int8',
'uint', 'uint64', 'uint32', 'uint16', 'uint8',
'float', 'float64', 'float32', 'float16', 'float8',
])
_KEYS = frozenset([
'columns',
'primary_key',
'alternate_keys',
'constraints',
'SCHEMA_VERSION'
])
SCHEMA_VERSION = 'SINGLE_TABLE_V1'

def _validate_numerical(self, column_name, **kwargs):
representation = kwargs.get('representation')
if representation and representation not in self._NUMERICAL_REPRESENTATIONS:
raise ValueError(
f"Invalid value for 'representation' {representation} for column '{column_name}'.")

@staticmethod
def _validate_datetime(column_name, **kwargs):
datetime_format = kwargs.get('datetime_format')
if datetime_format:
try:
formated_date = datetime.now().strftime(datetime_format)
except Exception as exception:
raise ValueError(
f"Invalid datetime format string '{datetime_format}' "
f"for datetime column '{column_name}'."
) from exception

matches = re.findall('(%.)|(%)', formated_date)
if matches:
raise ValueError(
f"Invalid datetime format string '{datetime_format}' "
f"for datetime column '{column_name}'."
)

@staticmethod
def _validate_categorical(column_name, **kwargs):
order = kwargs.get('order')
order_by = kwargs.get('order_by')
if order and order_by:
raise ValueError(
f"Categorical column '{column_name}' has both an 'order' and 'order_by' "
'attribute. Only 1 is allowed.'
)
if order_by and order_by not in ('numerical_value', 'alphabetical'):
raise ValueError(
f"Unknown ordering method '{order_by}' provided for categorical column "
f"'{column_name}'. Ordering method must be 'numerical_value' or 'alphabetical'."
)
if (isinstance(order, list) and not len(order)) or\
(not isinstance(order, list) and order is not None):
raise ValueError(
f"Invalid order value provided for categorical column '{column_name}'. "
"The 'order' must be a list with 1 or more elements."
)

@staticmethod
def _validate_text(column_name, **kwargs):
regex = kwargs.get('regex_format')
try:
re.compile(regex)
except Exception as exception:
raise ValueError(
f"Invalid regex format string '{regex}' for text column '{column_name}'."
) from exception

def __init__(self):
self._columns = {}
self._primary_key = None
Expand All @@ -36,6 +114,88 @@ def __init__(self):
'SCHEMA_VERSION': self.SCHEMA_VERSION
}

def _validate_unexpected_kwargs(self, column_name, sdtype, **kwargs):
expected_kwargs = self._EXPECTED_KWARGS.get(sdtype, ['pii'])
unexpected_kwargs = set(list(kwargs)) - set(expected_kwargs)
if unexpected_kwargs:
unexpected_kwargs = list(unexpected_kwargs)
unexpected_kwargs.sort()
unexpected_kwargs = ', '.join(unexpected_kwargs)
raise ValueError(
f"Invalid values '({unexpected_kwargs})' for {sdtype} column '{column_name}'.")

def _validate_column_exists(self, column_name):
if column_name not in self._columns:
raise ValueError(
f"Column name ('{column_name}') does not exist in the table. "
"Use 'add_column' to add new column."
)

def _validate_column(self, column_name, sdtype, **kwargs):
self._validate_unexpected_kwargs(column_name, sdtype, **kwargs)
if sdtype == 'categorical':
self._validate_categorical(column_name, **kwargs)
elif sdtype == 'numerical':
self._validate_numerical(column_name, **kwargs)
elif sdtype == 'datetime':
self._validate_datetime(column_name, **kwargs)
elif sdtype == 'text':
self._validate_text(column_name, **kwargs)

def add_column(self, column_name, **kwargs):
"""Add a column to the ``SingleTableMetadata``.
Args:
column_name (str):
The column name to be added.
kwargs (type):
Any additional key word arguments for the column, where ``sdtype`` is required.
Raises:
- ``ValueError`` if the column already exists.
- ``ValueError`` if the ``kwargs`` do not contain ``sdtype``.
- ``ValueError`` if the column has unexpected values or ``kwargs`` for the given
``sdtype``.
"""
if column_name in self._columns:
raise ValueError(
f"Column name '{column_name}' already exists. Use 'update_column' "
'to update an existing column.'
)

sdtype = kwargs.get('sdtype')
if sdtype is None:
raise ValueError(f"Please provide a 'sdtype' for column '{column_name}'.")

self._validate_column(column_name, **kwargs)
self._columns[column_name] = deepcopy(kwargs)

def update_column(self, column_name, **kwargs):
"""Update an existing column in the ``SingleTableMetadata``.
Args:
column_name (str):
The column name to be updated.
**kwargs (type):
Any key word arguments that describe metadata for the column.
Raises:
- ``ValueError`` if the column doesn't already exist in the ``SingleTableMetadata``.
- ``ValueError`` if the column has unexpected values or ``kwargs`` for the current
``sdtype``.
"""
self._validate_column_exists(column_name)
_kwargs = deepcopy(kwargs)
if 'sdtype' in kwargs:
sdtype = kwargs.pop('sdtype')
else:
sdtype = self._columns[column_name]['sdtype']
_kwargs['sdtype'] = sdtype

self._validate_column(column_name, sdtype, **kwargs)
self._columns[column_name] = _kwargs

def detect_from_dataframe(self, data):
"""Detect the metadata from a ``pd.DataFrame`` object.
Expand Down Expand Up @@ -99,7 +259,7 @@ def to_dict(self):
elif value:
metadata[key] = value

return copy.deepcopy(metadata)
return deepcopy(metadata)

def _set_metadata_dict(self, metadata):
"""Set a ``metadata`` dictionary to the current instance.
Expand All @@ -109,8 +269,8 @@ def _set_metadata_dict(self, metadata):
Python dictionary representing a ``SingleTableMetadata`` object.
"""
self._metadata = {}
for key in self.KEYS:
value = copy.deepcopy(metadata.get(key))
for key in self._KEYS:
value = deepcopy(metadata.get(key))
if key == 'constraints' and value:
value = [Constraint.from_dict(constraint_dict) for constraint_dict in value]

Expand Down
Loading

0 comments on commit 72ba076

Please sign in to comment.