# Spark Application 

In [449]:
class SparkApp:
    
    def load_data(self,customers,products,transactions):
        '''
            customers: file name in form of string and in csv format (e.g-'customers.csv')
            products: file name in form of string and in csv format (e.g-'products.csv')
            transactions: file name in form of string and in json format (e.g-'transactions.csv')
            
            The files should be in following schema:
            customers:root
                         |-- customer_id: string (nullable = true)
                         |-- loyalty_score: integer (nullable = true)
            products:root
                         |-- product_id: string (nullable = true)
                         |-- product_category: integer (nullable = true)
            transactions:root
                         |-- basket: array (nullable = true)
                         |    |-- element: struct (containsNull = true)
                         |    |    |-- price: long (nullable = true)
                         |    |    |-- product_id: string (nullable = true)
                         |-- customer_id: string (nullable = true)
                         |-- date_of_purchase: string (nullable = true)

            
            creates a new spark session 
            returns the list of three dataframes [df_customers,df_products,df_transactions]
        '''
        
        spark = self._create_a_SparkSession()
        df_customers = spark.read.csv(customers,inferSchema = True , header = True)
        df_products = spark.read.csv(products,inferSchema = True , header = True)
        df_transactions = spark.read.json(transactions,multiLine=True)
        
        return [df_customers,df_products,df_transactions]
    
    def _create_a_SparkSession(self):
        import pyspark
        import findspark
        from pyspark.sql import SparkSession
        

        spark=SparkSession.builder.appName('Customer Details').getOrCreate()
        return spark
        
    def flatten(self,df,column):
        df = df.withColumn(column,explode(column))
        return df
    
    def count_duplicate_rows(self,df):
        '''
            
        '''
        df = df.groupBy(df.columns)\
                .count()\
                .where(f.col('count') > 0)
        return df
        
    def rename_column(self,df,original,new):
        df = df.withColumnRenamed(original,new)
        return df

    def outer_join(self,df1,df2,col):
        df_new =  df1.join(df2,on=col,how='outer')
        return df_new
    
    def sort(self,df,col_1,col_2):
        df = df.sort(f.col(col_1),f.col(col_2))
        return df
        
    def load_output(self,customers,products,transactions):
        '''
            customers: file name in form of string and in csv format (e.g-'customers.csv')
            products: file name in form of string and in csv format (e.g-'products.csv')
            transactions: file name in form of string and in json format (e.g-'transactions.csv')
            
            The files should be in following schema:
            customers:root
                         |-- customer_id: string (nullable = true)
                         |-- loyalty_score: integer (nullable = true)
            products:root
                         |-- product_id: string (nullable = true)
                         |-- product_category: integer (nullable = true)
            transactions:root
                         |-- basket: array (nullable = true)
                         |    |-- element: struct (containsNull = true)
                         |    |    |-- price: long (nullable = true)
                         |    |    |-- product_id: string (nullable = true)
                         |-- customer_id: string (nullable = true)
                         |-- date_of_purchase: string (nullable = true)
                         
            returns a new dataframe (df_output) with columns in order = ['customer_id','loyalty_score','product_id','product_category','purchase_count_per_product_id']

        '''
        
        #loading the data first
        df_customers,df_products,df_transactions = self.load_data(customers,products,transactions)
        
        #importing necessary functions
        import pyspark.sql.functions as f
        from pyspark.sql.functions import explode

        #cleaning the df_products
        df_products= df_products.drop('product_description')

        # flattening the data from df_transaction 
        df_trans_flat = self.flatten(df_transactions,'basket').drop('date_of_purchase')

        # rearranging counting the no of duplicate rows in df_trans_flat dataframe
        df_trans_flat = df_trans_flat.select('customer_id','basket.*').drop('price')
        
        #counting the no of duplicate rows in df_trans_flat dataframe
        df_trans_flat = self.count_duplicate_rows(df_trans_flat)

        #renaming the column count to purchase_count_per_product_id
        df_trans_flat = self.rename_column(df_trans_flat,'count','purchase_count_per_product_id')
        
        #creating an intermediate dataframe to outer merge the dataframes df_customers and df_trans_flat on customer_id
        df_customer_trans = self.outer_join(df_customers,df_trans_flat,'customer_id')

        #creating the final output dataframe by outer merging df_products and the intermediate dataframe on product_id
        df_output = self.outer_join(df_customer_trans,df_products,'product_id')

        #rearranging the output dataframe and sorting the data first on customer_id and then on product_id
        df_output = df_output.select('customer_id','loyalty_score','product_id','product_category','purchase_count_per_product_id')
        
        df_output = self.sort(df_output,'customer_id','product_id')

        #Printing the required data
        return df_output



