In [None]:
from pyspark.sql import functions as F
from pyspark.sql.types import TimestampType, IntegerType, StringType
from datetime import datetime
import logging
import os

In [None]:
import time
import random
import re
import grpc
from google.api_core.exceptions import GoogleAPICallError, RetryError, ServiceUnavailable
from google.cloud.dataproc_spark_connect import DataprocSparkSession
from pyspark.errors import PySparkValueError

_QUOTA_PAT = re.compile(r"insufficient.*cpus.*quota", re.IGNORECASE)

def _is_quota_error(exc: Exception) -> bool:
    msg = str(exc) or ""
    return bool(_QUOTA_PAT.search(msg))

def get_spark_with_retry(
    gentle_delay: float = 10.0,     # fixed delay for the first few tries
    gentle_retries: int = 3,        # how many tries use gentle delay
    max_backoff: float = 60.0,      # cap for exponential backoff
    quota_wait: float = 90.0,       # wait longer when we hit a QUOTA error
    jitter: bool = True,
    verify_ready: bool = True,
    max_attempts: int | None = None # None = infinite
):
    """
    Keep retrying until DataprocSparkSession is usable.
    - QUOTA errors: wait `quota_wait` seconds and retry (lets other jobs finish).
    - Other errors: gentle linear delay for first `gentle_retries`, then exponential backoff.
    """
    attempt = 0
    backoff = gentle_delay

    while True:
        attempt += 1
        try:
            spark = DataprocSparkSession.builder.getOrCreate()
            # Let backend settle a moment before the test query
            time.sleep(3)

            if verify_ready:
                _ = spark.range(1).count()

            print(f" Spark connected on attempt {attempt}")
            return spark

        except (ServiceUnavailable, GoogleAPICallError, RetryError,
                grpc.RpcError, PySparkValueError, RuntimeError) as e:
            # QUOTA-specific path: someone else might be using CPUs. Wait longer, then retry.
            if _is_quota_error(e):
                wait = quota_wait + (random.uniform(0, quota_wait/3) if jitter else 0.0)
                print(f"⏳ Quota limited (likely other sessions using CPUs). "
                      f"Sleeping {wait:.1f}s, then retrying…\nDetails: {e}")
                time.sleep(wait)
            else:
                # Non-quota transient errors
                if attempt <= gentle_retries:
                    wait = gentle_delay
                else:
                    # exponential backoff capped
                    backoff = min(backoff * 2, max_backoff)
                    wait = backoff
                if jitter:
                    wait += random.uniform(0, wait/2)
                print(f"  Attempt {attempt} failed: {e!r}\n   Sleeping {wait:.1f}s before retry…")
                time.sleep(wait)

            if max_attempts and attempt >= max_attempts:
                raise RuntimeError(f" Failed after {attempt} attempts") from e

        except KeyboardInterrupt:
            print(" Stopped by user.")
            raise

# --- Usage ---
spark = get_spark_with_retry()

# Safe to use:
df = spark.range(5)
df.show()


## -- config

In [None]:
PROJECT_ID = "nyctaxi-467111"
BUCKET_NAME = "nyc_raw_data_bucket"
RAW_BUCKET = f"gs://{BUCKET_NAME}"
BRONZE_DATASET_NAME = "RawBronze"
SILVER_DATASET_NAME = "CleanSilver"
TAXI_TYPES = ["yellow", "green", "fhv", "fhvhv"]

# -------------------------
# Logger Setup (runtime log file)
# -------------------------

In [None]:
run_time = datetime.now().strftime("%Y%m%d_%H%M%S")
log_filename = f"Clean_green_{run_time}.log"
local_log_path = f"/tmp/{log_filename}"
gcs_log_path = f"logs/{log_filename}"
print(f" Logs will also be uploaded to: {gcs_log_path}")

logger = logging.getLogger("green_cleaning")
logger.setLevel(logging.INFO)

if logger.hasHandlers():
    logger.handlers.clear()

fh = logging.FileHandler(local_log_path)
fh.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)

ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(ch)

def log_and_upload(msg: str):
    logger.info(msg)
    # Optionally copy local_log_path to GCS
    os.system(f"gsutil cp {local_log_path} gs://{BUCKET_NAME}/{gcs_log_path}")


## - data to process

In [None]:
from google.cloud import bigquery

bq_client = bigquery.Client(project=PROJECT_ID)

def list_tables(dataset_name, prefix="yellow_"):
    """
    List tables in a dataset with a given prefix.
    Returns normalized year_month strings (yyyy-mm).
    """
    tables = bq_client.list_tables(dataset_name)
    ym_list = []
    for table in tables:
        table_name = table.table_id
        if table_name.startswith(prefix):
            # Example: yellow_2022_01 → 2022-01
            ym = table_name.replace(prefix, "").replace("_", "-")
            ym_list.append(ym)
    return sorted(ym_list)


In [None]:
bronze_list = list_tables(BRONZE_DATASET_NAME, prefix="green_")
silver_list = list_tables(SILVER_DATASET_NAME, prefix="green_")

print("Bronze available:", bronze_list)
print("Silver available:", silver_list)


In [None]:
missing_in_silver = sorted(set(bronze_list) - set(silver_list))

print(" Missing in Silver:", missing_in_silver)

