# Dataproc spark setup

In [None]:
from google.cloud.dataproc_spark_connect import DataprocSparkSession
from google.cloud.dataproc_v1 import Session
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType
import logging
import os
from datetime import datetime
from tqdm import tqdm

In [None]:
import time
import random
import grpc
from google.api_core.exceptions import GoogleAPICallError, RetryError, ServiceUnavailable
from google.cloud.dataproc_spark_connect import DataprocSparkSession
from pyspark.errors import PySparkValueError


def get_spark_with_retry(
    initial_delay: float = 10.0,    # seconds between first attempts
    max_backoff: float = 60.0,      # cap sleep time
    gentle_retries: int = 3,        # how many tries use linear delay before exponential
    jitter: bool = True,
    verify_ready: bool = True,
    max_attempts: int | None = None # None = infinite
):
    """
    Keep retrying until DataprocSparkSession.builder.getOrCreate() succeeds
    and is truly usable.

    - First `gentle_retries` attempts: linear delay (e.g. 10s each)
    - After that: exponential backoff capped at max_backoff
    """
    attempt = 0
    backoff = initial_delay

    while True:
        attempt += 1
        try:
            spark = DataprocSparkSession.builder.getOrCreate()

            # Give backend a few seconds to settle
            time.sleep(3)

            # Run a tiny query to confirm it's really connected
            if verify_ready:
                _ = spark.range(1).count()

            print(f"✅ Spark connected on attempt {attempt}")
            return spark

        except (ServiceUnavailable, GoogleAPICallError, RetryError,
                grpc.RpcError, PySparkValueError) as e:

            if max_attempts and attempt >= max_attempts:
                raise RuntimeError(f"Failed after {attempt} attempts: {e!r}")

            # gentle mode: first few retries wait fixed interval
            if attempt <= gentle_retries:
                sleep_for = initial_delay
            else:
                # exponential with jitter
                sleep_for = min(backoff * 2, max_backoff)
                backoff = sleep_for
                if jitter:
                    sleep_for += random.uniform(0, sleep_for / 2)

            print(f"Attempt {attempt} failed: {e!r}\n   Sleeping {sleep_for:.1f}s before retry...")
            time.sleep(sleep_for)

        except KeyboardInterrupt:
            print("Interrupted by user. Exiting retry loop.")
            raise


# --- Usage ---
spark = get_spark_with_retry()


# --- CONFIG ---

In [None]:
PROJECT_ID = "nyctaxi-467111"
BUCKET_NAME = "nyc_raw_data_bucket"
RAW_BUCKET = f"gs://{BUCKET_NAME}"
BRONZE_DATASET_NAME = "RawBronze"
TAXI_TYPES = ["yellow", "green", "fhv", "fhvhv"]


# --- LOGGING CONFIG ---

In [None]:
run_time = datetime.now().strftime("%Y%m%d_%H%M%S")
log_filename = f"01_GCSToBronzeIngestion_{run_time}.log"
local_log_path = f"/tmp/{log_filename}"     # local log file
gcs_log_path = f"logs/{log_filename}"       # GCS path for log
print(gcs_log_path)

logger = logging.getLogger()
logger.setLevel(logging.INFO)

# remove duplicate handlers if re-running notebook cells
if logger.hasHandlers():
    logger.handlers.clear()

# file handler
fh = logging.FileHandler(local_log_path)
fh.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)

# console handler
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(ch)


# -- dict of taxi_type → list of year-month strings in the GCS bucket --

In [None]:
from google.cloud import storage
import re

# --- CONFIG ---


storage_client = storage.Client()