In [450]:
cust_details = SparkApp()

In [451]:
cust_details.load_output('customers.csv','products.csv','transactions.json').show()

+-----------+-------------+----------+----------------+-----------------------------+
|customer_id|loyalty_score|product_id|product_category|purchase_count_per_product_id|
+-----------+-------------+----------+----------------+-----------------------------+
|         C1|           10|        P1|               1|                            2|
|         C1|           10|        P3|               3|                            1|
|         C2|          232|        P1|               1|                            1|
|         C2|          232|        P2|               2|                            1|
|         C3|           23|        P1|               1|                            1|
|         C3|           23|        P2|               2|                            1|
|         C4|           14|        P1|               1|                            1|
|         C4|           14|        P2|               2|                            1|
|         C5|           52|        P1|               1

# Unit Test cases with py.test

In [323]:
import unittest
from pyspark.sql import SparkSession

In [324]:
class PySparkUnitTestBase(unittest.TestCase):
    def setUpClass(self):
        spark = SparkSession\
        .builder\
        .appName('Unit Testing in Pyspark Application')\
        .master('local[*]')\
        .getOrCreate()
        self.spark = spark
    
    def tearDownClass(self):
        self.spark.stop()
        

In [448]:
class PysparkUnitTest(PySparkUnitTestBase):
    def _create_a_SparkSession(self):
        import pyspark
        import findspark
        from pyspark.sql import SparkSession
        import pyspark.sql.functions as f

        spark=SparkSession.builder.appName('Customer Details').getOrCreate()
        return spark
    '''def test_data(self,df1, df2):
        data1 = df1.collect()
        data2 = df2.collect()
        return set(data1) == set(data2)'''
    def test_load_data_case(self,customer,product,transactions,output):
        
        spark = self._create_a_SparkSession()
        expected_output_df = spark.read.csv(output,inferSchema=True,header=True)
        output_df = SparkApp().load_output(customer,product,transactions)
        output_df = output_df.withColumn('purchase_count_per_product_id',f.round(output_df["purchase_count_per_product_id"]).cast('integer'))
        #expected_output_df.show()
        #output_df.show()
        self.assertEqual( output_df.collect(),expected_output_df.collect())
        


In [446]:
PysparkUnitTest().test_load_data_case('customers_test-1.csv','products_test-1.csv','transactions_test-1.json','output_test-1.csv')

+-----------+-------------+----------+----------------+-----------------------------+
|customer_id|loyalty_score|product_id|product_category|purchase_count_per_product_id|
+-----------+-------------+----------+----------------+-----------------------------+
|         C1|           12|        P1|               1|                            1|
|         C1|           12|        P2|               2|                            1|
|         C1|           12|        P3|               3|                            1|
|         C2|           23|        P1|               1|                            1|
|         C2|           23|        P2|               2|                            1|
|         C3|          133|        P3|               3|                            1|
|         C3|          133|        P4|               4|                            1|
|         C4|           94|        P1|               1|                            1|
|         C4|           94|        P4|               4

In [447]:
PysparkUnitTest().test_load_data_case('customers_test-2.csv','products_test-2.csv','transactions_test-2.json','output_test-2.csv')

+-----------+-------------+----------+----------------+-----------------------------+
|customer_id|loyalty_score|product_id|product_category|purchase_count_per_product_id|
+-----------+-------------+----------+----------------+-----------------------------+
|         C1|           12|        P1|               1|                            2|
|         C1|           12|        P3|               3|                            1|
|         C2|           23|        P2|               2|                            2|
|         C3|          133|        P3|               3|                            2|
|         C3|          133|        P4|               4|                            1|
|         C4|           94|      null|            null|                         null|
|         C5|          100|        P1|               1|                            2|
|         C5|          100|        P4|               4|                            1|
+-----------+-------------+----------+----------------