In [0]:
# libs
from pyspark.sql import SparkSession

# create session
spark = SparkSession.builder\
        .master("local")\
        .appName("Colab")\
        .config('spark.ui.port', '4050')\
        .getOrCreate()

class DataPrep:
    def __init__(self,dataframe_dict):
        self.dataframe_dict=dataframe_dict
    
    
    #--------------------------------------# CREATE AND STORAGE #--------------------------------------#
    
    def storage_df(self):
        
        """
        Creates and stores a dataframe.

        Returns:
            df (pyspark.sql.DataFrame): The persisted dataframe.
        """
        
        from pyspark import StorageLevel
        
        dataframe_dict = self.dataframe_dict

        # create dataframe
        if isinstance(dataframe_dict['schema'], str):
            df = spark.table(dataframe_dict['schema'])
            df = df.persist(StorageLevel.DISK_ONLY)
        else:
            df = dataframe_dict['schema']

        return df
    
    
    #--------------------------------------# PREP: ARRAY STRING TO STRUCT #--------------------------------------#
    
    def json_string_column_to_struct(self):
        
        from pyspark.sql.functions import from_json
        
        # create objects
        dataframe=self.storage_df()
        json_list=self.dataframe_dict['json_list']
        
        # transform
        for column in json_list:
            json_schema = spark.read.json(dataframe.rdd.map(lambda row: row[str(column)])).schema
            dataframe = dataframe.withColumn(f'{column}_struct', from_json(str(column), json_schema))
            dataframe = dataframe.drop(column).withColumnRenamed(f'{column}_struct', f'{column}')
        
        return dataframe
    
    
    #--------------------------------------# SELECT AND RENAME #--------------------------------------#
    
    def basic_prep(self):
        
        from pyspark.sql.functions import col
        
        # create objects
        dataframe_dict = self.dataframe_dict
        dataframe = self.json_string_column_to_struct()
        
        # transform
        df = (
            dataframe
            .select(*dataframe_dict['columns'].keys())
            .toDF(*dataframe_dict['columns'].values())
        )
        
        for col_name in dataframe_dict['pk_list']:
            df = df.filter(col(col_name).isNotNull())
            
        return df
    
    #--------------------------------------# COMPOSED KEYS #--------------------------------------#
    
    def composed_keys(self):
        
        from pyspark.sql.functions import col, concat_ws, lower

        # create objects
        dataframe_dict = self.dataframe_dict
        key=dataframe_dict['key'] 
        pk_name=dataframe_dict['pk_name']
        pk_list=dataframe_dict['pk_list']
        fk_name=dataframe_dict['fk_name']
        fk_list=dataframe_dict['fk_list']
        dataframe = self.basic_prep()

        # pk transform
        df = (
            dataframe
            .withColumn(
                f'{pk_name}',
                concat_ws('_', col(str(pk_list[0])), col(str(pk_list[1])))
            )
        )
        
        # fk transform
        for name, list_cols in zip(fk_name, fk_list):
            df = (
                df
                .withColumn(
                    f'{name}',
                    concat_ws('_', lower(col(str(list_cols[0]))), lower(col(str(list_cols[1]))))
                )
            )
        
        return df
    
    
    #--------------------------------------# CREATE FLAGS #--------------------------------------#
    
    def flag_function(self):
        from pyspark.sql.functions import col, when

        # create objects
        dataframe_dict = self.dataframe_dict
        drop_list=dataframe_dict['drop_list']
        df = self.composed_keys()
        
        for column in dataframe_dict['create_flag']:
            df = df.withColumn(f'flag_{column}', when((col(column).isNotNull()) & (col(column) != 0), 1).otherwise(0))
        
        df = df.drop(*drop_list)

        return df
    
    
    #--------------------------------------# GROUP BY: FOREIGN KEYS #--------------------------------------#
    
    def groupby_function(self):
        
        # create objects
        dataframe_dict = self.dataframe_dict
        dataframe = self.flag_function()
    
        if dataframe_dict['key'] == 'fk':
            df = (
                dataframe
                .groupBy(dataframe_dict['fk_name']).max()
            )

            for column in [column for column in dataframe.columns if column != dataframe_dict['fk_name']]:
                df = df.withColumnRenamed(f"max({column})", f"{column}") 
                
        elif dataframe_dict['key'] == 'pk':
            df = self.flag_function()

        return df
    
    
    #--------------------------------------# RANKING: FOREIGN KEYS #--------------------------------------#
    
    def ranking(self):
        from pyspark.sql.functions import col, lower, row_number
        from pyspark.sql.window import Window
        
        # create objects
        dataframe_dict = self.dataframe_dict
        dataframe = self.groupby_function()
        
        tmp_dict = {}

        if dataframe_dict['ranking'] == True:
            # ranking
            df = (
                dataframe
                .select(dataframe_dict['partition'], dataframe_dict['col_rank'])
                .withColumn('rank', row_number().over(Window.partitionBy(dataframe_dict['partition']).orderBy(col(dataframe_dict['col_rank']).desc())))
                .groupBy(dataframe_dict['partition']).agg({'rank':'max'})
                .withColumnRenamed(f"max(rank)", 'rank')
                .join(
                    (dataframe.withColumn('rank', row_number().over(Window.partitionBy(dataframe_dict['partition']).orderBy(dataframe_dict['col_rank']))))
                    ,on = [dataframe_dict['partition'], 'rank']
                )
                .drop('rank', 'date')
            )
        else:
            df = self.groupby_function()

        return df
    