def list_year_months_for_taxi(bucket_name, taxi_type):
    """
    List available year-months for a given taxi_type in GCS bucket.
    Looks for filenames like {taxi_type}_tripdata_YYYY-MM.parquet
    """
    prefix = f"{taxi_type}/"
    blobs = storage_client.list_blobs(bucket_name, prefix=prefix)

    year_months = set()
    pattern = re.compile(r"(\d{4})-(\d{2})\.parquet$")

    for blob in tqdm(blobs,desc=f"listing {taxi_type} data files from buckets"):
        match = pattern.search(blob.name)
        if match:
            year, month = match.group(1), match.group(2)
            year_months.add(f"{year}-{month}")

    return sorted(list(year_months))

# --- Build dictionary for all taxi types ---
taxi_files_map = {}
for taxi_type in tqdm(TAXI_TYPES):
    ym_list = list_year_months_for_taxi(BUCKET_NAME, taxi_type)
    taxi_files_map[taxi_type] = ym_list

# --- Print summary ---
for taxi, months in taxi_files_map.items():
    print(f" {taxi.upper()} : {len(months)} files")
    print(months if months else " No files found")
    print("-"*40)


# -- Check and Read Missing Bronze Tables --

In [None]:
def list_bronze_tables():
    """
    Return a dict of taxi_type -> available year-months in Bronze (tables).
    Assumes table naming: <taxi_type>_yyyy_mm
    """
    bronze_map = {}
    for taxi_type in tqdm(TAXI_TYPES,desc="listing tables in bronze layer"):
        # List all tables matching taxi_type prefix
        tables = spark.catalog.listTables(BRONZE_DATASET_NAME)
        ym_list = []
        for t in tables:
            if t.name.startswith(taxi_type + "_"):
                # Parse year-month from table name
                parts = t.name.split("_")
                if len(parts) >= 3:  # taxi_yyyy_mm
                    ym = f"{parts[1]}-{parts[2]}"
                    ym_list.append(ym)
        bronze_map[taxi_type] = sorted(ym_list)
    return bronze_map

# --- Compare Raw GCS vs Bronze ---
bronze_map = list_bronze_tables()
missing_map = {}

for taxi_type in TAXI_TYPES:
    raw_list = set(taxi_files_map.get(taxi_type, []))
    bronze_list = set(bronze_map.get(taxi_type, []))
    missing = sorted(list(raw_list - bronze_list))
    missing_map[taxi_type] = missing
    print(f" {taxi_type.upper()} missing in Bronze: {missing}")

# --- SCHEMA DEFINITIONS ---

In [None]:
from pyspark.sql.types import StructType, StructField, StringType, LongType,IntegerType, DoubleType, TimestampType, FloatType


In [None]:

yellow_schema = StructType([
    StructField("VendorID", LongType(), True), # Changed to LongType
    StructField("tpep_pickup_datetime", TimestampType(), True),
    StructField("tpep_dropoff_datetime", TimestampType(), True),
    StructField("passenger_count", DoubleType(), True),
    StructField("trip_distance", DoubleType(), True),
    StructField("RatecodeID", DoubleType(), True),
    StructField("store_and_fwd_flag", StringType(), True),
    StructField("PULocationID", LongType(), True),
    StructField("DOLocationID", LongType(), True),
    StructField("payment_type", DoubleType(), True),
    StructField("fare_amount", DoubleType(), True),
    StructField("extra", DoubleType(), True),
    StructField("mta_tax", DoubleType(), True),
    StructField("tip_amount", DoubleType(), True),
    StructField("tolls_amount", DoubleType(), True),
    StructField("improvement_surcharge", DoubleType(), True),
    StructField("total_amount", DoubleType(), True),
    StructField("congestion_surcharge", DoubleType(), True),
    StructField("airport_fee", DoubleType(), True),
    StructField("cbd_congestion_fee", DoubleType(), True),  # Added from Jan 2025 onwards
])


In [None]:

green_schema = StructType([
    StructField("VendorID", LongType(), True),  # 1=Creative Mobile, 2=Curb Mobility, 6=Myle Technologies
    StructField("lpep_pickup_datetime", TimestampType(), True),
    StructField("lpep_dropoff_datetime", TimestampType(), True),
    StructField("store_and_fwd_flag", StringType(), True),  # Y/N
    StructField("RatecodeID", DoubleType(), True),  # 1=Standard, 2=JFK, etc.
    StructField("PULocationID", DoubleType(), True),
    StructField("DOLocationID", DoubleType(), True),
    StructField("passenger_count", DoubleType(), True),
    StructField("trip_distance", DoubleType(), True),
    StructField("fare_amount", DoubleType(), True),
    StructField("extra", DoubleType(), True),
    StructField("mta_tax", DoubleType(), True),
    StructField("tip_amount", DoubleType(), True),   # Only credit card tips included
    StructField("tolls_amount", DoubleType(), True),
    StructField("improvement_surcharge", DoubleType(), True),
    StructField("total_amount", DoubleType(), True),  # Excludes cash tips
    StructField("payment_type", DoubleType(), True),  # 1=CC, 2=Cash, etc.
    StructField("trip_type", DoubleType(), True),    # 1=Street-hail, 2=Dispatch
    StructField("congestion_surcharge", DoubleType(), True),
    StructField("cbd_congestion_fee", DoubleType(), True)  # New from 2025 :contentReference[oaicite:0]{index=0}
])


In [None]:

fhv_schema = StructType([
    StructField("dispatching_base_num", StringType(), True),   # Base license number of the dispatch
    StructField("pickup_datetime", TimestampType(), True),     # Pickup timestamp
    StructField("dropOff_datetime", TimestampType(), True),    # Dropoff timestamp
    StructField("PUlocationID", IntegerType(), True),          # Pickup location zone ID
    StructField("DOlocationID", IntegerType(), True),          # Dropoff location zone ID
    StructField("SR_Flag", IntegerType(), True),               # Shared ride flag (1 if shared, else null)
    StructField("Affiliated_base_number", StringType(), True)  # Affiliated base number (even if same as dispatch)
])


In [None]:

hvfhv_schema = StructType([
    StructField("hvfhs_license_num", StringType(), True),     # HV0002=Juno, HV0003=Uber, HV0004=Via, HV0005=Lyft
    StructField("dispatching_base_num", StringType(), True),
    StructField("originating_base_num", StringType(), True),

    StructField("request_datetime", TimestampType(), True),
    StructField("on_scene_datetime", TimestampType(), True),
    StructField("pickup_datetime", TimestampType(), True),
    StructField("dropoff_datetime", TimestampType(), True),

    StructField("PULocationID", IntegerType(), True),
    StructField("DOLocationID", IntegerType(), True),

    StructField("trip_miles", DoubleType(), True),
    StructField("trip_time", IntegerType(), True),            # seconds

    StructField("base_passenger_fare", DoubleType(), True),
    StructField("tolls", DoubleType(), True),
    StructField("bcf", DoubleType(), True),                   # Black Car Fund
    StructField("sales_tax", DoubleType(), True),
    StructField("congestion_surcharge", DoubleType(), True),
    StructField("airport_fee", DoubleType(), True),
    StructField("tips", DoubleType(), True),
    StructField("driver_pay", DoubleType(), True),

    StructField("shared_request_flag", StringType(), True),   # Y/N
    StructField("shared_match_flag", StringType(), True),     # Y/N
    StructField("access_a_ride_flag", StringType(), True),    # Y/N
    StructField("wav_request_flag", StringType(), True),      # Y/N
    StructField("wav_match_flag", StringType(), True),        # Y/N

    StructField("cbd_congestion_fee", DoubleType(), True)     # Added from 2025
])


In [None]:
schemas = {
  "yellow": yellow_schema,
  "green": green_schema,
  "fhv": fhv_schema,
  "fhvhv": hvfhv_schema
}

# --- Read missing tables into Spark ---

In [None]:
from pyspark.sql.utils import AnalysisException
import traceback


In [None]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit

def normalize_schema(df: DataFrame, taxi_type: str, schemas: dict) -> DataFrame:
    """
    Normalize input df to match canonical schema for taxi_type.

    - Casts columns to canonical type
    - Adds missing columns with null
    - Drops extra columns
    """
    target_schema = schemas[taxi_type]
    target_cols = [f.name for f in target_schema]

    # Add missing cols with null
    for field in target_schema:
        if field.name not in df.columns:
            df = df.withColumn(field.name, lit(None).cast(field.dataType))

    # Cast existing cols to target type
    for field in target_schema:
        if field.name in df.columns:
            df = df.withColumn(field.name, df[field.name].cast(field.dataType))

    # Keep only target schema columns in correct order
    df = df.select([col for col in target_cols])
    return df


In [None]:
def ingest_missing_to_bronze(missing_map: dict):
    for taxi_type, ym_list in missing_map.items():
        schema = schemas.get(taxi_type)
        if schema is None:
            print(f" No schema for {taxi_type}, skipping")
            continue

        for ym in ym_list:
            year, month = ym.split("-")
            file_path = f"{RAW_BUCKET}/{taxi_type}/{year}/{taxi_type}_tripdata_{ym}.parquet"
            table_name = f"{BRONZE_DATASET_NAME}.{taxi_type}_{ym.replace('-', '_')}"

            print(f" Processing {file_path} → {table_name}")

            try:
                # Let Spark infer raw schema
                raw_df = spark.read.parquet(file_path)
                print(f" Read {raw_df.count()} rows with schema {raw_df.dtypes}")

                # Normalize to canonical schema
                df = normalize_schema(raw_df, taxi_type, schemas)

                try:
                    df.write \
                        .format("bigquery") \
                        .option("table", table_name) \
                        .option("writeMethod", "direct") \
                        .mode("overwrite") \
                        .save()
                    print(f" Saved {df.count()} rows to {table_name}")

                except Exception as e:
                    print(f" Failed writing {table_name}: {e}")

            except Exception as e:
                print(f" Failed reading {file_path}: {e}")


In [None]:
ingest_missing_to_bronze(missing_map)

## -- end

In [None]:

# yellow_schema = StructType([
#     StructField("VendorID", LongType(), True), # Changed to LongType
#     StructField("tpep_pickup_datetime", TimestampType(), True),
#     StructField("tpep_dropoff_datetime", TimestampType(), True),
#     StructField("passenger_count", DoubleType(), True),
#     StructField("trip_distance", DoubleType(), True),
#     StructField("RatecodeID", DoubleType(), True),
#     StructField("store_and_fwd_flag", StringType(), True),
#     StructField("PULocationID", LongType(), True),
#     StructField("DOLocationID", LongType(), True),
#     StructField("payment_type", LongType(), True),
#     StructField("fare_amount", DoubleType(), True),
#     StructField("extra", DoubleType(), True),
#     StructField("mta_tax", DoubleType(), True),
#     StructField("tip_amount", DoubleType(), True),
#     StructField("tolls_amount", DoubleType(), True),
#     StructField("improvement_surcharge", DoubleType(), True),
#     StructField("total_amount", DoubleType(), True),
#     StructField("congestion_surcharge", DoubleType(), True),
#     StructField("airport_fee", DoubleType(), True),
#     StructField("cbd_congestion_fee", DoubleType(), True),  # Added from Jan 2025 onwards
# ])

