In [0]:
%run ../EcommerceProject/serving

In [0]:
import unittest

In [0]:
class TestOrdersEnriched(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.spark = spark

    def setUp(self):
        self.raw_orders = self.spark.table("ecommerceproject.default.raw_orders")
        self.orders_enriched = self.spark.table("ecommerceproject.enriched.orders_enriched")
        self.dim_customers = self.spark.table("ecommerceproject.enriched.dim_customers")
        self.dim_products = self.spark.table("ecommerceproject.enriched.dim_products")

    
    #Data Completeness
    def test_row_count_matches_raw(self):
        self.assertEqual(
            self.raw_orders.count(),
            self.orders_enriched.count(),
            "Row count mismatch expected {} got {}".format(
                self.raw_orders.count(), self.orders_enriched.count()
            )
        )
    #Only Current Customer Record used in Curated Layer
    def test_only_current_customers_joined(self):
        joined_df = self.orders_enriched.join(
            self.dim_customers,
            self.orders_enriched.customer_id == self.dim_customers.customer_id,
            "inner"
        )
        self.assertEqual(joined_df.filter(col("is_current") == False).count(),0,
            "outdated customer records joined"
        )
    #Only Current Products Record used in Curated Layer
    def test_only_current_product_joined(self):
        joineddf = self.orders_enriched.join(
            self.dim_products,
            self.orders_enriched.product_id == self.dim_products.product_id,
            "inner"
        )

        self.assertEqual(joineddf.filter(col("is_current") == False).count(),0,
            "outdated product records joined"
        )
    #Ensure Profit Round up to 2 decimal places
    def test_profit_rounded_to_two_decimals(self):
        invalid = self.orders_enriched.filter(
            col("profit") != round(col("profit"), 2)
        ).count()
        self.assertEqual(invalid, 0, "Profit not rounded to 2 decimals")
    #Null Check Order_id, Product_id, Cutomer_id, Order_Year
    def test_mandatory_columns_not_null(self):
        mandatory_cols = [
            "order_id",
            "product_id",
            "customer_id",
            "order_year"
        ]

        for col_name in mandatory_cols:
            null_count = self.orders_enriched.filter(
                col(col_name).isNull()
            ).count()

            self.assertEqual(
                null_count,
                0,
                f"Null values found in mandatory column: {col_name}"
            )
    #Business Value: Sub Category With Category
    def test_subcategory_without_category_not_allowed(self):
        faulty_df = self.orders_enriched.filter(
            col("category").isNull() & col("sub_category").isNotNull()
        ).count()

        self.assertEqual(
            faulty_df,
            0,
            "Sub-category exists without category"
        )

In [0]:
suite = unittest.TestLoader().loadTestsFromTestCase(TestOrdersEnriched)
unittest.TextTestRunner(verbosity=2).run(suite)
