In [None]:
from google.cloud.dataproc_spark_connect import DataprocSparkSession
from google.cloud.dataproc_v1 import Session
from pyspark.sql import functions as F
from pyspark.sql.types import *
import time, datetime, os
import logging
import sys
from google.cloud import storage
import io
from google.cloud import bigquery
from tqdm import tqdm

import pyspark.sql.functions as f

In [None]:
import time
import random
import re
import grpc
from google.api_core.exceptions import GoogleAPICallError, RetryError, ServiceUnavailable
from google.cloud.dataproc_spark_connect import DataprocSparkSession
from pyspark.errors import PySparkValueError

_QUOTA_PAT = re.compile(r"insufficient.*cpus.*quota", re.IGNORECASE)

def _is_quota_error(exc: Exception) -> bool:
    msg = str(exc) or ""
    return bool(_QUOTA_PAT.search(msg))

def get_spark_with_retry(
    gentle_delay: float = 10.0,     # fixed delay for the first few tries
    gentle_retries: int = 3,        # how many tries use gentle delay
    max_backoff: float = 60.0,      # cap for exponential backoff
    quota_wait: float = 90.0,       # wait longer when we hit a QUOTA error
    jitter: bool = True,
    verify_ready: bool = True,
    max_attempts: int | None = None # None = infinite
):
    """
    Keep retrying until DataprocSparkSession is usable.
    - QUOTA errors: wait `quota_wait` seconds and retry (lets other jobs finish).
    - Other errors: gentle linear delay for first `gentle_retries`, then exponential backoff.
    """
    attempt = 0
    backoff = gentle_delay

    while True:
        attempt += 1
        try:
            spark = DataprocSparkSession.builder.getOrCreate()
            # Let backend settle a moment before the test query
            time.sleep(3)

            if verify_ready:
                _ = spark.range(1).count()

            print(f" Spark connected on attempt {attempt}")
            return spark

        except (ServiceUnavailable, GoogleAPICallError, RetryError,
                grpc.RpcError, PySparkValueError, RuntimeError) as e:
            # QUOTA-specific path: someone else might be using CPUs. Wait longer, then retry.
            if _is_quota_error(e):
                wait = quota_wait + (random.uniform(0, quota_wait/3) if jitter else 0.0)
                print(f" Quota limited (likely other sessions using CPUs). "
                      f"Sleeping {wait:.1f}s, then retrying…\nDetails: {e}")
                time.sleep(wait)
            else:
                # Non-quota transient errors
                if attempt <= gentle_retries:
                    wait = gentle_delay
                else:
                    # exponential backoff capped
                    backoff = min(backoff * 2, max_backoff)
                    wait = backoff
                if jitter:
                    wait += random.uniform(0, wait/2)
                print(f"  Attempt {attempt} failed: {e!r}\n   Sleeping {wait:.1f}s before retry…")
                time.sleep(wait)

            if max_attempts and attempt >= max_attempts:
                raise RuntimeError(f" Failed after {attempt} attempts") from e

        except KeyboardInterrupt:
            print(" Stopped by user.")
            raise

# --- Usage ---
spark = get_spark_with_retry()

# Safe to use:
df = spark.range(5)
df.show()


In [None]:
spark

In [None]:
# ---------------------------
# Global Config
# ---------------------------

PROJECT_ID = "nyctaxi-467111"
SILVER_DATASET = "CleanSilver"
PREML_DATASET = "PreMlGold"
BUCKET_NAME = "nyc_raw_data_bucket"
RUN_ID = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

TAXI_TYPES = ["yellow", "green", "fhv", "fhvhv"]

LOG_LOCAL = f"./pipeline_run_{RUN_ID}.log"
LOG_GCS   = f"gs://{BUCKET_NAME}/logs/preMLpipeline_logs/pipeline_run_{RUN_ID}.log"


In [None]:
def log_and_upload(msg: str):
    ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    line = f"[{ts}] {msg}"
    print(line)
    with open(LOG_LOCAL, "a") as f: f.write(line + "\n")
    os.system(f"gsutil cp {LOG_LOCAL} {LOG_GCS}")


# ==============================================================
# Stage 2: Check Silver vs PreML → build missing_map
# ==============================================================

In [None]:
bq_client = bigquery.Client(project=PROJECT_ID)

def list_tables(dataset_name: str):
    return [t.table_id for t in bq_client.list_tables(dataset_name)]

silver_tables = list_tables(SILVER_DATASET)
preml_tables = list_tables(PREML_DATASET)