#-------------------------------------------------# DATAFRAMES #-------------------------------------------------#
import datetime
from datetime import timedelta
from random import choice, random, randrange

columns1 = ["col_1","col_2","col_4","col_flag","col_date","bool1","bool2",]
columns2 = ["col_1","col_2","col_3","jsoncolumn","col_flag2","col_date","bool1","bool2",]

data_list1, data_list2 = [], []

start_date = datetime.datetime.now() - timedelta(days=365)
end_date = start_date + timedelta(days=365)

for i in range(1,10):
    a = (f'label{i}', f'abc{i}', randrange(200), randrange(0,3), (start_date + (end_date - start_date) * random()), choice([True, False]), choice([True, False]))
    data_list1.append(a)
    
for i in range(1,5):
    b = (f'label{i}', f'abc{i}', f'def{i}', "{'col_5':222, 'col_X':'wrong'}", randrange(0,3), (start_date + (end_date - start_date) * random()), choice([True, False]), choice([True, False]))
    data_list2.append(b)

df1 = spark.createDataFrame(data=data_list1,schema=columns1)
df2 = spark.createDataFrame(data=data_list2,schema=columns2)


#-------------------------------------------------# PARAMETERS (DICTS) #-------------------------------------------------#

dict1 = {
    'schema':df1,   # or schema
    'key':'pk',
    'json_list':[],
    'columns':{
        'col_1':'col1',
        'col_2':'col2',
        'col_4':'col4',
        'col_flag':'col_flag',
        'col_date':'date',
    },
    'pk_name' : 'col1_col2',
    'pk_list':['col1', 'col2'],
    'fk_name':[],
    'fk_list':[],
    'create_flag':['col_flag'],
    'drop_list':['col2', 'col_flag'],
    'ranking':True,
    'partition':'col1_col2',
    'col_rank':'date',
}

dict2 = {
    'schema':df2,   #or schema
    'key':'fk',
    'json_list':['jsoncolumn'],
    'columns':{
        'col_1':'col1',
        'col_2':'col2',
        'col_3':'col3',
        'jsoncolumn.col_5':'col5',
        'col_flag2':'col_flag2',
        'col_date':'date',
    },
    'pk_name' : 'col1_col3',
    'pk_list':['col1', 'col3'],
    'fk_name':['col1_col2'],
    'fk_list':[['col1', 'col2']],
    'create_flag':['col_flag2'],
    'drop_list':['col2','col3', 'col_flag2,'],
    'ranking':False,
    
}
#-------------------------------------------------# DICT OF SOURCES DICTS #-------------------------------------------------#

dict_list = {
    'table1':{
        'dataframe':'dict1_name',
        'dict':dict1
    },
    'table2':{
        'dataframe':'dict2_name',
        'dict':dict2
    },
}

join_key = 'col1_col2'
composed_keys = ['col1', 'col2']


#-------------------------------------------------# RUN #-------------------------------------------------#

from pyspark.sql.types import StructType, StructField, StringType, IntegerType


dfs_dict = {}

for source in dict_list:
    df_class = DataPrep(dataframe_dict=dict_list[source]['dict'])
    dfs_dict[dict_list[source]['dataframe']] = df_class.ranking()

schema = StructType([
    StructField(join_key, StringType(), True),
  ])

df = spark.createDataFrame([], schema)

for dataframes in dfs_dict.values():
    df = df.join(dataframes, on=[join_key], how='fullouter')


#-------------------------------------------------# DISPLAYS #-------------------------------------------------#

display(dfs_dict['dict1_name'])
display(dfs_dict['dict2_name'])
display(df)

col1_col2,col1,col4,flag_col_flag
label1_abc1,label1,199,1
label2_abc2,label2,146,1
label3_abc3,label3,1,0
label4_abc4,label4,20,1
label5_abc5,label5,57,1
label6_abc6,label6,119,0
label7_abc7,label7,197,1
label8_abc8,label8,26,1
label9_abc9,label9,149,0


col1_col2,col5,col_flag2,flag_col_flag2
label1_abc1,222,2,1
label2_abc2,222,2,1
label3_abc3,222,1,1
label4_abc4,222,2,1


col1_col2,col1,col4,flag_col_flag,col5,col_flag2,flag_col_flag2
label1_abc1,label1,199,1,222.0,2.0,1.0
label2_abc2,label2,146,1,222.0,2.0,1.0
label3_abc3,label3,1,0,222.0,1.0,1.0
label4_abc4,label4,20,1,222.0,2.0,1.0
label5_abc5,label5,57,1,,,
label6_abc6,label6,119,0,,,
label7_abc7,label7,197,1,,,
label8_abc8,label8,26,1,,,
label9_abc9,label9,149,0,,,
