Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

287 moving columns to multiindex #319

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions pandera/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from .error_handlers import SchemaErrorHandler
from .hypotheses import Hypothesis
from .schema_components import MultiIndex, Index, Column
cosmicBboy marked this conversation as resolved.
Show resolved Hide resolved

N_INDENT_SPACES = 4

Expand Down Expand Up @@ -689,6 +690,181 @@ def to_yaml(self, fp: Union[str, Path] = None):

return pandera.io.to_yaml(self, fp)

def set_index(self, keys: List[str], drop: bool = True, append: bool = False, inplace: bool = False):
"""
A method for setting the :class:`Index` of a :class:`DataFrameSchema`,
via an existing :class:`Column` or list of :class:`Column`s.

:param keys: list of labels
:param drop: bool, default True
:param append: bool, default False
:param inplace: bool, default False
:return: a new :class:`DataFrameSchema` with specified column(s) in the index.
"""

# first check if should be done to self or make copy
if inplace:
new_schema = self
else:
new_schema = copy.deepcopy(self)

if not isinstance(keys, list):
keys = [keys]

# ensure all specified keys are present in the columns
try:
not_in_cols: List[str] = [x for x in keys if x not in new_schema.columns.keys()]
assert not_in_cols == []
except AssertionError:
raise Exception(f"Keys {not_in_cols} not found in schema columns!")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these exceptions should be errors.SchemaInitError and use the from keyword so that full error trace is available. assert statements are also nice when quickly scripting things, but in this case it might be more appropriate to use a conditional then raise here.

not_in_cols: List[str] = [
    x for x in keys_temp if x not in new_schema.columns.keys()
]
if not not_in_cols:
     raise errors.SchemaInitError(f"Keys {not_in_cols} not found in schema columns!")


# ensure no duplicates
try:
dup_cols:List[str] = [x for x in set(keys) if keys.count(x) > 1]
assert dup_cols == []
except AssertionError:
raise Exception(f"Keys {dup_cols} are duplicated!")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment above


# make a list if not None
if (keys is not None) and (isinstance(keys, list) is False):
cosmicBboy marked this conversation as resolved.
Show resolved Hide resolved
keys = list(keys)


# if there is already an index, append or replace according to parameters
if new_schema.index is not None:
if isinstance(new_schema.index, MultiIndex) and append:
ind_list: List = list(new_schema.index.columns.values())
elif isinstance(new_schema.index, Index) and append:
ind_list: List = [new_schema.index]
else:
ind_list: list = []
cosmicBboy marked this conversation as resolved.
Show resolved Hide resolved
# if there is no index, then create from columns
else:
ind_list: list = []

for col in keys:
ind_list.append(Index(pandas_dtype = new_schema.columns[col].dtype,
cosmicBboy marked this conversation as resolved.
Show resolved Hide resolved
name = col,
checks = new_schema.columns[col].checks,
nullable = new_schema.columns[col].nullable,
allow_duplicates = new_schema.columns[col].allow_duplicates,
coerce = new_schema.columns[col].coerce))


if len(ind_list) == 1:
new_schema.index = ind_list[0]
elif len(ind_list) > 1:
new_schema.index = MultiIndex(ind_list)

# if drop is True as defaulted, drop the columns moved into the index
if drop:
new_schema.columns = new_schema.remove_columns(keys).columns
cosmicBboy marked this conversation as resolved.
Show resolved Hide resolved

if not inplace:
return new_schema
else:
self.columns = new_schema.columns
self.index = new_schema.index

def reset_index(self, level: List[str] = None, drop: bool = False, inplace: bool = False):
"""
A method for reseting the :class:`Index` of a :class:`DataFrameSchema`.

:param level: list of labels
:param drop: bool, default True
:param append: bool, default False
:param inplace: bool, default False
:return: a new :class:`DataFrameSchema` with specified column(s) in the index.

"""
# first check if should be done to self or make copy
if inplace:
new_schema = self
else:
new_schema = copy.deepcopy(self)
try:
assert new_schema.index is not None
except AssertionError:
raise Exception('There is currently no index set for this schema.')

# ensure all specified keys are present in the index
try:
if isinstance(new_schema.index, MultiIndex) and (level is not None):
not_in_cols: List[str] = [x for x in level if x not in list(new_schema.index.columns.keys())]
elif isinstance(new_schema.index, Index) and (level is not None):
not_in_cols: List[str] = [] if ([new_schema.index.name] == level) else level
else:
not_in_cols:list = []
assert not_in_cols == []
except AssertionError:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment on line 717

raise Exception(f"Keys {not_in_cols} not found in schema columns!")

# make a list if not None
if (level is not None) and (isinstance(level, list) is False):
level = list(level)

# ensure no duplicates
if level is not None:
try:
dup_cols:List[str] = [x for x in set(level) if level.count(x) > 1]
assert dup_cols == []
except AssertionError:
raise Exception(f"Keys {dup_cols} are duplicated!")

