In [None]:
# Databricks notebook source
# MAGIC %md
# MAGIC # NYC Yellow Taxi Data Processing Pipeline
# MAGIC 
# MAGIC This notebook processes raw NYC yellow taxi data and creates features for fare prediction.

# COMMAND ----------

# MAGIC %md
# MAGIC ## Setup and Configuration

# COMMAND ----------

import pyspark.sql.functions as F
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from delta.tables import DeltaTable
import boto3
from datetime import datetime
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Get parameters from job
source_year = dbutils.widgets.get("source_year")
source_month = dbutils.widgets.get("source_month")
source_bucket = dbutils.widgets.get("source_bucket")
source_key = dbutils.widgets.get("source_key")

print(f"Processing: {source_bucket}/{source_key}")

# COMMAND ----------

# MAGIC %md
# MAGIC ## CloudWatch Metrics Setup

# COMMAND ----------

# To properly send CloudWatch metrics, update your Databricks job configuration to include an instance profile with CloudWatch permissions
# def send_cloudwatch_metric(metric_name, value, unit='Count', dimensions=None):
#     """Send custom metrics to CloudWatch"""
#     try:
#         cloudwatch = boto3.client('cloudwatch', region_name='us-east-1')
        
#         metric_data = {
#             'MetricName': metric_name,
#             'Value': value,
#             'Unit': unit,
#             'Timestamp': datetime.utcnow()
#         }
        
#         if dimensions:
#             metric_data['Dimensions'] = dimensions
        
#         cloudwatch.put_metric_data(
#             Namespace='NYCTaxi/Processing',
#             MetricData=[metric_data]
#         )
#         logger.info(f"Sent metric: {metric_name} = {value}")
#     except Exception as e:
#         logger.error(f"Failed to send CloudWatch metric: {e}")

# COMMAND ----------

# MAGIC %md
# MAGIC ## Data Ingestion and Initial Validation

# COMMAND ----------

def load_and_validate_data(file_path):
    """Load raw data and perform initial validation"""
    try:
        # Load data
        df = spark.read.parquet(f"s3a://{source_bucket}/{source_key}")
        
        initial_count = df.count()
        logger.info(f"METRIC: Initial data load: {initial_count} records from {file_path}")
        
        # Send metric
        # send_cloudwatch_metric(
        #     'RecordsLoaded', 
        #     initial_count,
        #     dimensions=[
        #         {'Name': 'Year', 'Value': source_year},
        #         {'Name': 'Month', 'Value': source_month}
        #     ]
        # )
        
        return df, initial_count
    
    except Exception as e:
        logger.error(f"Failed to load data: {e}")
        # send_cloudwatch_metric('DataLoadFailed', 1)
        raise Exception(f"Failed to load data: {e}")

# Load the data
raw_df, initial_record_count = load_and_validate_data(f"{source_bucket}/{source_key}")

# COMMAND ----------

# MAGIC %md
# MAGIC ## Data Quality Checks and Cleaning

# COMMAND ----------

def clean_data(df):
    """Apply data quality checks and cleaning rules"""
    
    # Initial count
    start_count = df.count()
    
    # Remove records with invalid coordinates
    df_clean = df.filter(
        (F.col("PULocationID").isNotNull()) &
        (F.col("DOLocationID").isNotNull()) &
        (F.col("trip_distance") > 0) &
        (F.col("trip_distance") < 1000) &  # Remove outliers
        (F.col("fare_amount") > 0) &
        (F.col("fare_amount") < 1000) &    # Remove outliers
        (F.col("total_amount") > 0) &
        (F.col("passenger_count") > 0) &
        (F.col("passenger_count") <= 6) &  # Reasonable passenger count
        (F.col("tpep_pickup_datetime").isNotNull()) &
        (F.col("tpep_dropoff_datetime").isNotNull())
    )
    
    # Remove trips with negative duration
    df_clean = df_clean.filter(
        F.col("tpep_dropoff_datetime") > F.col("tpep_pickup_datetime")
    )
    
    # Remove trips longer than 24 hours (likely data errors)
    df_clean = df_clean.filter(
        (F.unix_timestamp("tpep_dropoff_datetime") - 
         F.unix_timestamp("tpep_pickup_datetime")) < 86400
    )
    
    final_count = df_clean.count()
    removed_count = start_count - final_count
    
    logger.info(f"Data cleaning: {start_count} -> {final_count} records (removed {removed_count})")
    
    # Send quality metrics
    logger.info(f'METRIC: RecordsRemoved {removed_count}')
    logger.info(f'METRIC: DataQualityScore {(final_count / start_count) * 100} Percent')
    
    return df_clean