missing_map = {}
for tbl in tqdm(silver_tables):
    taxi_type = tbl.split("_")[0]
    if taxi_type not in TAXI_TYPES:
       print(f"skipping {tbl}")
       continue
    if ((tbl+"_daily" not in preml_tables) and (tbl+"_hourly" not in preml_tables) and (tbl+"_hotspot" not in preml_tables))  :
        missing_map.setdefault(taxi_type, []).append(tbl)

log_and_upload(f"Identified missing_map = {missing_map}")

In [None]:
for i in missing_map:
  print(f"{i}: count {len(missing_map[i])}")

# ---------------------------
# Stage 2 - Aggregate Features for PreML
# ---------------------------

In [None]:
# ---------------------------
# Stage 2 - Aggregate Features for PreML
# ---------------------------

from pyspark.sql import functions as F
from pyspark.sql.types import StringType

def process_missing_to_preml(missing_map: dict):
    """
    For each taxi_type and missing CleanSilver table,
    aggregate features and save into PreML dataset.
    """
    for taxi_type, tbl_list in tqdm(missing_map.items()):
        for tbl in tbl_list:
            silver_table = f"{PROJECT_ID}.{SILVER_DATASET}.{tbl}"
            preml_table = f"{PROJECT_ID}.{PREML_DATASET}.{tbl}"

            log_and_upload(f" Processing {silver_table} → {preml_table}")

            try:
                # -------------------------
                # Load CleanSilver table
                # -------------------------
                df = spark.read.format("bigquery").option("table", silver_table).load()
                log_and_upload(f"Loaded {df.count()} rows from {silver_table}")

                # -------------------------
                # Common Time Columns
                # -------------------------
                # Different taxi types use different pickup/dropoff field names
                # WHERE: choose pickup/dropoff & revenue per taxi type
                if taxi_type == "yellow":
                    pickup_col, dropoff_col = "tpep_pickup_datetime", "tpep_dropoff_datetime"
                    revenue_col = "total_amount"
                elif taxi_type == "green":
                    pickup_col, dropoff_col = "lpep_pickup_datetime", "lpep_dropoff_datetime"
                    revenue_col = "total_amount"
                elif taxi_type == "fhv":
                    pickup_col, dropoff_col = "pickup_datetime", "dropOff_datetime"
                    revenue_col = F.lit(0.0)  # no fare data in fhv
                elif taxi_type == "fhvhv":
                    pickup_col, dropoff_col = "pickup_datetime", "dropoff_datetime"
                    revenue_col = "base_passenger_fare"
                else:
                    log_and_upload(f" Unsupported taxi_type {taxi_type}, skipping...")
                    continue
                # WHERE: derive date & hour for grouping
                df = df.withColumn("pickup_date", F.to_date(F.col(pickup_col)))
                df = df.withColumn("pickup_hour", F.hour(F.col(pickup_col)))

                # -------------------------
                # Aggregations
                # -------------------------
                log_and_upload(f"Aggregating  rows from {silver_table}")

                # 1. Daily Trip Counts & Revenue
                daily_agg = (
                    df.groupBy("pickup_date")
                      .agg(
                          F.count("*").alias("trips"),
                          F.sum(revenue_col).alias("revenue")
                      )
                      .withColumn("taxi_type", F.lit(taxi_type))
                )

                # 2. Hourly Trip Demand (for anomaly detection / forecasting)
                hourly_agg = (
                    df.groupBy("pickup_date", "pickup_hour")
                      .agg(F.count("*").alias("trips"),
                           F.sum(revenue_col).alias("revenue"))
                      .withColumn("taxi_type", F.lit(taxi_type))
                )

                # 3. Hotspot Clustering Prep (PU/DO location counts)
                hotspot_agg = (
                    df.groupBy("pickup_date", "PULocationID", "DOLocationID")
                      .agg(F.count("*").alias("trips"),
                           F.sum(revenue_col).alias("revenue"))
                      .withColumn("taxi_type", F.lit(taxi_type))
                )

                log_and_upload(f"Writing  rows from aggregated {silver_table}")

                # -------------------------
                # Write to PreML
                # -------------------------
                # Store partitioned by taxi_type_yyyy_mm
                (daily_agg.write.format("bigquery")
                    .option("table", preml_table + "_daily")
                    .option("writeMethod", "direct")
                    .mode("overwrite")
                    .save())

                (hourly_agg.write.format("bigquery")
                    .option("table", preml_table + "_hourly")
                    .option("writeMethod", "direct")
                    .mode("overwrite")
                    .save())

                (hotspot_agg.write.format("bigquery")
                    .option("table", preml_table + "_hotspot")
                    .option("writeMethod", "direct")
                    .mode("overwrite")
                    .save())

                log_and_upload(f" Saved aggregated PreML tables for {tbl}")

            except Exception as e:
                log_and_upload(f" ERROR processing {tbl}: {str(e)}")
                continue