#
additional_columns: list = []
new_index = new_schema.index
if level is None:
cosmicBboy marked this conversation as resolved.
Show resolved Hide resolved
new_index = None
print(new_index)
if not drop:
if isinstance(new_schema.index, MultiIndex):
additional_columns: List[str] = additional_columns + [ind for ind in list(new_schema.index.columns.keys())]
else:
additional_columns.append(new_schema.index.name)
else:
if isinstance(new_schema.index, MultiIndex):
new_index = new_schema.index.remove_columns(level)
if len(list(new_index.columns.keys())) == 1:
ind_key = list(new_index.columns.keys())[0]
ind_obj = new_index.columns[ind_key]
new_index:Index = Index(pandas_dtype=ind_obj.dtype,
checks=ind_obj.checks,
nullable=ind_obj.nullable,
allow_duplicates=ind_obj.allow_duplicates,
coerce=ind_obj.coerce,
name=ind_obj.name)
elif len(list(new_index.columns.keys())) == 0:
new_index: list = None

if not drop:
additional_columns = additional_columns + [ind for ind in level]

else:
new_index = None
if not drop:
additional_columns.append(level)

if not drop:
additional_columns: dict = {col: new_schema.index.columns.get(col) for col in additional_columns} \
if isinstance(new_schema.index, MultiIndex) \
else {additional_columns[0]: new_schema.index}

new_schema = new_schema.add_columns(
{k: Column(pandas_dtype=v.dtype,
checks=v.checks,
nullable=v.nullable,
allow_duplicates=v.allow_duplicates,
coerce=v.coerce,
name=v.name) for (k, v) in additional_columns.items()})

new_schema.index:Index = new_index
if not inplace:
return new_schema
else:
self.columns = new_schema.columns
self.index = new_schema.index


class SeriesSchemaBase:
"""Base series validator object."""
Expand Down
106 changes: 106 additions & 0 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,3 +1082,109 @@ def test_schema_coerce_inplace_validation(inplace, from_dtype, to_dtype):
else:
# not inplace preserves original dataframe type
assert df["column"].dtype == from_dtype




def get_schema_simple():
schema = DataFrameSchema(
columns = {'col1': Column(pandas_dtype = Int), 'col2': Column(pandas_dtype = Float)},
index = Index(pandas_dtype = String, name = 'ind0'))
return schema

def get_schema_multi():
schema = DataFrameSchema(
columns={'col1': Column(pandas_dtype=Int), 'col2': Column(pandas_dtype=Float)},
index=MultiIndex([Index(pandas_dtype=String, name='ind0'),Index(pandas_dtype=String, name='ind1')]))
return schema


@pytest.mark.parametrize("inplace", [True, False])
def test_set_index_inplace(inplace):
test_schema = get_schema_simple()
temp_schema = copy.deepcopy(test_schema)
test_schema.set_index(keys = ['col1'], inplace = inplace) # inplace should be defaulted to False
if inplace is True:
assert isinstance(test_schema.index, Index) # b/c append is defaulted to False
assert test_schema.index.name == 'col1'
assert test_schema.index.dtype == pa.Int
assert len(test_schema.columns) == 1
else:
assert test_schema == temp_schema

@pytest.mark.parametrize("drop", [True, False])
def test_set_index_drop(drop):
test_schema = get_schema_simple()
test_schema = test_schema.set_index(keys = ['col1'], drop = drop)
if drop is True:
assert len(test_schema.columns) == 1
assert list(test_schema.columns.keys()) == ['col2']
else:
assert len(test_schema.columns) == 2
assert list(test_schema.columns.keys()) == ['col1','col2']
assert test_schema.index.name == 'col1'

@pytest.mark.parametrize("append", [True, False])
def test_set_index_append(append):
temp_schema = get_schema_simple()
test_schema = temp_schema.set_index(keys = ['col1'], append = append)
if append is True:
assert isinstance(test_schema.index, MultiIndex)
assert list(test_schema.index.columns.keys()) == ['ind0', 'col1']
assert test_schema.index.columns['col1'].dtype == temp_schema.columns['col1'].dtype
else:
assert isinstance(test_schema.index, Index)
assert test_schema.index.name == 'col1'

### reset_index tests

def test_reset_index_copy():
test_schema = get_schema_simple()
temp_schema = copy.deepcopy(test_schema)
temp_schema.reset_index() # inplace should be defaulted to False
assert test_schema == temp_schema

@pytest.mark.parametrize("inplace", [True, False])
def test_reset_index_inplace(inplace):
test_schema = get_schema_simple()
temp_schema = copy.deepcopy(test_schema)
test_schema.reset_index(inplace = inplace) # inplace should be defaulted to False
if inplace is True:
#print(test_schema.index)
assert test_schema.index is None
else:
assert test_schema == temp_schema


@pytest.mark.parametrize("drop", [True, False])
def test_reset_index_drop(drop):
test_schema = get_schema_simple()
test_schema = test_schema.reset_index(drop=drop)
if drop:
assert len(test_schema.columns) == 2
assert list(test_schema.columns.keys()) == ['col1', 'col2']
else:
assert len(test_schema.columns) == 3
assert list(test_schema.columns.keys()) == ['col1', 'col2','ind0']
assert test_schema.index is None

def test_reset_index_level():
temp_schema = get_schema_multi()

test_schema = temp_schema.reset_index(level=['ind0'])
#print(test_schema.index)
assert test_schema.index.name == 'ind1'
assert isinstance(test_schema.index, Index)

test_schema = temp_schema.reset_index(level=['ind0','ind1'])
assert test_schema.index is None
assert list(test_schema.columns.keys()) == ['col1', 'col2', 'ind0', 'ind1']


## general functionality
def test_invalid_keys():
test_schema = get_schema_simple()
with pytest.raises(Exception):
test_schema.set_index(['foo', 'bar'])
with pytest.raises(Exception):
test_schema.reset_index(['foo', 'bar'])