diff --git a/pandera/io.py b/pandera/io.py index d26484f9b..5d9f8fa13 100644 --- a/pandera/io.py +++ b/pandera/io.py @@ -156,7 +156,9 @@ def _deserialize_schema(serialized_schema): for index_component in serialized_schema["index"] ] - if len(index) == 1: + if index is None: + index = None + elif len(index) == 1: index = Index(**index[0]) else: index = MultiIndex(indexes=[ diff --git a/tests/test_io.py b/tests/test_io.py index 65c4ac2b5..611308a5d 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -125,6 +125,63 @@ def _create_schema(index="single"): """.format(version=pa.__version__) +def _create_schema_null_index(): + + return pa.DataFrameSchema( + columns={ + "float_column": pa.Column( + pa.Float, checks=[ + pa.Check.greater_than(-10), + pa.Check.less_than(20), + pa.Check.in_range(-10, 20), + ] + ), + "str_column": pa.Column( + pa.String, checks=[ + pa.Check.isin(["foo", "bar", "x", "xy"]), + pa.Check.str_length(1, 3) + ] + ), + }, + index=None + ) + + +YAML_SCHEMA_NULL_INDEX = """ +schema_type: dataframe +version: {version} +columns: + float_column: + pandas_dtype: float + nullable: false + checks: + greater_than: -10 + less_than: 20 + in_range: + min_value: -10 + max_value: 20 + str_column: + pandas_dtype: string + nullable: false + checks: + isin: + - foo + - bar + - x + - xy + str_length: + min_value: 1 + max_value: 3 +index: null +coerce: false +strict: false +""".format(version=pa.__version__) + +YAML_VALIDATION_PAIRS = [ + [YAML_SCHEMA, _create_schema], + [YAML_SCHEMA_NULL_INDEX, _create_schema_null_index] +] + @pytest.mark.skipif( PYYAML_VERSION.release < (5, 1, 0), # type: ignore reason="pyyaml >= 5.1.0 required", @@ -162,10 +219,12 @@ def test_to_yaml(): ) def test_from_yaml(): """Test that from_yaml reads yaml string.""" - schema_from_yaml = io.from_yaml(YAML_SCHEMA) - expected_schema = _create_schema() - assert schema_from_yaml == expected_schema - assert expected_schema == schema_from_yaml + + for yml_string, schema_creator in YAML_VALIDATION_PAIRS: + schema_from_yaml = io.from_yaml(yml_string) + expected_schema = schema_creator() + assert schema_from_yaml == expected_schema + assert expected_schema == schema_from_yaml def test_io_yaml_file_obj():