# Validation & Incremental Update


In [None]:
def validate_and_update_preml(taxi_type: str, year_month_list: list):
    """
    Check CleanSilver vs PreML for a given taxi_type and list of yyyy_mm.
    Process only missing PreML partitions.
    """
    from google.cloud import bigquery
    client = bigquery.Client(project=PROJECT_ID)

    for year_month in tqdm(year_month_list):
        silver_table = f"{PROJECT_ID}.{SILVER_DATASET}.{taxi_type}_{year_month}"
        preml_daily   = f"{PROJECT_ID}.{PREML_DATASET}.{taxi_type}_{year_month}_daily"
        preml_hourly  = f"{PROJECT_ID}.{PREML_DATASET}.{taxi_type}_{year_month}_hourly"
        preml_hotspot = f"{PROJECT_ID}.{PREML_DATASET}.{taxi_type}_{year_month}_hotspot"

        try:
            silver_count = [r.cnt for r in client.query(f"SELECT COUNT(*) as cnt FROM `{silver_table}`").result()][0]
            if silver_count == 0:
                log_and_upload(f" Skipping {silver_table} (empty)")
                continue

            existing_preml = []
            for t in [preml_daily, preml_hourly, preml_hotspot]:
                try:
                    client.get_table(t)
                    existing_preml.append(t)
                except Exception:
                    pass

            if len(existing_preml) == 3:
                log_and_upload(f" PreML already exists for {taxi_type}_{year_month}, skipping.")
                continue

            log_and_upload(f" Running PreML aggregation for {taxi_type}_{year_month}")
            process_missing_to_preml({taxi_type: [f"{taxi_type}_{year_month}"]})

        except Exception as e:
            log_and_upload(f" ERROR validation {taxi_type}_{year_month}: {str(e)}")
            continue

# ---------------------------
# Master Orchestration
# ---------------------------

In [None]:
for taxi_type in tqdm(TAXI_TYPES):
        # Only capture tables like taxiType_YYYY_MM
        query = f"""
            SELECT table_name
            FROM `{PROJECT_ID}.{SILVER_DATASET}.INFORMATION_SCHEMA.TABLES`
            WHERE REGEXP_CONTAINS(table_name, r'^{taxi_type}_[0-9]{{4}}_[0-9]{{2}}$')
        """
        # WHERE: run that SQL against BigQuery via the Spark-BQ connector
        df_partitions = (
            spark.read.format("bigquery")
            .option("query", query)
            .load()
        )

        if df_partitions.count() == 0:
            print(f" No CleanSilver tables found for {taxi_type}")
            continue

        # Sorted year_month list
        year_month_list = sorted(
            [row.table_name.replace(f"{taxi_type}_", "") for row in df_partitions.collect()]
        )
        print(f" Found partitions for {taxi_type}: {year_month_list}")

In [None]:
def run_master_orchestration():
    run_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_and_upload(f" Starting Master Orchestration Run ID={run_id}")

    for taxi_type in tqdm(TAXI_TYPES):# WHERE: iterate across yellow/green/fhv/fhvhv
        try:
            # Only capture tables like taxiType_YYYY_MM
            query = f"""
                SELECT table_name
                FROM `{PROJECT_ID}.{SILVER_DATASET}.INFORMATION_SCHEMA.TABLES`
                WHERE REGEXP_CONTAINS(table_name, r'^{taxi_type}_[0-9]{{4}}_[0-9]{{2}}$')
            """
            df_partitions = (
                spark.read.format("bigquery")
                .option("query", query)
                .load()
            )

            # WHERE: if no partitions, skip this taxi type
            if df_partitions.count() == 0:
                log_and_upload(f" No CleanSilver tables found for {taxi_type}")
                continue

            # WHERE: extract and sort YYYY_MM suffixes from table names
            # Sorted year_month list
            year_month_list = sorted(
                [row.table_name.replace(f"{taxi_type}_", "") for row in df_partitions.collect()]
            )
            log_and_upload(f" Found partitions for {taxi_type}: {year_month_list}")

            # Stage 3 validation/update
            # WHERE: delegate to the PreML validator/runner
            validate_and_update_preml(taxi_type, year_month_list)

        except Exception as e:
            log_and_upload(f" ERROR orchestration {taxi_type}: {str(e)}")
            continue

    log_and_upload(" Master Orchestration completed.")


