In [0]:
%scala
import java.net.URL
import java.io.File
import java.nio.file.{Files, Paths, StandardCopyOption}

// -----------------------------
// Create a text widget for period
// -----------------------------
val periodParam = dbutils.widgets.get("period")  // value comes from widget

// -----------------------------
// Helper: Parse period parameter (YYYYMM)
// -----------------------------
def parsePeriod(period: String): (Int, Int) = {
  require(period.matches("""\d{6}"""), "Period must be in format YYYYMM, e.g., 202507")
  val year = period.substring(0, 4).toInt
  val month = period.substring(4, 6).toInt
  (year, month)
}

val (tyear, tmonth) = parsePeriod(periodParam)

// -----------------------------
// Function to download a Parquet file
// -----------------------------
def downloadTlcMonth(year: Int, month: Int, destFolder: String): Option[String] = {
  val filename = f"yellow_tripdata_$year-$month%02d.parquet"
  val url = s"https://d37ci6vzurychx.cloudfront.net/trip-data/$filename"

  // Create folder if it does not exist
  val folderPath = new File(destFolder)
  if (!folderPath.exists()) folderPath.mkdirs()

  val destPath = Paths.get(destFolder, filename)
  if (Files.exists(destPath)) {
    println(s"File already exists: $destPath")
    return Some(destPath.toString)
  }

  try {
    println(s"Downloading $filename ...")
    val website = new URL(url)
    val in = website.openStream()
    Files.copy(in, destPath, StandardCopyOption.REPLACE_EXISTING)
    in.close()
    println(s"Download complete: $destPath")
    Some(destPath.toString)
  } catch {
    case e: Exception =>
      println(s"Failed to download $filename: ${e.getMessage}")
      None
  }
}

// -----------------------------
// Download the selected month
// -----------------------------
val destFolder = "/dbfs/tmp/yellow"
val filePathOpt = downloadTlcMonth(tyear, tmonth, destFolder)

if (filePathOpt.isEmpty) throw new RuntimeException(s"Failed to download Yellow Taxi data for $tyear-$tmonth")
val filePath = filePathOpt.get

// -----------------------------
// Convert to Spark DBFS path
// -----------------------------
val sparkPath = "dbfs:/" + filePath.stripPrefix("/dbfs/")

// -----------------------------
// Read the month into Spark DataFrame
// -----------------------------
val df_month = spark.read.parquet(sparkPath)

println(s"\n=== Showing first 10 rows for $tyear-$tmonth%02d ===")
df_month.show(10, truncate=false)
();


In [0]:
%scala
val yellowTaxiDDL = """
VendorID INT NOT NULL,                      
tpep_pickup_datetime TIMESTAMP NOT NULL,
tpep_dropoff_datetime TIMESTAMP NOT NULL,
passenger_count BIGINT NOT NULL,
trip_distance DOUBLE NOT NULL,
RatecodeID BIGINT,
store_and_fwd_flag STRING,
PULocationID INT,
DOLocationID INT ,
payment_type BIGINT NOT NULL,
fare_amount DOUBLE NOT NULL,
extra DOUBLE,
mta_tax DOUBLE ,
tip_amount DOUBLE ,
tolls_amount DOUBLE ,
improvement_surcharge DOUBLE ,
total_amount DOUBLE NOT NULL,
congestion_surcharge DOUBLE ,
Airport_fee DOUBLE,
cbd_congestion_fee DOUBLE 
"""



import org.apache.spark.sql.types._
val yellowTaxiSchema = StructType.fromDDL(yellowTaxiDDL)

val fileSchema = spark.read.parquet(sparkPath).schema

if (!fileSchema.equals(yellowTaxiSchema)) {
  println(s"Schema mismatch detected for $sparkPath")
  println("Expected schema:")
  yellowTaxiSchema.printTreeString()
  println("File schema:")
  fileSchema.printTreeString()
}

val df_month_strict = spark.read.option("mergeSchema", "true").schema(yellowTaxiSchema).parquet(sparkPath)

//df_month_strict.printSchema()
//df_month_strict.show(5)

import org.apache.spark.sql.functions._

import org.apache.spark.sql.functions._

// Add flag columns for each check
 val df_flagged = df_month_strict
  .withColumn("invalid_passenger_count", when($"passenger_count" < 1 || $"passenger_count" > 6, lit(true)).otherwise(lit(false)))
  .withColumn("invalid_trip_distance", when($"trip_distance" <= 0, lit(true)).otherwise(lit(false)))
  .withColumn("invalid_payment_type", when(!$"payment_type".isin(1,2,3,4,5), lit(true)).otherwise(lit(false)))


// Show first 10 rows with flags
//df_flagged.show(10, truncate=false)

// Separate invalid and valid records based on flags
val df_quarantine = df_flagged.filter(
  $"invalid_passenger_count" || $"invalid_trip_distance" || $"invalid_payment_type"
)

val df_valid = df_flagged.filter(
  !$"invalid_passenger_count" && !$"invalid_trip_distance" && !$"invalid_payment_type"
)


// Build table names safely
val quarantineTable = s"bronze_layer_${tyear}_${tmonth}_yellow_taxi_quarantine"
val validTable      = s"bronze_layer_${tyear}_${tmonth}_yellow_taxi_valid"
val df_month_strict_name      = s"bronze_layer_${tyear}_${tmonth}_yellow_taxi_full"

// Write tables as Delta
df_quarantine.write
  .format("delta")
  .mode("overwrite")  // or "append" depending on your use case
  .saveAsTable(quarantineTable)

df_valid.write
  .format("delta")
  .mode("overwrite")
  .saveAsTable(validTable)

df_month_strict.write
  .format("delta")
  .mode("overwrite")
  .option("overwriteSchema", "true") 
  .saveAsTable(df_month_strict_name)

println("✅ Tables created successfully:")
println(" - yellow_taxi_valid")
println(" - yellow_taxi_quarantine")

// Count total records
val totalCount = df_flagged.count()

// Count quarantined and valid records
val quarantineCount = df_quarantine.count()
val validCount = df_valid.count()

println(s"Total records: $totalCount")
println(s"Valid records (non-quarantine): $validCount")
println(s"Quarantined records: $quarantineCount")
println(f"Data quality: ${validCount * 100.0 / totalCount}%.2f%% valid")