# -- main

In [None]:
for year_month in missing_in_silver:
    year, month = map(int, year_month.split("-"))

    try:
        log_and_upload(f" Starting Green processing for {year_month}")

        # -------------------------
        # Load from RawBronze
        # -------------------------
        table_name = f"green_{year_month.replace('-', '_')}"
        rawbronze_table = f"{PROJECT_ID}.{BRONZE_DATASET_NAME}.{table_name}"
        df = spark.read.format("bigquery").option("table", rawbronze_table).load()
        log_and_upload(f"Loaded {df.count()} rows from {rawbronze_table}")

        # -------------------------
        # Type casting
        # -------------------------
        if not isinstance(df.schema["lpep_pickup_datetime"].dataType, TimestampType):
            df = df.withColumn("lpep_pickup_datetime", F.col("lpep_pickup_datetime").cast(TimestampType()))
        if not isinstance(df.schema["lpep_dropoff_datetime"].dataType, TimestampType):
            df = df.withColumn("lpep_dropoff_datetime", F.col("lpep_dropoff_datetime").cast(TimestampType()))

        categorical_columns = ['VendorID','RatecodeID','store_and_fwd_flag','PULocationID','DOLocationID','payment_type','trip_type']
        for col_name in categorical_columns:
            if col_name in df.columns:
                if col_name == 'store_and_fwd_flag':
                    df = df.withColumn(col_name, F.col(col_name).cast(StringType()))
                else:
                    df = df.withColumn(col_name, F.col(col_name).cast(IntegerType()))

        log_and_upload("Completed type casting")

        # -------------------------
        # Drop null passenger_count
        # -------------------------
        before = df.count()
        df = df.filter(F.col("passenger_count").isNotNull())
        log_and_upload(f"Dropped {before - df.count()} rows with null passenger_count")

        # -------------------------
        # Filter invalid times
        # -------------------------
        before = df.count()
        df = df.filter(F.col("lpep_pickup_datetime") < F.col("lpep_dropoff_datetime"))
        log_and_upload(f"Dropped {before - df.count()} rows with invalid pickup/dropoff")

        # -------------------------
        # Filter invalid distances
        # -------------------------
        before = df.count()
        df = df.filter(F.col("trip_distance") > 0)
        log_and_upload(f"Dropped {before - df.count()} rows with invalid distances")

        # -------------------------
        # Passenger count filter
        # -------------------------
        before = df.count()
        df = df.filter(F.col("passenger_count") > 0)
        log_and_upload(f"Dropped {before - df.count()} rows with invalid passenger_count")

        # -------------------------
        # Fare & total_amount checks
        # -------------------------
        before = df.count()
        df = df.filter(F.col("fare_amount") > 0)
        log_and_upload(f"Dropped {before - df.count()} rows with invalid fare_amount")

        before = df.count()
        df = df.filter(F.col("total_amount") > 0)
        log_and_upload(f"Dropped {before - df.count()} rows with invalid total_amount")

        # -------------------------
        # Outlier Removal (IQR trip_distance)
        # -------------------------
        Q1, Q3 = df.approxQuantile("trip_distance", [0.25, 0.75], 0.01)
        IQR = Q3 - Q1
        lb, ub = Q1 - 1.5*IQR, Q3 + 1.5*IQR
        before = df.count()
        df = df.filter((F.col("trip_distance") >= lb) & (F.col("trip_distance") <= ub))
        log_and_upload(f"Dropped {before - df.count()} trip_distance outliers")

        # -------------------------
        # Trip duration filter
        # -------------------------
        df = df.withColumn("trip_duration",
                           F.unix_timestamp("lpep_dropoff_datetime") - F.unix_timestamp("lpep_pickup_datetime"))
        before = df.count()
        df = df.filter((F.col("trip_duration") >= 5) & (F.col("trip_duration") <= 7200))
        log_and_upload(f"Dropped {before - df.count()} rows with invalid trip_duration")

        # -------------------------
        # Ensure only correct year/month
        # -------------------------
        df = df.withColumn("year", F.year("lpep_pickup_datetime")) \
               .withColumn("pickup_month", F.month("lpep_pickup_datetime"))
        before = df.count()
        df = df.filter((F.col("year") == year) & (F.col("pickup_month") == month))
        log_and_upload(f"Filtered to {year_month}, rows remaining: {df.count()}")

        # -------------------------
        # Write Cleaned → Silver
        # -------------------------
        clean_table = f"green_{year_month.replace('-', '_')}"
        df.write.format("bigquery") \
          .option("table", f"{PROJECT_ID}.{SILVER_DATASET_NAME}.{clean_table}") \
          .mode("overwrite") \
          .save()
        log_and_upload(f" Written to Silver table {SILVER_DATASET_NAME}.{clean_table}")

    except Exception as e:
        log_and_upload(f" ERROR processing {year_month}: {str(e)}")
        continue

In [None]:
import time

# Stop the Spark session gracefully
try:
    spark.stop()
    print(" Spark session stopped")
except Exception as e:
    print(f" Error while stopping Spark session: {e}")

# Sleep for 60 seconds to allow quota/resources to free up
print(" Waiting 60s for resources to be released...")
time.sleep(60)
print(" Done waiting. You can safely start a new session now.")