# --- ML Ready Views

In [None]:
# def create_ml_ready_views():
#     """Create unified BigQuery views for ML consumption."""
#     client = bigquery.Client(project=PROJECT_ID)
#     log_and_upload(" Creating ML Ready Views...")

#     # Daily
#     daily_query = f"""
#     CREATE OR REPLACE VIEW `{PROJECT_ID}.{PREML_DATASET}.all_taxi_daily` AS
#     SELECT taxi_type, pickup_date, SUM(trips) as trips, SUM(revenue) as revenue
#     FROM `{PROJECT_ID}.{PREML_DATASET}.*_daily`
#     GROUP BY taxi_type, pickup_date
#     """
#     client.query(daily_query).result()
#     log_and_upload(" Created view: all_taxi_daily")

#     # Hourly
#     hourly_query = f"""
#     CREATE OR REPLACE VIEW `{PROJECT_ID}.{PREML_DATASET}.all_taxi_hourly` AS
#     SELECT taxi_type, pickup_date, pickup_hour, SUM(trips) as trips
#     FROM `{PROJECT_ID}.{PREML_DATASET}.*_hourly`
#     GROUP BY taxi_type, pickup_date, pickup_hour
#     """
#     client.query(hourly_query).result()
#     log_and_upload(" Created view: all_taxi_hourly")

#     # Hotspots
#     hotspot_query = f"""
#     CREATE OR REPLACE VIEW `{PROJECT_ID}.{PREML_DATASET}.all_taxi_hotspots` AS
#     SELECT taxi_type, pickup_date, PULocationID, DOLocationID, SUM(trips) as trips
#     FROM `{PROJECT_ID}.{PREML_DATASET}.*_hotspot`
#     GROUP BY taxi_type, pickup_date, PULocationID, DOLocationID
#     """
#     client.query(hotspot_query).result()
#     log_and_upload(" Created view: all_taxi_hotspots")

#     log_and_upload(" ML Ready Views created successfully!")

In [None]:
# # prompt: delete all tables from PreMlGold

# import pandas_gbq

# # Construct the SQL query to get all table names in the PreMlGold dataset
# sql_query = f"""
# SELECT table_name
# FROM `{PROJECT_ID}.{PREML_DATASET}.INFORMATION_SCHEMA.TABLES`
# WHERE table_type = 'BASE TABLE'
# """

# # Read the table names into a DataFrame
# df_tables = pandas_gbq.read_gbq(sql_query, project_id=PROJECT_ID, dialect="standard")

# # Iterate through the table names and delete each table
# for table_name in tqdm(df_tables['table_name']):
#     try:
#         delete_query = f"DROP TABLE `{PROJECT_ID}.{PREML_DATASET}.{table_name}`"
#         pandas_gbq.read_gbq(delete_query, project_id=PROJECT_ID, dialect="standard")
#         print(f"Successfully deleted table: {table_name}")
#     except Exception as e:
#         print(f"Error deleting table {table_name}: {e}")

# -- main

In [None]:
# ==============================================================
# Main Entrypoint
# ==============================================================

def main():
    run_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_and_upload(f"Starting PreML Incremental Aggregation Pipeline | RunID={run_id}")

    try:
        # Stage 2 & 5 - Orchestration (find missing + aggregate)
        run_master_orchestration()

        # # Stage 6 - Create ML-ready views
        # create_ml_ready_views()

        log_and_upload(" Pipeline completed successfully!")

    except Exception as e:
        log_and_upload(f" Pipeline failed: {str(e)}")
        raise

    finally:
        log_and_upload(" Final log uploaded to GCS")
        log_and_upload(f"Log file location: {LOG_GCS}")


# Trigger when running notebook
if __name__ == "__main__":
    main()


In [None]:
import time

# Stop the Spark session gracefully
try:
    spark.stop()
    print(" Spark session stopped")
except Exception as e:
    print(f" Error while stopping Spark session: {e}")

# Sleep for 60 seconds to allow quota/resources to free up
print(" Waiting 60s for resources to be released...")
time.sleep(60)
print(" Done waiting. You can safely start a new session now.")
