### Import necessary libraries

In [0]:
%run "./Assignment - Ingestion"

In [0]:
from pyspark.sql import SparkSession

In [0]:
    spark = SparkSession.builder \
            .appName("pytest") \
            .getOrCreate()

In [0]:
### Validate Product Schema
schema = StructType([
    StructField("Product ID", StringType(), True),
    StructField("Category", StringType(), True),
    StructField("Sub-Category", StringType(), True),
    StructField("Product Name", StringType(), True),
    StructField("State", StringType(), True),
    StructField("Price per product", DoubleType(), True),
])

productData = [("P001", "Electronics", "Mobile", "Samsung S23", "New York", 299.99),
            ("P002", "Furniture", "Chair", "Featherlite", "California", 99.99)]

df = spark.createDataFrame(productData, schema)

In [0]:
### Validate record count in Product bronze table

schemaName = "raw"
test_product_raw_df = fetchTableData(schemaName, "raw_product")

assert test_product_raw_df.count() == product_df.count(), "Not all records have been loaded in raw table for product file"

In [0]:
### Validate record count in Product bronze table

schemaName = "enriched"
test_product_enriched_df = fetchTableData(schemaName, "enriched_product")

assert test_product_enriched_df.count() == product_df.count(), "Not all records have been loaded in enriched table for product file"

In [0]:
### Validate Customer Schema

expectedCustomerSchema = StructType([
    StructField("Customer ID", StringType(), True),
    StructField("Customer Name", StringType(), True),
    StructField("email", StringType(), True),
    StructField("phone", StringType(), True),
    StructField("address", StringType(), True),
    StructField("Segment", StringType(), True),
    StructField("Country", StringType(), True),
    StructField("City", StringType(), True),
    StructField("State", StringType(), True),
   StructField("Postal Code", DoubleType(), True),
    StructField("Region", StringType(), True)
])

data = [("C001", "Ross Specter", "abc@def.com", "+91 89876765", "5th Street", "Consumer", "India","Mumbai", "Maharashtra", 12345.0, "West"),
            ("C002", "Daniel Pinto", "ghi@jkl.com", "+91 6543567", "7th Main", "Home Office", "India", "Bengaluru", "Karnataka", 10034.0, "South")]

df = spark.createDataFrame(data, expectedCustomerSchema)

assert df.schema == customer_df.schema, "customer schema mismatch"

In [0]:
### Validate record count in Customer Table
schemaName = "raw"
customer_raw_df = fetchTableData(schemaName, "raw_customer")

assert customer_raw_df.count() == customer_df.count(), "Not all records have been loaded in raw table for customer file"

In [0]:
### Validate Customer enriched table count

schemaName = "enriched"
test_customer_enriched_df = fetchTableData(schemaName, "enriched_customer")
assert customer_df.count() == test_customer_enriched_df.count(), "Not all records have been loaded in enriched table for customer file"

In [0]:
expectedOrderSchema = StructType([
    StructField("Row ID", IntegerType(), True),
    StructField("Order ID", StringType(), True),
    StructField("Order Date", StringType(), True),
    StructField("Ship Date", StringType(), True),
    StructField("Ship Mode", StringType(), True),
    StructField("Customer ID", StringType(), True),
    StructField("Product ID", StringType(), True),
    StructField("Quantity", IntegerType(), True),
    StructField("Price", StringType(), True),
    StructField("Discount", DoubleType(), True),
    StructField("Profit", DoubleType(), True)
])

data = [(9992,"CA-2014-140662","17/11/2014","19/11/2014","First Class","TS-21205","OFF-AP-10001242",3,241.44,0.0,72.432)]

df = spark.createDataFrame(data, expectedOrderSchema)

assert df.schema == orders_df.schema, "Orders Schema mismatch"

In [0]:
### Validate orders table record count

schemaName = "raw"
orders_raw_df = fetchTableData(schemaName, "raw_orders")

assert orders_raw_df.count() == orders_df.count(), "Not all records have been loaded in raw table for orders file"

In [0]:
### Validate orders table record count in enriched table

schemaName = "enriched"
test_orders_enriched_table_df = fetchTableData(schemaName, "enriched_orders")

assert orders_df.count() == test_orders_enriched_table_df.count(), "Not all records have been loaded in enriched table for orders file"

In [0]:
### Validate Profit columns' precision
schemaName = "enriched"
test_orders_enriched_table_df = fetchTableData(schemaName, "enriched_orders")

rounded_df = test_orders_enriched_table_df.withColumn("rounded_profit", F.round(F.col("profit"), 2))

for row in rounded_df.collect():
        assert row["profit"] == row["rounded_profit"], f"Profit value {row['profit']} is not rounded to 2 decimal places."

rounded_to_one_df = test_orders_enriched_table_df.withColumn("rounded_to_one_profit", F.round(F.col("profit"), 1))

incorrectly_rounded = rounded_to_one_df.filter(F.col("profit") != F.col("rounded_to_one_profit"))

assert incorrectly_rounded.count() > 0, "All 'profit' values are rounded"

In [0]:
### Validate if required customer and product columns exist in the orders enriched table

schemaName = "enriched"
test_orders_enriched_table_df = fetchTableData(schemaName, "enriched_orders")

required_columns = {"customer_name", "country", "category", "sub_category"}

assert required_columns.issubset(set(test_orders_enriched_table_df.columns)), f"Some required columns are missing: {required_columns.difference(test_orders_enriched_table_df.columns)}"

In [0]:

test_profit_gold_table_df = fetchTableData(schemaName, "gold_profit")
test_orders_enriched_table_df = fetchTableData(schemaName, "enriched_orders")

assert 

