In [None]:
import time
import random
import re
import grpc
from google.api_core.exceptions import GoogleAPICallError, RetryError, ServiceUnavailable
from google.cloud.dataproc_spark_connect import DataprocSparkSession
from pyspark.errors import PySparkValueError

_QUOTA_PAT = re.compile(r"insufficient.*cpus.*quota", re.IGNORECASE)

def _is_quota_error(exc: Exception) -> bool:
    msg = str(exc) or ""
    return bool(_QUOTA_PAT.search(msg))

def get_spark_with_retry(
    gentle_delay: float = 10.0,     # fixed delay for the first few tries
    gentle_retries: int = 3,        # how many tries use gentle delay
    max_backoff: float = 60.0,      # cap for exponential backoff
    quota_wait: float = 90.0,       # wait longer when we hit a QUOTA error
    jitter: bool = True,
    verify_ready: bool = True,
    max_attempts: int | None = None # None = infinite
):
    """
    Keep retrying until DataprocSparkSession is usable.
    - QUOTA errors: wait `quota_wait` seconds and retry (lets other jobs finish).
    - Other errors: gentle linear delay for first `gentle_retries`, then exponential backoff.
    """
    attempt = 0
    backoff = gentle_delay

    while True:
        attempt += 1
        try:
            spark = DataprocSparkSession.builder.getOrCreate()
            # Let backend settle a moment before the test query
            time.sleep(3)

            if verify_ready:
                _ = spark.range(1).count()

            print(f" Spark connected on attempt {attempt}")
            return spark

        except (ServiceUnavailable, GoogleAPICallError, RetryError,
                grpc.RpcError, PySparkValueError, RuntimeError) as e:
            # QUOTA-specific path: someone else might be using CPUs. Wait longer, then retry.
            if _is_quota_error(e):
                wait = quota_wait + (random.uniform(0, quota_wait/3) if jitter else 0.0)
                print(f"⏳ Quota limited (likely other sessions using CPUs). "
                      f"Sleeping {wait:.1f}s, then retrying…\nDetails: {e}")
                time.sleep(wait)
            else:
                # Non-quota transient errors
                if attempt <= gentle_retries:
                    wait = gentle_delay
                else:
                    # exponential backoff capped
                    backoff = min(backoff * 2, max_backoff)
                    wait = backoff
                if jitter:
                    wait += random.uniform(0, wait/2)
                print(f"  Attempt {attempt} failed: {e!r}\n   Sleeping {wait:.1f}s before retry…")
                time.sleep(wait)

            if max_attempts and attempt >= max_attempts:
                raise RuntimeError(f" Failed after {attempt} attempts") from e

        except KeyboardInterrupt:
            print(" Stopped by user.")
            raise

# --- Usage ---
spark = get_spark_with_retry()

# Safe to use:
df = spark.range(5)
df.show()


In [None]:
import logging
from datetime import datetime
from pyspark.sql import functions as F
from pyspark.sql.types import (
    StringType, IntegerType, DoubleType, TimestampType,
    StructType, StructField
)

# -----------------------------
# CONFIG
# -----------------------------

In [None]:
PROJECT_ID = "nyctaxi-467111"
BUCKET_NAME = "nyc_raw_data_bucket"
BRONZE_DATASET_NAME = "RawBronze"
SILVER_DATASET_NAME = "CleanSilver"
TABLE_PREFIX = "fhvhv"

# -----------------------------
# SCHEMA
# -----------------------------

In [None]:
hvfhv_schema = StructType([
    StructField("hvfhs_license_num", StringType(), True),
    StructField("dispatching_base_num", StringType(), True),
    StructField("originating_base_num", StringType(), True),

    StructField("request_datetime", TimestampType(), True),
    StructField("on_scene_datetime", TimestampType(), True),
    StructField("pickup_datetime", TimestampType(), True),
    StructField("dropoff_datetime", TimestampType(), True),

    StructField("PULocationID", IntegerType(), True),
    StructField("DOLocationID", IntegerType(), True),

    StructField("trip_miles", DoubleType(), True),
    StructField("trip_time", IntegerType(), True),

    StructField("base_passenger_fare", DoubleType(), True),
    StructField("tolls", DoubleType(), True),
    StructField("bcf", DoubleType(), True),
    StructField("sales_tax", DoubleType(), True),
    StructField("congestion_surcharge", DoubleType(), True),
    StructField("airport_fee", DoubleType(), True),
    StructField("tips", DoubleType(), True),
    StructField("driver_pay", DoubleType(), True),

    StructField("shared_request_flag", StringType(), True),
    StructField("shared_match_flag", StringType(), True),
    StructField("access_a_ride_flag", StringType(), True),
    StructField("wav_request_flag", StringType(), True),
    StructField("wav_match_flag", StringType(), True),

    StructField("cbd_congestion_fee", DoubleType(), True)
])


# -----------------------------
# LOGGER (single per run)
# -----------------------------

In [None]:

