In [0]:
from pyspark.sql.functions import col, when, lit, concat ,rtrim ,upper, trim , current_date
from pyspark.sql.types import IntegerType, DoubleType, TimestampType
from pyspark.sql.functions import split, regexp_replace, col
from pyspark.sql.functions import split, regexp_replace, col



In [0]:
dbutils.widgets.text('sfURL', '')
dbutils.widgets.text('sfUser', '')
dbutils.widgets.text('sfDatabase', '')
dbutils.widgets.text('sourceSchema', '')
dbutils.widgets.text('sfWarehouse', '')
dbutils.widgets.text('sfRole', '')
dbutils.widgets.text('targetSchema','')
dbutils.widgets.text('srcTable','')
dbutils.widgets.text('targetTable','')


sfURL = dbutils.widgets.get('sfURL')
sfUser = dbutils.widgets.get('sfUser')
sfPassword = dbutils.secrets.get(scope="snowflake-secrets", key="sf-password")
sfDatabase = dbutils.widgets.get('sfDatabase')
sourceSchema = dbutils.widgets.get('sourceSchema')
srcTable = dbutils.widgets.get('srcTable')
sfWarehouse = dbutils.widgets.get('sfWarehouse')
sfRole = dbutils.widgets.get('sfRole')
targetSchema = dbutils.widgets.get('targetSchema')
targetTable =  dbutils.widgets.get('targetTable')

In [0]:
sfOptions = {
    "sfURL": sfURL,
    "sfUser": sfUser,
    "sfPassword": sfPassword,
    "sfDatabase": sfDatabase,
    "sfSchema": sourceSchema,
    "sfWarehouse": sfWarehouse,
    "sfRole": sfRole
}

sfOptionsTarget = {
    "sfURL": sfURL,
    "sfUser": sfUser,
    "sfPassword": sfPassword,
    "sfDatabase": sfDatabase,
    "sfSchema": targetSchema,
    "sfWarehouse": sfWarehouse,
    "sfRole": sfRole
}

