In [0]:
from pyspark.sql.functions import col, to_date
from pyspark.sql.types import IntegerType, DateType, DoubleType, StringType

dbutils.widgets.text("sales_table_name", "bronze_sales")
dbutils.widgets.text("database_name", "module2_db")
dbutils.widgets.dropdown("force_valid_status_for_testing", "False", ["True", "False"])

sales_table = dbutils.widgets.get("sales_table_name")
db_name = dbutils.widgets.get("database_name")
force_valid_status = dbutils.widgets.get("force_valid_status_for_testing") == "True"
full_sales_table_name = f"{db_name}.{sales_table}"

try:
    df_sales_raw = spark.table(full_sales_table_name)
except Exception as e:
    dbutils.jobs.taskValues.set(key="validation_status", value="ERROR_READING_TABLE")
    dbutils.jobs.taskValues.set(key="validation_error_message", value=str(e))
    dbutils.notebook.exit(f"Failed to read table: {full_sales_table_name}")

# Cast columns to expected types
df_sales = df_sales_raw
df_sales = df_sales.withColumn("OrderDate_casted", col("OrderDate").cast(DateType()))
df_sales = df_sales.withColumn("Quantity_casted", col("Quantity").cast(IntegerType()))
df_sales = df_sales.withColumn("UnitPrice_casted", col("UnitPrice").cast(DoubleType()))

# Check for missing required columns after casting
required_cols = ["SalesOrderNumber", "OrderDate_casted", "CustomerID", "Item", "Quantity_casted", "UnitPrice_casted"]
missing_cols_source = [c.replace("_casted", "") for c in required_cols if c.replace("_casted", "") not in df_sales_raw.columns]

if missing_cols_source:
    error_message = f"Missing source columns: {', '.join(missing_cols_source)}"
    dbutils.jobs.taskValues.set(key="validation_status", value="INVALID_SCHEMA")
    dbutils.jobs.taskValues.set(key="validation_error_message", value=error_message)
    dbutils.notebook.exit(error_message)

# Perform data quality checks
null_sales_order_number_count = df_sales.filter(col("SalesOrderNumber").isNull()).count()
null_order_date_count = df_sales.filter(col("OrderDate_casted").isNull()).count()
null_customer_id_count = df_sales.filter(col("CustomerID").isNull()).count()
null_item_count = df_sales.filter(col("Item").isNull()).count()
invalid_quantity_count = df_sales.filter(col("Quantity_casted").isNull() | (col("Quantity_casted") <= 0)).count()
null_unit_price_count = df_sales.filter(col("UnitPrice_casted").isNull() | (col("UnitPrice_casted") < 0)).count()

validation_status_actual = "VALID"
error_message = ""
if null_sales_order_number_count > 0: validation_status_actual = "INVALID"; error_message += f"Null SalesOrderNumber. "
if null_order_date_count > 0: validation_status_actual = "INVALID"; error_message += f"Null/invalid OrderDate. "
if null_customer_id_count > 0: validation_status_actual = "INVALID"; error_message += f"Null CustomerID. "
if null_item_count > 0: validation_status_actual = "INVALID"; error_message += f"Null Item. "
if invalid_quantity_count > 0: validation_status_actual = "INVALID"; error_message += f"Null/non-positive Quantity. "
if null_unit_price_count > 0: validation_status_actual = "INVALID"; error_message += f"Null/negative UnitPrice. "

final_validation_status = "VALID" if force_valid_status else validation_status_actual
final_error_message = "Forced VALID for testing" if force_valid_status else error_message

dbutils.jobs.taskValues.set(key="validation_status", value=final_validation_status)
dbutils.jobs.taskValues.set(key="source_table_record_count", value=df_sales_raw.count())
if final_validation_status == "INVALID":
    dbutils.jobs.taskValues.set(key="validation_error_message", value=final_error_message)
dbutils.notebook.exit(final_validation_status)