In [24]:
%load_ext autoreload
%autoreload 2

from spalah.dataframe import slice_dataframe, script_dataframe, flatten_schema
from pyspark.sql import SparkSession, DataFrame

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
spark = SparkSession.builder.getOrCreate()

In [26]:
from datetime import datetime, date

from pyspark.sql import Row

df = spark.createDataFrame([
    Row(a=1, b=2., c='string1', d=date(2000, 1, 1), e=datetime(2000, 1, 1, 12, 0)),
    Row(a=2, b=3., c='string2', d=date(2000, 2, 1), e=datetime(2000, 1, 2, 12, 0)),    
])
df.show()

+---+---+-------+----------+-------------------+
|  a|  b|      c|         d|                  e|
+---+---+-------+----------+-------------------+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|
|  2|3.0|string2|2000-02-01|2000-01-02 12:00:00|
+---+---+-------+----------+-------------------+



In [41]:
from spalah.dataframe import slice_dataframe

slice_dataframe(
    input_dataframe=df,    
    
    nullify_only=False
).show()

b
a
+---+---+-------+----------+-------------------+
|  a|  b|      c|         d|                  e|
+---+---+-------+----------+-------------------+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|
|  2|3.0|string2|2000-02-01|2000-01-02 12:00:00|
+---+---+-------+----------+-------------------+



In [None]:
nested_df = spark.sql("""
select 
    1 as column_a
,   2.0 as column_b
,   struct(
        "c1" as column_c_1
    ,   struct(
            "c_2_1" as c_2_1,
            "c_2_2" as c_2_2,
            "c_2_3" as c_2_3
    ) as column_c_2
) as column_c
""")

In [None]:
nested_df.printSchema()

In [None]:
slice_dataframe(
    input_dataframe=nested_df,
    columns_to_include=["column_c"],
    columns_to_exclude=["column_c.column_c_1"],
    nullify_only=False,
    debug=False
).select("column_c.*").show(10, False)

In [223]:
from datetime import datetime, date

from pyspark.sql import Row

df1 = spark.createDataFrame([
    Row(a=1, b="2.", c='string1', d=date(2000, 1, 1), e=datetime(2000, 1, 1, 12, 0), q="bcd"),    
])
df1.show()

df2 = spark.createDataFrame([
    Row(A=1, b=2., c='string1', d=date(2000, 1, 1), e=datetime(2000, 1, 1, 12, 0), f="abc"),    
])
df2.show()

schema1 = flatten_schema(df1.schema, True)
schema2 = flatten_schema(df2.schema, True)

+---+---+-------+----------+-------------------+---+
|  a|  b|      c|         d|                  e|  q|
+---+---+-------+----------+-------------------+---+
|  1| 2.|string1|2000-01-01|2000-01-01 12:00:00|bcd|
+---+---+-------+----------+-------------------+---+

+---+---+-------+----------+-------------------+---+
|  A|  b|      c|         d|                  e|  f|
+---+---+-------+----------+-------------------+---+
|  1|2.0|string1|2000-01-01|2000-01-01 12:00:00|abc|
+---+---+-------+----------+-------------------+---+



In [189]:
_source = set(schema1.copy())
_target = set(schema2.copy())

get_matched_by_name_and_type = lambda source, target: source & target

# Case 1: Find all matched columns, keep them in separate list
list_matched = get_matched_by_name_and_type(_source, _target)

print(list_matched)



{('e', 'TimestampType'), ('c', 'StringType'), ('d', 'DateType')}


In [190]:
remove_by_name_and_type = lambda base_value, subtract_value: base_value - subtract_value

_source = remove_by_name_and_type(_source, list_matched)
_target = remove_by_name_and_type(_target, list_matched)

print(_source)
print(_target)

{('b', 'StringType'), ('a', 'LongType'), ('q', 'StringType')}
{('f', 'StringType'), ('b', 'DoubleType'), ('A', 'LongType')}


In [191]:
column_names_lower = lambda base_value: set([(x.lower(), y) for (x, y) in base_value])
remove_by_name = lambda base_value, subtract_value: [(x, y) for (x, y) in base_value if not x.lower() in ([z[0] for z in subtract_value])]