run_time = datetime.now().strftime("%Y%m%d_%H%M%S")
log_filename = f"02d{TABLE_PREFIX}_cleaning_{run_time}.log"
local_log_path = f"/tmp/{log_filename}"
gcs_log_path = f"logs/{log_filename}"

logger = logging.getLogger()
logger.setLevel(logging.INFO)
if logger.hasHandlers():
    logger.handlers.clear()
fh = logging.FileHandler(local_log_path)
fh.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(ch)

def log_and_upload(msg: str):
    logger.info(msg)
    try:
        import subprocess
        subprocess.run(
            ["gsutil", "cp", local_log_path, f"gs://{BUCKET_NAME}/{gcs_log_path}"],
            check=True
        )
    except Exception as e:
        logger.error(f" Failed to upload log: {e}")

# -----------------------------
# HELPER: LIST TABLES
# -----------------------------

In [None]:
TABLE_PREFIX

In [None]:
def list_tables(dataset):
    query = f"SELECT table_name FROM `{PROJECT_ID}.{dataset}.INFORMATION_SCHEMA.TABLES`"
    return [row.table_name for row in spark.read.format("bigquery").load(query).collect()]

bronze_tables = list_tables(BRONZE_DATASET_NAME)
silver_tables = list_tables(SILVER_DATASET_NAME)

bronze_months = {t.replace(f"{TABLE_PREFIX}_", "") for t in bronze_tables if t.startswith(TABLE_PREFIX)}
silver_months = {t.replace(f"{TABLE_PREFIX}_", "") for t in silver_tables if t.startswith(TABLE_PREFIX)}

missing_months = sorted(bronze_months - silver_months)
log_and_upload(f"Missing months to process: {missing_months}")

In [None]:
missing_months

# -----------------------------
# CLEANING LOOP
# -----------------------------

In [None]:
for year_month in missing_months:
    year, month = map(int, year_month.split("_"))
    table_name = f"{TABLE_PREFIX}_{year_month.replace('-', '_')}"

    try:
        log_and_upload(f" Starting processing for {year_month}")

        # --- Load from Raw Bronze ---
        df = spark.read.format("bigquery").option(
            "table", f"{PROJECT_ID}.{BRONZE_DATASET_NAME}.{table_name}"
        ).load()
        log_and_upload(f"Loaded {df.count()} rows from {BRONZE_DATASET_NAME}.{table_name}")

        # --- Cleaning Rules ---
        before = df.count()
        df = df.filter(F.col("pickup_datetime").isNotNull() & F.col("dropoff_datetime").isNotNull())
        log_and_upload(f"Dropped {before - df.count()} rows with null pickup/dropoff")

        before = df.count()
        df = df.filter(F.col("pickup_datetime") < F.col("dropoff_datetime"))
        log_and_upload(f"Dropped {before - df.count()} rows with invalid pickup/dropoff order")

        for loc_col in ["PULocationID", "DOLocationID"]:
            if loc_col in df.columns:
                before = df.count()
                df = df.filter(F.col(loc_col).isNotNull() & (F.col(loc_col) >= 0))
                log_and_upload(f"Dropped {before - df.count()} rows with invalid {loc_col}")

        # trip_time sanity check (hvfhv provides directly)
        before = df.count()
        df = df.filter((F.col("trip_time") >= 60) & (F.col("trip_time") <= 86400))
        log_and_upload(f"Dropped {before - df.count()} rows with invalid trip_time")

        # trip_miles > 0
        before = df.count()
        df = df.filter(F.col("trip_miles") > 0)
        log_and_upload(f"Dropped {before - df.count()} rows with invalid trip_miles")

        # Flags normalization: force to Y/N
        flag_cols = ["shared_request_flag", "shared_match_flag",
                     "access_a_ride_flag", "wav_request_flag", "wav_match_flag"]
        for col in flag_cols:
            if col in df.columns:
                df = df.withColumn(col, F.when(F.col(col).isin("Y", "N"), F.col(col)).otherwise(None))

        # Filter to target year/month
        df = df.withColumn("year", F.year("pickup_datetime")).withColumn("pickup_month", F.month("pickup_datetime"))
        before = df.count()
        df = df.filter((F.col("year") == year) & (F.col("pickup_month") == month))
        log_and_upload(f"Filtered to {year_month}, remaining {df.count()} rows")

        # --- Write to Clean Silver ---
        df.write.format("bigquery").option(
            "table", f"{PROJECT_ID}.{SILVER_DATASET_NAME}.{table_name}"
        ).mode("overwrite").save()
        log_and_upload(f" Written to {SILVER_DATASET_NAME}.{table_name}")

    except Exception as e:
        log_and_upload(f" ERROR processing {year_month}: {str(e)}")
        continue

In [None]:
import time

# Stop the Spark session gracefully
try:
    spark.stop()
    print(" Spark session stopped")
except Exception as e:
    print(f" Error while stopping Spark session: {e}")

# Sleep for 60 seconds to allow quota/resources to free up
print(" Waiting 60s for resources to be released...")
time.sleep(60)
print(" Done waiting. You can safely start a new session now.")
