Skip to content

Commit

Permalink
enables simple schema-transformations with add/remove columns (#144)
Browse files Browse the repository at this point in the history
* enables simple schema transformations

Co-authored-by: Niels Bantilan <niels.bantilan@gmail.com>
  • Loading branch information
mastersplinter and cosmicBboy committed Jan 13, 2020
1 parent 7bf2fe8 commit dcc40f4
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 1 deletion.
64 changes: 64 additions & 0 deletions docs/source/dataframe_schemas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,67 @@ Some examples of where this can be provided to pandas are:
column1 column2 column3
a 1 valueA True
b 1 valueB True



DataFrameSchema Transformations
-------------------------------

Pandera supports transforming a schema using ``.add_columns`` and ``.remove_columns``.

``.add_columns`` expects a ``Dict[str, Any]``, i.e. the same as when defining ``Columns`` in a ``DataFrameSchema``:

.. testcode:: add_columns

from pandera import DataFrameSchema, Column, Int, Check, String, Object

schema = DataFrameSchema({
"col1": Column(Int, Check(lambda s: s >= 0)),
}, strict=True)

new_schema = schema.add_columns({
"col2": Column(String, Check(lambda x: x <= 0)),
"col3": Column(Object, Check(lambda x: x == 0))
})

expected_schema = schema.add_columns({
"col1": Column(Int, Check(lambda s: s >= 0)),
"col2": Column(String, Check(lambda x: x <= 0)),
"col3": Column(Object, Check(lambda x: x == 0))
})

print(new_schema == expected_schema)

.. testoutput:: add_columns
:options: +NORMALIZE_WHITESPACE

True

``.remove_columns`` expects a list of one or more Column names:

.. testcode:: remove_columns

from pandera import DataFrameSchema, Column, Int, Check, String, Object

schema = DataFrameSchema({
"col1": Column(Int, Check(lambda s: s >= 0)),
"col2": Column(String, Check(lambda x: x <= 0)),
"col3": Column(Object, Check(lambda x: x == 0))
}, strict=True)

new_schema = schema.remove_columns(["col2", "col3"])

print(new_schema)

.. testoutput:: remove_columns
:options: +NORMALIZE_WHITESPACE

DataFrameSchema(
columns={
"col1": "<Schema Column: 'col1' type=int64>"
},
index=None,
transformer=None,
coerce=False,
strict=True
)
33 changes: 32 additions & 1 deletion pandera/schemas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Core pandera schema class definitions."""

import json
from typing import List, Optional, Union, Dict
import copy
from typing import List, Optional, Union, Dict, Any


import pandas as pd

Expand Down Expand Up @@ -287,6 +289,35 @@ def __str__(self):
def __eq__(self, other):
return self.__dict__ == other.__dict__

def add_columns(self,
extra_schema_cols: Dict[str, Any]) -> 'DataFrameSchema':
"""Create a new DataFrameSchema with extra Columns
:param extra_schema_cols: Additional columns of the format
:type extra_schema_cols: DataFrameSchema
:returns: a new DataFrameSchema with the extra_schema_cols added
"""
schema_copy = copy.deepcopy(self)
schema_copy.columns = {**schema_copy.columns,
**DataFrameSchema(extra_schema_cols).columns}
return schema_copy

def remove_columns(self,
cols_to_remove: List) -> 'DataFrameSchema':
"""Removes a column from a DataFrameSchema and returns a new
DataFrameSchema.
:param cols_to_remove: Columns to be removed from the DataFrameSchema
:type cols_to_remove: List
:returns: a new DataFrameSchema without the cols_to_remove
"""
schema_copy = copy.deepcopy(self)
for col in cols_to_remove:
schema_copy.columns.pop(col)

return schema_copy


class SeriesSchemaBase():
Expand Down
51 changes: 51 additions & 0 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,54 @@ def test_schema_equality_operators():
assert series_schema != not_equal_schema
assert series_schema_base == copy.deepcopy(series_schema_base)
assert series_schema_base != not_equal_schema


def test_add_and_remove_columns():
"""Check that adding and removing columns works as expected and doesn't
modify the original underlying DataFrameSchema."""
schema1 = DataFrameSchema({
"col1": Column(Int, Check(lambda s: s >= 0)),
}, strict=True)

schema1_exact_copy = copy.deepcopy(schema1)

# test that add_columns doesn't modify schema1 after add_columns:
schema2 = schema1.add_columns({
"col2": Column(String, Check(lambda x: x <= 0)),
"col3": Column(Object, Check(lambda x: x == 0))
})

schema2_exact_copy = copy.deepcopy(schema2)

assert schema1 == schema1_exact_copy

# test that add_columns changed schema1 into schema2:
expected_schema_2 = DataFrameSchema({
"col1": Column(Int, Check(lambda s: s >= 0)),
"col2": Column(String, Check(lambda x: x <= 0)),
"col3": Column(Object, Check(lambda x: x == 0))
}, strict=True)

assert schema2 == expected_schema_2

# test that remove_columns doesn't modify schema2:
schema3 = schema2.remove_columns(["col2"])

assert schema2 == schema2_exact_copy

# test that remove_columns has removed the changes as expected:
expected_schema_3 = DataFrameSchema({
"col1": Column(Int, Check(lambda s: s >= 0)),
"col3": Column(Object, Check(lambda x: x == 0))
}, strict=True)

assert schema3 == expected_schema_3

# test that remove_columns can remove two columns:
schema4 = schema2.remove_columns(["col2", "col3"])

expected_schema_4 = DataFrameSchema({
"col1": Column(Int, Check(lambda s: s >= 0))
}, strict=True)

assert schema4 == expected_schema_4 == schema1

0 comments on commit dcc40f4

Please sign in to comment.