_source_lowered = column_names_lower(_source)
_target_lowered =  column_names_lower(_target)

list_not_matched_by_case = get_matched_by_name_and_type(_source_lowered, _target_lowered)

print(list_not_matched_by_case)

_source = remove_by_name(_source, list_not_matched_by_case)
_target = remove_by_name(_target, list_not_matched_by_case)

print(_source)
print(_target)

{('a', 'LongType')}
[('b', 'StringType'), ('q', 'StringType')]
[('f', 'StringType'), ('b', 'DoubleType')]


In [192]:

def get_matched_by_name_but_not_type(source, target): 
    x = dict(source)
    y = dict(target)
    return [(k, f"{x[k]} <=> {y[k]}") for k in x if k in y and x[k] != y[k]]


In [193]:
list_not_matched_by_type = get_matched_by_name_but_not_type(_source, _target)
print(list_not_matched_by_type)

[('b', 'StringType <=> DoubleType')]


In [194]:
_source =  remove_by_name(_source, list_not_matched_by_type)
_target =  remove_by_name(_target, list_not_matched_by_type)

print(_source)
print(_target)

[('q', 'StringType')]
[('f', 'StringType')]


In [195]:
if _source:
    print("exists only the source:")
    print(_source)

exists only the source:
[('q', 'StringType')]


In [196]:
if _target:
    print("exists only the target:")
    print(_target)

exists only the target:
[('f', 'StringType')]


In [354]:
from collections import namedtuple
from pyspark.sql import types as T
from typing import Set
from pyspark.sql import DataFrame

MatchedColumn = namedtuple('MatchedColumn', ['name', 'data_type'])
NotMatchedColumn = namedtuple('NotMatchedColumn', ['name', 'data_type', 'reason'])