In [0]:
try:
    if srcTable == 'CHICAGO_FOOD_INSPECTIONS_RAW':
        try:
            #Read Chicago Raw
            df_raw = spark.read.format("snowflake").options(**sfOptions).option("dbtable", srcTable).load()
            # Drop duplicates from all columns
            df_final = df_raw.dropDuplicates()

            # Display the final row count
            row_count = df_final.count()
            print(f"Total rows after dropping duplicates: {row_count}")
            df_typed = df_final \
                        .withColumn("Inspection_ID", col("Inspection_ID").cast(IntegerType())) \
                        .withColumn("License_Number", col("License_Number").cast(DoubleType())) \
                        .withColumn("Zip", col("Zip").cast(IntegerType())) \
                        .withColumn("Latitude", col("Latitude").cast(DoubleType())) \
                        .withColumn("Longitude", col("Longitude").cast(DoubleType())) \
                        .withColumn("Location", col("Location").cast(DoubleType())) \
                        .withColumn("Inspection_Date", col("Inspection_Date").cast(TimestampType()))

            df_cleaned = df_typed.withColumn("Risk", when(col("Risk").isNull(), "Unknown").otherwise(col("Risk"))) \
               .withColumn("Violations", when(col("Violations").isNull(), "Unknown").otherwise(col("Violations"))) \
               .withColumn("AKA_Name", when(col("AKA_Name").isNull(), "Unknown").otherwise(col("AKA_Name"))) \
               .withColumn("Facility_Type", when(col("Facility_Type").isNull(), "Unknown").otherwise(col("Facility_Type"))) \
               .withColumn("License_Number", when((col("License_Number").isNull()) | (col("License_Number") == 0), -1).otherwise(col("License_Number"))) \
               .withColumn("Zip", when(col("Zip").isNull(), -1).otherwise(col("Zip")))
            # Step 1: Get the mean of Latitude and Longitude columns
            mean_lat_row = df_cleaned.selectExpr("avg(Latitude)").first()
            mean_long_row = df_cleaned.selectExpr("avg(Longitude)").first()

            # Step 2: Extract float values safely
            mean_lat = float(mean_lat_row[0]) if mean_lat_row and mean_lat_row[0] is not None else 0.0
            mean_long = float(mean_long_row[0]) if mean_long_row and mean_long_row[0] is not None else 0.0

            # Step 3: Replace NULLs in Latitude, Longitude, and Location
            df_cleaned = df_cleaned.withColumn(
                "Latitude",
                when(col("Latitude").isNull(), lit(mean_lat)).otherwise(col("Latitude"))
            ).withColumn(
                "Longitude",
                when(col("Longitude").isNull(), lit(mean_long)).otherwise(col("Longitude"))
            ).withColumn(
                "Location",
                when(col("Location").isNull(),
                    concat(lit("("), lit(mean_lat), lit(", "), lit(mean_long), lit(")"))
                ).otherwise(col("Location"))
            )

            # Step 4: Count nulls for validation
            null_lat_count = df_cleaned.filter(col("Latitude").isNull()).count()
            null_long_count = df_cleaned.filter(col("Longitude").isNull()).count()
            null_location_count = df_cleaned.filter(col("Location").isNull()).count()

            print(f"NULL Latitude count: {null_lat_count}")
            print(f"NULL Longitude count: {null_long_count}")
            print(f"NULL Location count: {null_location_count}")

            # Filter rows where Violations ends with whitespace
            trailing_ws_count = df_cleaned.filter(col("Violations").rlike(r"\s+$")).count()

            print(f"Rows with trailing whitespaces in 'Violations': {trailing_ws_count}")

            df_final_cleaned = df_cleaned.withColumn("Violations", rtrim(col("Violations")))
            # Step 1: Normalize City (uppercase + trimmed)
            df_cleaned = df_final_cleaned.withColumn("City", upper(trim(col("City")))) \
                        .withColumn("State", upper(trim(col("State"))))


            # Step 2: Fix Chicago-like values (replace all known variants with "CHICAGO")
            df_cleaned = df_cleaned.withColumn(
                "City",
                when(col("City").rlike("^(CHICAGO|CCHICAGO|CHICAGOO|CHICAGO\\.|CHICAGOC|CHICAGOCHICAGO|CH)$"), "CHICAGO")
                .otherwise(col("City"))
            )

            # Step 3: Keep only rows where City == "CHICAGO" and State == "IL"
            df_final = df_cleaned.filter((col("City") == "CHICAGO") & (col("State") == "IL"))\
                .withColumn("DI_JOB_ID", lit("Amruta_01")).withColumn("DI_LOAD_DT", current_date())

            df_final.select("City", "State").distinct().show()
            df_final.count()
            df_final.write.format("snowflake").options(**sfOptionsTarget).option("dbtable", targetTable).mode("overwrite").save()
            print("Chicago data cleaning completed.")
        except Exception as e:
            raise Exception(f"Chicago cleaning failed: {str(e)}")

    elif srcTable == 'DALLAS_FOOD_INSPECTIONS_RAW':
        try:
            df_raw = spark.read.format("snowflake").options(**sfOptions).option("dbtable", srcTable).load()
            # Count duplicates based on all columns (like GROUP BY + HAVING > 1)
            dup_count = df_raw.groupBy(df_raw.columns) \
                .count() \
                .filter("count > 1") \
                .count()
            print(f"Number of duplicate records: {dup_count}")
            df_deduped = df_raw.dropDuplicates()
            # Step 1: Remove parentheses from the LAT_LONG_LOCATION column
            df_cleaned = df_deduped.withColumn("latlong_cleaned", regexp_replace("LAT_LONG_LOCATION", "[()]", ""))

            # Step 2: Split on comma
            df_cleaned = df_cleaned.withColumn("latlong_split", split(col("latlong_cleaned"), ","))

            # Step 3: Create new columns for latitude and longitude as floats
            df_cleaned = df_cleaned.withColumn("Lat_split", col("latlong_split").getItem(0).cast("double")) \
                                .withColumn("Long_split", col("latlong_split").getItem(1).cast("double"))
            df_cleaned = df_cleaned.drop("latlong_cleaned", "latlong_split")
            df_cleaned.select("LAT_LONG_LOCATION", "Lat_split", "Long_split").show(10, truncate=False)
            df_cleaned = df_cleaned.withColumn("ZIP_CODE", split(col("ZIP_CODE"), "-").getItem(0))
            df_cleaned = df_cleaned.withColumn("ZIP_CODE", col("ZIP_CODE").cast("long"))
            df_cleaned = df_cleaned.filter((col("zip_code")>= 75201) & (col("zip_code")<= 75398))
            #median for Lat
            lat_values = df_cleaned.filter(col("Lat_split").isNotNull()) \
                                .select("Lat_split") \
                                .rdd.map(lambda row: row[0]) \
                                .sortBy(lambda x: x) \
                                .collect()

            lat_median = lat_values[len(lat_values) // 2]

            #median for Long
            long_values = df_cleaned.filter(col("Long_split").isNotNull()) \
                                    .select("Long_split") \
                                    .rdd.map(lambda row: row[0]) \
                                    .sortBy(lambda x: x) \
                                    .collect()

            long_median = long_values[len(long_values) // 2]


            #replacing nulls with median
            df_cleaned = df_cleaned.fillna({
                "Lat_split": lat_median,
                "Long_split": long_median
            })
            df_fixed = df_cleaned.withColumn("Lat_split",when((col("Lat_split") < 32.6) | (col("Lat_split") > 33.0), lit(lat_median))
                                             .otherwise(col("Lat_split"))).withColumn("Long_split", when((col("Long_split") < -97.0) | (col("Long_split") > -96.6), lit(long_median)).otherwise(col("Long_split")))
            cols_to_replace = {
                            "STREET_DIRECTION": "N/A - Not Available",
                            "STREET_TYPE": "N/A - Not Available",
                            "STREET_UNIT": "N/A - Not Available"
                            }

            df_fixed = df_fixed.fillna(cols_to_replace)
            # Step 1: Separate columns by data type
            violation_string_cols = []
            violation_points_cols = []

            for i in range(1, 26):
                violation_string_cols.extend([
                    f"VIOLATION_DESCRIPTION_{i}",
                    f"VIOLATION_DETAILS_{i}",
                    f"VIOLATION_MEMO_{i}"
                ])
                violation_points_cols.append(f"VIOLATION_POINTS_{i}")

            # Step 2: Replace nulls in string columns with "N/A"
            replace_string_dict = {col: "N/A - Not Applicable" for col in violation_string_cols}
            df_fixed = df_fixed.fillna(replace_string_dict)

            # Step 3: Replace nulls in numeric (points) columns with 0
            replace_numeric_dict = {col: 0 for col in violation_points_cols}
            df_fixed = df_fixed.fillna(replace_numeric_dict)
                        # 1. Drop the original LAT_LONG_LOCATION column
            df_fixed = df_fixed.drop("LAT_LONG_LOCATION")

            # 2. Cast ZIP_CODE to Integer
            df_fixed = df_fixed.withColumn("ZIP_CODE", col("ZIP_CODE").cast(IntegerType()))

            # 3. Cast INSPECTION_SCORE and STREET_NUMBER to Integer
            df_fixed = df_fixed \
                .withColumn("INSPECTION_SCORE", col("INSPECTION_SCORE").cast(IntegerType())) \
                .withColumn("STREET_NUMBER", col("STREET_NUMBER").cast(IntegerType()))

            # 4. Rename and cast Lat/Long to Double (for Snowflake FLOAT compatibility)
            df_fixed = df_fixed \
                .withColumnRenamed("Lat_split", "LATITUDE") \
                .withColumnRenamed("Long_split", "LONGITUDE") \
                .withColumn("LATITUDE", col("LATITUDE").cast(DoubleType())) \
                .withColumn("LONGITUDE", col("LONGITUDE").cast(DoubleType()))

            # 5. Cast all VIOLATION_POINTS_* columns to Integer
            for i in range(1, 26):
                df_fixed = df_fixed.withColumn(f"VIOLATION_POINTS_{i}", col(f"VIOLATION_POINTS_{i}").cast(IntegerType()))

            # 6. Add audit columns
            df_fixed = df_fixed \
                .withColumn("DI_JOB_ID", lit("SP_001")) \
                .withColumn("DI_LOAD_DT", current_date())

            # 7. Reorder columns to move audit fields to the end
            cols_without_audit = [c for c in df_fixed.columns if c not in ["DI_JOB_ID", "DI_LOAD_DT"]]
            final_column_order = cols_without_audit + ["DI_JOB_ID", "DI_LOAD_DT"]
            df_final = df_fixed.select(*final_column_order)
            df_final.write.format("snowflake").options(**sfOptionsTarget).option("dbtable", targetTable).mode("overwrite").save()

            print("Dallas data cleaning completed.")
        except Exception as e:
            raise Exception(f"Dallas cleaning failed: {str(e)}")

    else:
        raise ValueError(f"Unsupported src_table: '{src_table}'. Only 'DALLAS_FOOD_INSPECTIONS_RAW' or 'CHICAGO_FOOD_INSPECTIONS_RAW' are supported.")

except Exception as e:
    print(f"🔥 Error during cleaning: {e}")
    raise



Number of duplicate records: 18931
+-------------------------------+------------+-------------+
|LAT_LONG_LOCATION              |Lat_split   |Long_split   |
+-------------------------------+------------+-------------+
|\r                             |NULL        |NULL         |
|(37.774192712, -89.359825991)\r|37.774192712|-89.359825991|
|\r                             |NULL        |NULL         |
|\r                             |NULL        |NULL         |
|\r                             |NULL        |NULL         |
|\r                             |NULL        |NULL         |
|(32.763332, -96.855978)\r      |32.763332   |-96.855978   |
|\r                             |NULL        |NULL         |
|\r                             |NULL        |NULL         |
|(32.93083, -96.82094)\r        |32.93083    |-96.82094    |
+-------------------------------+------------+-------------+
only showing top 10 rows

Dallas data cleaning completed.