cleaned_df = clean_data(raw_df)

# COMMAND ----------

# MAGIC %md
# MAGIC ## Feature Engineering

# COMMAND ----------

def engineer_features(df):
    """Create features for fare prediction model"""
    
    logger.info("Starting feature engineering...")
    
    # Calculate trip duration in minutes
    df_features = df.withColumn(
        "duration_minutes",
        (F.unix_timestamp(F.col("tpep_dropoff_datetime")) - 
        F.unix_timestamp(F.col("tpep_pickup_datetime"))) / 60
    )
    
    # Temporal features
    df_features = df_features.withColumn("pickup_hour", F.hour("tpep_pickup_datetime")) \
        .withColumn("pickup_day_of_week", F.dayofweek("tpep_pickup_datetime")) \
        .withColumn("pickup_month", F.month("tpep_pickup_datetime")) \
        .withColumn("pickup_year", F.year("tpep_pickup_datetime"))
    
    # Derived temporal features
    df_features = df_features \
        .withColumn("is_weekend", F.when(F.col("pickup_day_of_week").isin([1, 7]), 1).otherwise(0)) \
        .withColumn("is_rush_hour", 
                   F.when((F.col("pickup_hour").between(7, 9)) | 
                         (F.col("pickup_hour").between(17, 19)), 1).otherwise(0)) \
        .withColumn("season", 
                   F.when(F.col("pickup_month").isin([12, 1, 2]), "winter")
                    .when(F.col("pickup_month").isin([3, 4, 5]), "spring")
                    .when(F.col("pickup_month").isin([6, 7, 8]), "summer")
                    .otherwise("fall"))
    
    # Trip characteristics
    df_features = df_features \
        .withColumn("trip_distance_km", F.col("trip_distance") * 1.60934) \
        .withColumn("avg_speed_kmh", 
                   F.when(F.col("duration_minutes") > 0,
                         (F.col("trip_distance_km") / F.col("duration_minutes")) * 60)
                    .otherwise(0)) \
        .withColumn("fare_per_mile", F.col("fare_amount") / F.col("trip_distance")) \
        .withColumn("fare_per_minute", F.col("fare_amount") / F.col("duration_minutes"))
    
    # Payment features
    df_features = df_features \
        .withColumn("tip_percentage", 
                   F.when(F.col("fare_amount") > 0,
                         (F.col("tip_amount") / F.col("fare_amount")) * 100)
                    .otherwise(0)) \
        .withColumn("total_amount_per_passenger", 
                   F.col("total_amount") / F.col("passenger_count"))
    
    # Location features (simplified borough mapping)
    # Manhattan zones: 1-4, 12-13, 24-25, 41-43, 45, 48, 50, 68, 74-75, 79, 87-88, 90, 100, 103-104, 107-108, 113, 114, 116, 120-125, 127-128, 137, 140, 141-144, 148, 151-153, 158, 161-163, 164, 166, 170, 186, 194, 202, 209, 211, 224, 229-234, 236-237, 238, 239, 243, 244, 246, 249-250, 261-263
    manhattan_zones = [1,2,3,4,12,13,24,25,41,42,43,45,48,50,68,74,75,79,87,88,90,100,103,104,107,108,113,114,116,120,121,122,123,124,125,127,128,137,140,141,142,143,144,148,151,152,153,158,161,162,163,164,166,170,186,194,202,209,211,224,229,230,231,232,233,234,236,237,238,239,243,244,246,249,250,261,262,263]
    
    # Airport zones (JFK: 132, LGA: 138, Newark: 1)  
    airport_zones = [132, 138, 1]
    
    df_features = df_features \
        .withColumn("pickup_manhattan", 
                   F.when(F.col("PULocationID").isin(manhattan_zones), 1).otherwise(0)) \
        .withColumn("dropoff_manhattan", 
                   F.when(F.col("DOLocationID").isin(manhattan_zones), 1).otherwise(0)) \
        .withColumn("manhattan_trip",
                   F.when((F.col("pickup_manhattan") == 1) | (F.col("dropoff_manhattan") == 1), 1).otherwise(0)) \
        .withColumn("is_airport_trip",
                   F.when(F.col("PULocationID").isin(airport_zones) | 
                         F.col("DOLocationID").isin(airport_zones), 1).otherwise(0))
    
    # Add processing metadata
    df_features = df_features \
        .withColumn("processed_timestamp", F.current_timestamp()) \
        .withColumn("processing_year", F.lit(source_year)) \
        .withColumn("processing_month", F.lit(source_month))
    
    logger.info("Feature engineering completed")
    return df_features