# green_schema = StructType([
#     StructField("VendorID", IntegerType (), True),  # 1=Creative Mobile, 2=Curb Mobility, 6=Myle Technologies
#     StructField("lpep_pickup_datetime", TimestampType(), True),
#     StructField("lpep_dropoff_datetime", TimestampType(), True),
#     StructField("store_and_fwd_flag", StringType(), True),  # Y/N
#     StructField("RatecodeID", LongType(), True),  # 1=Standard, 2=JFK, etc.
#     StructField("PULocationID", LongType(), True),
#     StructField("DOLocationID", LongType(), True),
#     StructField("passenger_count", DoubleType(), True),
#     StructField("trip_distance", DoubleType(), True),
#     StructField("fare_amount", DoubleType(), True),
#     StructField("extra", DoubleType(), True),
#     StructField("mta_tax", DoubleType(), True),
#     StructField("tip_amount", DoubleType(), True),   # Only credit card tips included
#     StructField("tolls_amount", DoubleType(), True),
#     StructField("improvement_surcharge", DoubleType(), True),
#     StructField("total_amount", DoubleType(), True),  # Excludes cash tips
#     StructField("payment_type", DoubleType(), True),  # 1=CC, 2=Cash, etc.
#     StructField("trip_type", DoubleType(), True),    # 1=Street-hail, 2=Dispatch
#     StructField("congestion_surcharge", DoubleType(), True),
#     StructField("cbd_congestion_fee", DoubleType(), True)  # New from 2025 :contentReference[oaicite:0]{index=0}
# ])


# fhv_schema = StructType([
#     StructField("dispatching_base_num", StringType(), True),   # Base license number of the dispatch
#     StructField("pickup_datetime", TimestampType(), True),     # Pickup timestamp
#     StructField("dropOff_datetime", TimestampType(), True),    # Dropoff timestamp
#     StructField("PUlocationID", LongType(), True),          # Pickup location zone ID
#     StructField("DOlocationID", LongType(), True),          # Dropoff location zone ID
#     StructField("SR_Flag", IntegerType(), True),               # Shared ride flag (1 if shared, else null)
#     StructField("Affiliated_base_number", StringType(), True)  # Affiliated base number (even if same as dispatch)
# ])

# hvfhs_schema = StructType([
#     StructField("hvfhs_license_num", StringType(), True),     # HV0002=Juno, HV0003=Uber, HV0004=Via, HV0005=Lyft
#     StructField("dispatching_base_num", StringType(), True),
#     StructField("originating_base_num", StringType(), True),

#     StructField("request_datetime", TimestampType(), True),
#     StructField("on_scene_datetime", TimestampType(), True),
#     StructField("pickup_datetime", TimestampType(), True),
#     StructField("dropoff_datetime", TimestampType(), True),

#     StructField("PULocationID", LongType(), True),
#     StructField("DOLocationID", LongType(), True),

#     StructField("trip_miles", DoubleType(), True),
#     StructField("trip_time", IntegerType(), True),            # seconds

#     StructField("base_passenger_fare", DoubleType(), True),
#     StructField("tolls", DoubleType(), True),
#     StructField("bcf", DoubleType(), True),                   # Black Car Fund
#     StructField("sales_tax", DoubleType(), True),
#     StructField("congestion_surcharge", DoubleType(), True),
#     StructField("airport_fee", DoubleType(), True),
#     StructField("tips", DoubleType(), True),
#     StructField("driver_pay", DoubleType(), True),

#     StructField("shared_request_flag", StringType(), True),   # Y/N
#     StructField("shared_match_flag", StringType(), True),     # Y/N
#     StructField("access_a_ride_flag", StringType(), True),    # Y/N
#     StructField("wav_request_flag", StringType(), True),      # Y/N
#     StructField("wav_match_flag", StringType(), True),        # Y/N

#     StructField("cbd_congestion_fee", DoubleType(), True)     # Added from 2025
# ])

# schemas = {
#   "yellow": yellow_schema,
#   "green": green_schema,
#   "fhv": fhv_schema,
#   "hvfhs": hvfhs_schema
# }

In [None]:
import time

# Stop the Spark session gracefully
try:
    spark.stop()
    print("Spark session stopped")
except Exception as e:
    print(f"Error while stopping Spark session: {e}")

# Sleep for 60 seconds to allow quota/resources to free up
print("Waiting 60s for resources to be released...")
time.sleep(60)
print("Done waiting. You can safely start a new session now.")


In [None]:
#  sleep 1 minute