class  SchemaComparer():

    def __init__(self, source_schema: T.StringType, target_schema: T.StringType) -> None:
        self._source = self.__import_schema(source_schema)
        self._target = self.__import_schema(target_schema)
        self.matched = list()
        self.not_matched = list()


    def __import_schema(self, input_schema: T.StructType) -> Set[tuple]:
        """Import StructType as the flatten set of tuples: (column_name, data_type)

        Args:
            input_schema (T.StructType): Schema to process

        Raises:
            TypeError: if input schema has a type: DataFrame
            TypeError: if input schema hasn't a type: StructType

        Returns:
            Set[tuple]: Set of tuples: (column_name, data_type)
        """        

        if isinstance(input_schema, DataFrame):
            raise TypeError("One of 'source_schema or 'target_schema' passed as a DataFrame. Use DataFrame.schema instead")
        elif not isinstance(input_schema, T.StructType):
            raise TypeError("Parameters 'source_schema and 'target_schema' must have a type: StructType")
       
        return set(flatten_schema(input_schema, True))


    def __match_by_name_and_type(
        self, 
        source: Set[tuple] = set(), 
        target: Set[tuple] = set()
    ) -> Set[tuple]:
        """Matches columns in source and target schemas by name and data type

        Args:
            source (Set[tuple], optional): Flattened source schema. Defaults to set().
            target (Set[tuple], optional): Flattened target schema. Defaults to set().

        Returns:
            Set[tuple]: Fully matched columns as a set of tuples: (column_name, data_type)
        """        

        # If source and target is not provided, use class attributes as the input        
        _source = self._source if not source else source
        _target = self._target if not target else target


        result = _source & _target

         # Remove matched values of case 1 from further processing
        self.__remove_matched_by_name_and_type(result)
        
        if not (source and target):
            self.__populate_matched(result)

        return result

    def __remove_matched_by_name_and_type(self, subtract_value: Set[tuple]) -> None: 
        """Removes fully matched columns from the further processing

        Args:
            subtract_value (Set[tuple]): Set of matched columns
        """        
        
        self._source = self._source - subtract_value
        self._target = self._target - subtract_value

    def __remove_matched_by_name(self, subtract_value: Set[tuple]) -> None: 
        """Removes matched by name columns from the further processing

        Args:
            subtract_value (Set[tuple]): Set of matched column
        """        
        
        _remove = lambda input_value, subtract_value: [
            (x, y) for (x, y) in input_value 
            if not x.lower() in ([z[0].lower() for z in subtract_value])
        ]

        self._source = _remove(self._source, subtract_value)
        self._target = _remove(self._target, subtract_value)
        

    def __lower_column_names(self, base_value: Set[tuple]) -> Set[tuple]: 
        """Lower-case all column names of the input set

        Args:
            base_value (Set[tuple]): Input set of columns

        Returns:
            Set[tuple]: Output set of columns with lower-case column names
        """        
        return set([(x.lower(), y) for (x, y) in base_value])

    def __match_by_name_type_excluding_case(self) -> None:
        """Matches columns in source and target schemas by name and data type 
        without taking into account column name case
        """        
        
        _source_lowered = self.__lower_column_names(self._source)
        _target_lowered = self.__lower_column_names(self._target)

        result = self.__match_by_name_and_type(_source_lowered, _target_lowered)

        # Remove matched values of case 2 from further processing
        self.__remove_matched_by_name(result)

        self.__populate_not_matched(
            result, 
            "The column exists in source and target schemas but it's name is case-mismatched"
        )

    def __match_by_name_but_not_type(self) -> None: 
        """Matches columns in source and target schemas only by column name"""        

        x = dict(self._source)
        y = dict(self._target)
        result = [(k, f"{x[k]} <=> {y[k]}") for k in x if k in y and x[k] != y[k]]

        # Remove matched values of case 3 from further processing
        self.__remove_matched_by_name(result)

        self.__populate_not_matched(
            result, 
            "The column exists in source and target schemas but it is not matched by a data type"
        )        

    def __process_remaining_non_matched_columns(self) -> None:
        """Process remaining not matched columns"""        

        self.__populate_not_matched(self._source, "The column exists only in the source schema")

        self.__populate_not_matched(self._target, "The column exists only in the target schema")
        
        self.__remove_matched_by_name(self._source)
        self.__remove_matched_by_name(self._target)


    def __populate_matched(self, input_value: Set[tuple]) -> None:
        """Populate class property 'matched' with a list of fully matched columns

        Args:
            input_value (Set[tuple]): The set of tuples with a list of column names and data types
        """  

        for match in input_value:
            self.matched.append(
                MatchedColumn(name=match[0], data_type=match[1])
            )

    def __populate_not_matched(self, input_value: Set[tuple], reason: str) -> None:
        """Populate class property 'not_matched' with a list of columns that didn't match for some 
        reason with included an actual reason

        Args:
            input_value (Set[tuple]): The set of tuples with a list of column names and data types
            reason (str): Reason for not match
        """        

        for match in input_value:
            self.not_matched.append(
                NotMatchedColumn(name=match[0], data_type=match[1], reason=reason)
            )

    def compare(self):
        """Compares the source and target schemas and populates properties 'matched' and 'not_matched'"""        

        # Case 1: find columns that are matched by name and type and remove them from further processing
        self.__match_by_name_and_type()

        # Case 2: find columns that match mismatched by name due to case: ID <-> Id
        self.__match_by_name_type_excluding_case()

        # Case 3: Find columns matched by name, but not by data type
        self.__match_by_name_but_not_type()

        # Case 4: Find columns that exists only in the source or target
        self.__process_remaining_non_matched_columns()
        


In [357]:
cmp = SchemaComparer(
    source_schema=df1.schema, 
    target_schema=df2.schema
)

cmp.compare()


cmp.not_matched
#cmp._not_matched_by_name_case
#cmp._not_matched_by_type

#print(cmp._not_matched_exists_only_in_source)
#print(cmp._not_matched_exists_only_in_target)


[NotMatchedColumn(name='a', data_type='LongType', reason="The column exists in source and target schemas but it's name is case-mismatched"),
 NotMatchedColumn(name='b', data_type='StringType <=> DoubleType', reason='The column exists in source and target schemas but it is not matched by a data type'),
 NotMatchedColumn(name='q', data_type='StringType', reason='The column exists only in the source schema'),
 NotMatchedColumn(name='f', data_type='StringType', reason='The column exists only in the target schema')]