## Test Cases for Data Retrieval Statements used by Apache Spark
**Spark SQL 3.3 Data Retrieval Statements Reference** 
https://spark.apache.org/docs/3.3.0/sql-ref-syntax.html#data-retrieval-statements

To store these results configure data <mark>**storage account and container**</mark>.

In [None]:
!pip install unittest-xml-reporting xmltodict

## Configure Result Storage Location

In [None]:
storage_account=""
result_container=""

## Initialize Common Variables for the test run

In [None]:
import time

# Don't change these variables
TEST_SUITE= "SPARK_SQL_DML"
RESULT_FILE_NAME="data_retrieval_test_result.parquet"
RAW_RESULT_FILE_NAME="raw_data_retrieval_test_result.parquet"
# Test Run ID
TEST_RUN_ID= round(time.time()*1000)
# Test platform
PLATFORM = "nameoftheplatform"
# Prefix for all tables
PREFIX = PLATFORM
SUFFIX = TEST_RUN_ID
# Spark SQL function
sql=spark.sql

### Set Common Spark Configurations

In [None]:
sql("set hive.exec.dynamic.partition.mode=nonstrict")

## DML - Insert Table

In [None]:
import unittest
import numpy as np

class DataRetrievalInsertTableTest(unittest.TestCase):
    """Test cases for data retrieval"""

    customer_table_name=f"{PREFIX}_customer_insert_table_{SUFFIX}"
    tx_table_name=f"{PREFIX}_transaction_insert_table_{SUFFIX}"
    view_name = f"{PREFIX}_testtx_customer_{SUFFIX}"

    
    @classmethod
    def setUpClass(cls):
        customer_table_name_sql = f"CREATE TABLE {cls.customer_table_name} (cust_id INT, name STRING, age INT) \
                        PARTITIONED BY (state STRING)"
        tx_table_name_sql = f"CREATE TABLE {cls.tx_table_name} (id INT, cust_id INT, tx_date STRING, total_amount Float) \
                        PARTITIONED BY (year INT)"
        try:
            sql(customer_table_name_sql)
            sql(tx_table_name_sql)
            sql(f"INSERT INTO {cls.customer_table_name} VALUES \
                 (1,'abcd',20,'TX'),(2,'bbcd',18,'CA'),(3,'c',25,'TX'),(4,'d',20,'WA'),(5,'e',18,'AK'),(6,'f_cd',NULL,'CA');")
            sql(f"INSERT INTO {cls.tx_table_name} VALUES \
                 (1,1,'2019-01-01',40.56,2019),(2,2,'2022-11-01',10.56,2022),(3,1,'2023-12-05',10.34,2023),(4,3,'2021-11-05',8.00,2021),(5,6,'2019-10-12',12.45,2019),(6,2,'2023-10-12',15.45,2023);")
            
            sql(f"CREATE OR REPLACE VIEW {cls.view_name} AS SELECT A.*,B.name, b.state FROM {cls.tx_table_name} A \
                                INNER JOIN {cls.customer_table_name} B ON A.cust_id=B.cust_id")
        except Exception as ex:
            msg={'command':'InsertTable Setup failed','status':'fail'}
            cls.fail(f"{msg}")
    

    def test_dataretrieval_001_cte_join_query(self):
        """Data retrieval using CTE and INNER JOIN"""
        sql_cmd = f"WITH t(custid, custname) AS (SELECT cust_id, name FROM {self.customer_table_name} WHERE state='TX') \
                        SELECT custname,total_amount,year FROM {self.tx_table_name} INNER JOIN t ON t.custid= {self.tx_table_name}.cust_id;"
        try:
            df=sql(sql_cmd)
            self.assertEqual(df.count(),3)
            self.assertEqual(len(df.columns),3)
        except Exception as ex:
            msg={'command':'INNER JOIN CTE (Common Table Expression)','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_002_cte_view_query(self):
        """Data retrieval using CTE with View statement"""
        view_name = f"{PREFIX}_cte_view_{SUFFIX}"
        sql_cmd = f"CREATE VIEW {view_name} AS WITH t(custid, custname) AS (SELECT cust_id, name FROM {self.customer_table_name} WHERE state='TX') \
                        SELECT custid,custname FROM t"
        try:
            sql(sql_cmd)
            df=sql(f"SELECT * FROM {view_name}")
            self.assertEqual(df.count(),2)
            self.assertEqual(len(df.columns),2)
        except Exception as ex:
            msg={'command':'Create View CTE (Common Table Expression)','status':'fail'}
            self.fail(f"{msg}")

        finally:
            sql(f"DROP VIEW IF EXISTS {view_name}")

    def test_dataretrieval_003_cluster_by_query(self):
        """Data retrieval using CLUSTER BY statement"""
        sql_cmd_without_cluster = f"SELECT * FROM {self.customer_table_name}"
        sql_cmd_with_cluster = f"SELECT * FROM {self.customer_table_name} CLUSTER BY cust_id"
        current_partition_conf =  spark.conf.get("spark.sql.shuffle.partitions",'200')
        try:
            # without cluster
            sql("SET spark.sql.shuffle.partitions = 2;")
            df_without_cluster=sql(sql_cmd_without_cluster)
            columns_without_cluster=df_without_cluster.toPandas()
            customerIds_without_cluster=list(columns_without_cluster['cust_id'])
            # with cluster The rows are sorted based on cust_id within each partition
            df_with_cluster=sql(sql_cmd_with_cluster)
            columns_with_cluster=df_with_cluster.toPandas()
            customerIds_with_cluster=list(columns_with_cluster['cust_id'])
            self.assertNotEqual(customerIds_with_cluster,customerIds_without_cluster)
            self.assertEqual(customerIds_with_cluster,[1,2,3,4,5,6])
        except Exception as ex:
            msg={'command':'SELECT with Cluster By','status':'fail'}
            self.fail(f"{msg}")

        finally:
            # rollback configuration
            spark.conf.set("spark.sql.shuffle.partitions",current_partition_conf)

    def test_dataretrieval_004_group_filter_by_query(self):
        """Data retrieval using GROUP BY FILTER WHERE statement"""
        sql_cmd = f"SELECT cust_id,sum(total_amount) FILTER (\
                WHERE year IN (2019,2021) \
              ) AS `sum(amount)` FROM {self.tx_table_name} GROUP by cust_id ORDER BY cust_id"
        try:
            df=sql(sql_cmd)
            self.assertGreater(df.count(),1)
        except Exception as ex:
            msg={'command':'GROUP BY FILTER','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_005_first_last_query(self):
        """Data retrieval using FIRST AND LAST"""
        sql_cmd = f"SELECT FIRST(id) as firstid, LAST(id) as lastId FROM {self.tx_table_name}"
        try:
            df=sql(sql_cmd)
            self.assertEqual(df.count(),1)
            self.assertEqual(len(df.columns),2)
        except Exception as ex:
            msg={'command':'GROUP BY FILTER','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_006_group_having_query(self):
        """Data retrieval using GROUP BY Having"""
        sql_cmd = f"SELECT year,sum(total_amount) as sumtotal  FROM {self.tx_table_name} GROUP by Year Having Year>2019 order by year"
        try:
            df=sql(sql_cmd)
            pandas_df = df.toPandas()
            self.assertEqual(list(pandas_df['year']),[2021,2022,2023])
            self.assertEqual(df.count(),3)
        except Exception as ex:
            msg={'command':'GROUP BY Having Order By','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_007_inline_table(self):
        """Data retrieval Inline Table"""
        sql_cmd = f"SELECT country, {self.customer_table_name}.state,name FROM {self.customer_table_name} \
                  INNER JOIN VALUES ('USA', 'TX'),('USA', 'CA'), ('USA','WA'), ('USA','AK') as countrystate(country, state) ON countrystate.state={self.customer_table_name}.state"
        try:
            df=sql(sql_cmd)
            self.assertEqual(len(df.columns),3)
        except Exception as ex:
            msg={'command':'Inline Table Join','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_008_file_query(self):
        """Data retrieval From file query"""
        file_name = f"Files/test_dataretrieval_008_file_query/{PREFIX}_parquet_query_read_{SUFFIX}.parquet"
        sql_cmd = f"SELECT * FROM parquet.`{file_name}`"
        try:
            df = sql(f"SELECT * FROM {self.customer_table_name}")
            df.write.parquet(file_name)
            read_file_df=sql(sql_cmd)
            self.assertEqual(df.count(), read_file_df.count())
        except Exception as ex:
            msg={'command':'Query a file with a parquet format ','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_009_like_predicate1(self):
        """Data retrieval Like predicate filter"""
        try:
            df = sql(f"SELECT * FROM {self.customer_table_name} WHERE state like 'T%'")
            self.assertEqual(df.count(), 2)
        except Exception as ex:
            msg={'command':'Like predicate filter search pattern %','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_009_like_predicate2(self):
        """Data retrieval Like predicate filter"""
        try:
            df = sql(f"SELECT * FROM {self.customer_table_name} WHERE name like '_bcd' ")
            self.assertEqual(df.count(), 2)
        except Exception as ex:
            msg={'command':'Like predicate filter search pattern underscore','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_009_like_predicate3(self):
        """Data retrieval Like predicate filter"""
        try:
            df = sql(f"SELECT * FROM {self.customer_table_name} WHERE name like '%\\_%' ")
            self.assertEqual(df.count(), 1)
        except Exception as ex:
            msg={'command':'Like predicate filter esc_char','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_010_tablesample(self):
        """Data retrieval Table Sample"""
        try:
            df = sql(f"SELECT * FROM {self.customer_table_name} TABLESAMPLE (50 PERCENT)")
            row_count=df.count()
            self.assertGreater(row_count, 0)
            df = sql(f"SELECT * FROM {self.customer_table_name} TABLESAMPLE (5 ROWS)")
            row_count=df.count()
            self.assertEqual(row_count, 5)
            df = sql(f"SELECT * FROM {self.customer_table_name} TABLESAMPLE (BUCKET 2 OUT OF 6)")
            row_count=df.count()
            self.assertGreater(row_count, 0)
        except Exception as ex:
            msg={'command':'TABLESAMPLE query','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_011_orderby_null_first_last(self):
        """Data retrieval Order By NULL FIRST and NULL LAST"""
        try:
            df_null_first = sql(f"SELECT * FROM {self.customer_table_name} ORDER BY age DESC NULLS FIRST ")
            df_null_first_p=df_null_first.toPandas()
            age_null_first=list(df_null_first_p['age'])
            self.assertTrue(np.isnan(age_null_first[0]))
            df_null_last = sql(f"SELECT * FROM {self.customer_table_name} ORDER BY age DESC NULLS LAST ")
            df_null_last_p=df_null_last.toPandas()
            age_null_last=list(df_null_last_p['age'])
            self.assertTrue(np.isnan(age_null_last[len(age_null_last)-1]))
        except Exception as ex:
            msg={'command':'Order By NULL FIRST and NULL LAST','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_012_limit(self):
        """Data retrieval LIMIT ALL"""
        try:
            df_limit_2 = sql(f"SELECT * FROM {self.customer_table_name} LIMIT 2")
            self.assertEqual(df_limit_2.count(),2)
            df_limit_all = sql(f"SELECT * FROM {self.customer_table_name} LIMIT ALL")
            self.assertEqual(df_limit_all.count(),6)
        except Exception as ex:
            msg={'command':'LIMIT and LIMIT ALL','status':'fail'}
            self.fail(f"{msg}")
 
    def test_dataretrieval_012_limit(self):
        """Data retrieval LIMIT ALL"""
        try:
            df_limit_2 = sql(f"SELECT * FROM {self.customer_table_name} LIMIT 2")
            self.assertEqual(df_limit_2.count(),2)
            df_limit_all = sql(f"SELECT * FROM {self.customer_table_name} LIMIT ALL")
            self.assertEqual(df_limit_all.count(),6)
        except Exception as ex:
            msg={'command':'LIMIT and LIMIT ALL','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_013_window_dense_rank(self):
        """Window functions - DENSE_RANK"""
        try:
            # Test dense rank
            df_dense_rank = sql(f"SELECT name, state, total_amount, DENSE_RANK() OVER (PARTITION BY state ORDER BY total_amount \
                                        ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS dense_rank FROM {self.view_name}")
            df_dense_rank_p=df_dense_rank.toPandas()
            self.assertEqual(list(df_dense_rank_p['state']),['CA','CA','CA','TX','TX','TX'])
            self.assertEqual(list(df_dense_rank_p['dense_rank']),[1,2,3,1,2,3])
        except Exception as ex:
            msg={'command':'Windows DENSE_RANK','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_014_window_cum_dist(self):
        """Window functions - CUME_DIST"""
        try:
            df = sql(f"SELECT name, state, total_amount, CUME_DIST() OVER (PARTITION BY state ORDER BY total_amount \
                                        RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cume_dist FROM {self.view_name}")
            df_p=df.toPandas()
            self.assertEqual(sum(list(df_p['cume_dist'])),4)
            self.assertEqual(len(df_p),6)
        except Exception as ex:
            msg={'command':'Windows CUME_DIST','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_014_window_min_over(self):
        """Window functions - MIN OVER"""
        try:
            df = sql(f"SELECT name, state, total_amount, Min(total_amount) OVER (PARTITION BY state ORDER BY total_amount) AS min FROM {self.view_name}")
            df_p=df.toPandas()
            l_min = list(df_p['min'])
            self.assertEqual(list(map(lambda x:round(x,2), l_min)),[10.56,10.56,10.56,8,8,8])
        except Exception as ex:
            msg={'command':'Windows MIN OVER','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_015_window_lag_lead(self):
        """Window functions - LAG AND LEAD"""
        try:
            df = sql(f"SELECT name, state, total_amount, LAG(total_amount) OVER (PARTITION BY state ORDER BY total_amount) AS lag,\
                                   LEAD(total_amount,1,0) OVER (PARTITION BY state ORDER BY total_amount) as lead  FROM {self.view_name}")
            df_p=df.toPandas()
            self.assertTrue(np.isnan(list(df_p['lag'])[0]))
            self.assertEqual(list(df_p['lead'])[len(df_p)-1],0.0)
        except Exception as ex:
            msg={'command':'Windows LAG LEAD','status':'fail'}
            self.fail(f"{msg}")

    def test_dataretrieval_016_window_lateral_view(self):
        """LATERAL VIEW"""
        try:
            df = sql(f"SELECT * FROM {self.customer_table_name} LATERAL VIEW EXPLODE(ARRAY('Y', 'N')) tableName AS test_flag")
            # total 6 rows multiply by 2
            self.assertEqual(df.count(),12)
            df = sql(f"SELECT * FROM {self.customer_table_name} LATERAL VIEW EXPLODE(ARRAY()) AS empty_col")
            self.assertEqual(df.count(),0)
            df = sql(f"SELECT * FROM {self.customer_table_name} LATERAL VIEW OUTER EXPLODE(ARRAY()) AS empty_col2")
            self.assertEqual(df.count(),6)
        except Exception as ex:
            msg={'command':'LATERAL VIEW ','status':'fail'}
            self.fail(f"{msg}")

    @classmethod
    def tearDownClass(cls):
        """tear down"""
        sql(f"DROP TABLE IF EXISTS {cls.customer_table_name}")
        sql(f"DROP TABLE IF EXISTS {cls.tx_table_name}")
        sql(f"DROP VIEW IF EXISTS {cls.view_name}")

#TODO: DISTRIBUTE BY, Hints, PIVOT 

### Execute Test Case

In [None]:
import io
import xmlrunner
loader = unittest.TestLoader()
suite  = unittest.TestSuite()

# add tests to the test suite
suite.addTests(loader.loadTestsFromTestCase(DataRetrievalInsertTableTest))

# initialize a runner, pass it your suite and run it
out = io.BytesIO()
runner = xmlrunner.XMLTestRunner(output=out)
result = runner.run(suite)

## Report for Test

In [None]:
from pyspark.sql.functions import col, explode,isnull,from_json, expr, to_json, coalesce, lit
from pyspark.sql.types import StructType,StructField,StringType
import json
import xmltodict

dict_result=xmltodict.parse(out.getvalue())
json_result = json.loads(json.dumps(dict_result,indent=4).replace('@',''))
test_suites=json_result['testsuites']['testsuite']

df = spark.read.json(sc.parallelize([test_suites]))
fail_schema = StructType([
  StructField("command", StringType(), True),
  StructField("status", StringType(),  True)
])

test_cases_df= df.withColumn('ts',explode('testcase')).drop(col('testcase'))

if "failure:" in test_cases_df.schema.simpleString():
 explode_df= test_cases_df.withColumn('fail',from_json(col('ts.failure.message'),fail_schema)).drop(col('ts.failure'))
else:
 explode_df= test_cases_df.withColumn("fail",from_json(expr("to_json(named_struct('command', '', 'status', 'pass'))"),fail_schema))
 
df_test_result=explode_df.select(col("errors").alias("errorInSuite"),col("failures").alias("failedInSuite"),col("name").alias("suitename"),\
      "skipped",col("tests").alias("totalTest"), col("timestamp").alias("executionTime"),col("ts.name").alias("testCaseName"), \
       col("ts.time").alias("testCaseTime"),coalesce(col("fail.command"), lit("")).alias("failcommand"),coalesce(col("fail.status"), lit("pass")).alias("status"))

if (len(storage_account)>0 and len(result_container)>0):
    # save result to storage
    storage_path = f"abfs://{result_container}@{storage_account}.dfs.core.windows.net/{TEST_RUN_ID}/{PLATFORM}/{TEST_SUITE}"
    # write raw results
    df.write.parquet(f"{storage_path}/{RAW_RESULT_FILE_NAME}")
    # write transformed results
    df_test_result.write.parquet(f"{storage_path}/{RESULT_FILE_NAME}") 
else:
    print("configure storage path to store results")
    df_test_result.show(200,False)    