# Apply feature engineering
feature_df = engineer_features(cleaned_df)

# COMMAND ----------

# MAGIC %md
# MAGIC ## Final Data Quality Checks

# COMMAND ----------

def validate_features(df):
    """Validate engineered features"""
    
    validation_results = {}
    
    # Check for null values in key features
    key_features = ["duration_minutes", "avg_speed_kmh", "fare_per_mile", "tip_percentage"]
    
    for feature in key_features:
        null_count = df.filter(F.col(feature).isNull()).count()
        validation_results[f"{feature}_nulls"] = null_count
        
        if null_count > 0:
            logger.warning(f"Found {null_count} null values in {feature}")
    
    # Check for unreasonable values
    unreasonable_speed = df.filter(F.col("avg_speed_kmh") > 200).count()
    unreasonable_duration = df.filter(F.col("duration_minutes") > 1440).count()  # > 24 hours
    
    validation_results["unreasonable_speed_count"] = unreasonable_speed
    validation_results["unreasonable_duration_count"] = unreasonable_duration
    
    # Send validation metrics
    for metric_name, value in validation_results.items():
        logger.info(f"METRIC: Validation_{metric_name} = {value}")
    
    logger.info(f"Feature validation completed: {validation_results}")
    return validation_results

validation_results = validate_features(feature_df)

# COMMAND ----------

# MAGIC %md
# MAGIC ## Save to Delta Lake

# COMMAND ----------

def setup_unity_catalog():
    """Create catalog and schema if they don't exist"""
    try:
        # Create catalog with managed location
        catalog_location = f"s3a://{source_bucket}/unity-catalog/nyc_taxi_analytics/"
        
        spark.sql(f"""
            CREATE CATALOG IF NOT EXISTS nyc_taxi_analytics
            MANAGED LOCATION '{catalog_location}'
        """)
        logger.info("Catalog 'nyc_taxi_analytics' created or already exists")
        
        # Create schema
        spark.sql("CREATE SCHEMA IF NOT EXISTS nyc_taxi_analytics.fare_prediction")
        logger.info("Schema 'nyc_taxi_analytics.fare_prediction' created or already exists")
        
        return True
        
    except Exception as e:
        logger.error(f"Failed to setup Unity Catalog: {e}")
        raise e

setup_unity_catalog()


def save_to_delta(df, table_path, table_name):
    """Save processed data to Delta Lake format"""
    
    try:
        record_count = df.count()
        logger.info(f"Saving {record_count} records to {table_path}")
        
        # Write to Delta Lake with partitioning
        df.write \
            .format("delta") \
            .mode("append") \
            .option("mergeSchema", "true") \
            .partitionBy("processing_year", "processing_month") \
            .save(table_path)
        
        # Register/update table in Unity Catalog
        spark.sql(f"""
            CREATE TABLE IF NOT EXISTS {table_name}
            USING DELTA
            LOCATION '{table_path}'
        """)
        
        # Send success metrics
        logger.info(f"METRIC: RecordsWritten {record_count}")
        # send_cloudwatch_metric('ProcessingSuccess', 1)
        
        logger.info(f"Successfully saved data to {table_name}")
        return True
        
    except Exception as e:
        # send_cloudwatch_metric('ProcessingFailed', 1)
        logger.error(f"Failed to save data: {e}")
        raise e

# Define paths and table names
processed_table_path = f"s3a://{source_bucket}/nyctaxi/processed/yellow_taxi_features/"
catalog_table_name = "nyc_taxi_analytics.fare_prediction.processed_yellow_taxi"

# Save the processed data
save_success = save_to_delta(feature_df, processed_table_path, catalog_table_name)

# COMMAND ----------

# MAGIC %md
# MAGIC ## Processing Summary and Cleanup

# COMMAND ----------

def log_processing_summary():
    """Log final processing summary"""
    
    summary = {
        "source_file": f"{source_bucket}/{source_key}",
        "processing_date": datetime.utcnow().isoformat(),
        "initial_records": initial_record_count,
        "final_records": feature_df.count(),
        "success": save_success
    }
    
    logger.info(f"Processing Summary: {summary}")
    
    # Save processing log to S3
    log_path = f"s3a://{source_bucket}/nyctaxi/processing_logs/success/{source_year}_{source_month}_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.json"
    
    log_df = spark.createDataFrame([summary])
    log_df.coalesce(1).write.mode("overwrite").json(log_path)
    
    return summary

# Generate final summary
processing_summary = log_processing_summary()

print("✅ Processing completed successfully!")
print(f"📊 Processed {processing_summary['final_records']} records")
print(f"💾 Data saved to: {catalog_table_name}")