In [0]:
from databricks.labs.dqx.engine import DQEngine
from databricks.sdk import WorkspaceClient
from databricks.labs.dqx.config import InputConfig, OutputConfig
from databricks.labs.dqx.rule import DQRowRule
from databricks.labs.dqx.check_funcs import (
    is_not_null,
    is_in_list,
    is_in_range
)
from typing import List

In [0]:
def get_quality_checks():
    """
    Returns a list of DQX quality checks for NYC Yellow Taxi Trip data.
    
    Returns:
        List[DQRule]: List of data quality check rules
    """
    
    checks = [
        # VendorID checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="VendorID",
            name="VendorID_is_null"
        ),
        DQRowRule(
            criticality="error",
            check_func=is_in_list,
            column="VendorID",
            check_func_kwargs={"allowed": [1, 2, 6, 7]},
            name="VendorID_unknown_value"
        ),  

        # tpep_pickup_datetime checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="tpep_pickup_datetime",
            name="tpep_pickup_datetime_is_null"
        ),

        # tpep_dropoff_datetime checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="tpep_dropoff_datetime",
            name="tpep_dropoff_datetime_is_null"
        ),

        # passenger_count checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="passenger_count",
            name="passenger_count_is_null"
        ),
        DQRowRule(
            criticality="warn",
            check_func=is_in_range,
            column="passenger_count",
            check_func_kwargs={"min_limit": 1, "max_limit": 4},
            name="passenger_count_is_in_range"
        ),

        # trip_distance checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="trip_distance",
            name="trip_distance_is_null"
        ),
        DQRowRule(
            criticality="warn",
            check_func=is_in_range,
            column="trip_distance",
            check_func_kwargs={"min_limit": 0, "max_limit": 40},
            name="trip_distance_is_in_range"
        ),

        # RatecodeID checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="RatecodeID",
            name="RatecodeID_is_null"
        ),
        DQRowRule(
            criticality="error",
            check_func=is_in_list,
            column="RatecodeID",
            check_func_kwargs={"allowed": [1, 2, 3, 4, 5, 6]},
            name="RatecodeID_other_value"
        ),  
              
        # store_and_fwd_flag checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="store_and_fwd_flag",
            name="store_and_fwd_flag_is_null"
        ),
        DQRowRule(
            criticality="error",
            check_func=is_in_list,
            column="store_and_fwd_flag",
            check_func_kwargs={"allowed": ["Y", "N"]},
            name="store_and_fwd_flag_other_value"
        ),

        # PULocationID checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="PULocationID",
            name="PULocationID_is_null"
        ),
        
        # DOLocationID checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="DOLocationID",
            name="DOLocationID_is_null"
        ),
        
        # payment_type checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="payment_type",
            name="payment_type_is_null"
        ),
        DQRowRule(
            criticality="error",
            check_func=is_in_list,
            column="payment_type",
            check_func_kwargs={"allowed": [0, 1, 2, 3, 4, 5, 6]},
            name="payment_type_other_value"
        ),

        # fare_amount checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="fare_amount",
            name="fare_amount_is_null"
        ),
        
        # extra checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="extra",
            name="extra_is_null"
        ),
        
        # mta_tax checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="mta_tax",
            name="mta_tax_is_null"
        ),
        
        # tip_amount checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="tip_amount",
            name="tip_amount_is_null"
        ),
        
        # tolls_amount checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="tolls_amount",
            name="tolls_amount_is_null"
        ),
        
        # improvement_surcharge checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="improvement_surcharge",
            name="improvement_surcharge_is_null"
        ),
        
        # total_amount checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="total_amount",
            name="total_amount_is_null"
        ),
        
        # congestion_surcharge checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="congestion_surcharge",
            name="congestion_surcharge_is_null"
        ),
        
        # Airport_fee checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="Airport_fee",
            name="Airport_fee_is_null"
        ),
        
        # cbd_congestion_fee checks
        DQRowRule(
            criticality="error",
            check_func=is_not_null,
            column="cbd_congestion_fee",
            name="cbd_congestion_fee_is_null"
        )
    ]
    
    return checks

In [0]:
dq_engine = DQEngine(WorkspaceClient())

input_df = spark.read.table("nprod.tlc.yellow_trip")

output_df = dq_engine.apply_checks(input_df, get_quality_checks())

output_df.write.mode("overwrite").saveAsTable("nprod.tlc.yellow_trip_valid")