Skip to content

Commit

Permalink
Add StructType and DDL extraction from Pandera schemas (#1570)
Browse files Browse the repository at this point in the history
* organize tests and add multiple inheritance test for model

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>

* organize tests and add multiple inheritance test for model

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>

* organize tests and add multiple inheritance test for model

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>

* fix test format

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>

* add nested structure test

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>

* add read test case using CSV wrong schema inference

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>

* add read test case using CSV wrong schema inference

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>

* accept abhishek s suggestion

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>

* skip read test in Windows plataform

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>

* skip read test in Windows plataform

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>

---------

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>
  • Loading branch information
filipeo2-mck committed Apr 27, 2024
1 parent 1148867 commit cf09ae2
Show file tree
Hide file tree
Showing 4 changed files with 546 additions and 2 deletions.
31 changes: 30 additions & 1 deletion pandera/api/pyspark/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union, cast, overload

from pyspark.sql import DataFrame
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import StructType, StructField

from pandera import errors
from pandera.api.base.schema import BaseSchema
Expand Down Expand Up @@ -563,6 +564,34 @@ def to_json(

return pandera.io.to_json(self, target, **kwargs)

def to_structtype(self) -> StructType:
"""Recover fields of DataFrameSchema as a Pyspark StructType object.
As the output of this method will be used to specify a read schema in Pyspark
(avoiding automatic schema inference), the False `nullable` properties are
just ignored, as this check will be executed by the Pandera validations
after a dataset is read.
:returns: StructType object with current schema fields.
"""
fields = [
StructField(column, self.columns[column].dtype.type, True)
for column in self.columns
]
return StructType(fields)

def to_ddl(self) -> str:
"""Recover fields of DataFrameSchema as a Pyspark DDL string.
:returns: String with current schema fields, in compact DDL format.
"""
# `StructType.toDDL()` is only available in internal java classes
spark = SparkSession.builder.getOrCreate()
# Create a base dataframe from where we access underlying Java classes
empty_df_with_schema = spark.createDataFrame([], self.to_structtype())

return empty_df_with_schema._jdf.schema().toDDL()


def _validate_columns(
column_dict: dict[Any, "pandera.api.pyspark.components.Column"], # type: ignore [name-defined]
Expand Down
18 changes: 18 additions & 0 deletions pandera/api/pyspark/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Class-based api for pyspark models."""

# pylint:disable=abstract-method
import copy
import inspect
Expand All @@ -22,6 +23,7 @@
)

import pyspark.sql as ps
from pyspark.sql.types import StructType

from pandera.api.base.model import BaseModel
from pandera.api.checks import Check
Expand Down Expand Up @@ -271,6 +273,22 @@ def to_yaml(cls, stream: Optional[os.PathLike] = None):
"""
return cls.to_schema().to_yaml(stream)

@classmethod
def to_structtype(cls) -> StructType:
"""Recover fields of DataFrameModel as a Pyspark StructType object.
:returns: StructType object with current model fields.
"""
return cls.to_schema().to_structtype()

@classmethod
def to_ddl(cls) -> str:
"""Recover fields of DataFrameModel as a Pyspark DDL string.
:returns: String with current model fields, in compact DDL format.
"""
return cls.to_schema().to_ddl()

@classmethod
@docstring_substitution(validate_doc=DataFrameSchema.validate.__doc__)
def validate(
Expand Down
Loading

0 comments on commit cf09ae2

Please sign in to comment.