In [0]:
import mlflow
import mlflow.spark
from mlflow.tracking import MlflowClient

In [0]:
spark.sql("USE CATALOG hive_metastore")
spark.sql("USE default")

print("="*70)
print("ENVIRONMENT SETUP")
print("="*70)
print(f"Current Catalog: {spark.sql('SELECT current_catalog()').collect()[0][0]}")
print(f"Current Database: {spark.sql('SELECT current_database()').collect()[0][0]}")
print(f"Spark version: {spark.version}")
print("="*70)

ENVIRONMENT SETUP
Current Catalog: hive_metastore
Current Database: default
Spark version: 4.0.0


In [0]:
# ============================================
# Find Best Model from MLFlow
# ============================================

from mlflow.tracking import MlflowClient

client = MlflowClient()

print("="*70)
print("ML MODEL REGISTRY")
print("="*70)

# Get experiment
username = spark.sql("SELECT current_user()").collect()[0][0]
experiment_name = f"/Users/{username}/nyc-taxi-prediction"
experiment = mlflow.get_experiment_by_name(experiment_name)

if experiment is None:
    raise Exception(f"Experiment '{experiment_name}' not found! Run Notebook 04 first.")

print(f"\nExperiment ID: {experiment.experiment_id}")
print(f"Experiment Name: {experiment_name}")

# Method 1: Try to get best model from MLFlow runs (primary method)
print("\nSearching for best model in MLFlow...")

try:
    runs = mlflow.search_runs(
        experiment_ids=[experiment.experiment_id],
        filter_string="metrics.rmse > 0",
        order_by=["metrics.rmse ASC"],
        max_results=1
    )
    
    if len(runs) == 0:
        raise Exception("No runs with logged metrics found")
    
    best_run = runs.iloc[0]
    best_run_id = best_run["run_id"]
    
    # Get model metadata
    best_model_name = best_run['tags.mlflow.runName']
    best_rmse = best_run['metrics.rmse']
    best_mae = best_run['metrics.mae']
    best_r2 = best_run['metrics.r2']
    
    print(f"\nBest Run Found (from MLFlow):")
    print(f"  Run ID: {best_run_id}")
    print(f"  Model: {best_model_name}")
    print(f"  RMSE: {best_rmse:.2f} minutes")
    print(f"  MAE: {best_mae:.2f} minutes")
    print(f"  R²: {best_r2:.4f}")
    
    mlflow_method_success = True

except Exception as e:
    print(f"\nCould not find metrics in MLFlow: {str(e)}")
    mlflow_method_success = False

# Method 2: Fallback to model_comparison_metrics table (backup method)
if not mlflow_method_success:
    print("\nTrying backup method: model_comparison_metrics table...")
    
    try:
        # Load from the table you created in Notebook 04
        model_metrics_df = spark.table("model_comparison_metrics").orderBy("rmse")
        best_model_row = model_metrics_df.first()
        
        if best_model_row is None:
            raise Exception("No models found in comparison table!")
        
        best_model_name = best_model_row["model_name"]
        best_rmse = best_model_row["rmse"]
        best_mae = best_model_row["mae"]
        best_r2 = best_model_row["r2"]
        
        # Find the run_id for this model
        runs_all = mlflow.search_runs(
            experiment_ids=[experiment.experiment_id],
            filter_string=f"tags.mlflow.runName LIKE '%{best_model_name}%'"
        )
        
        if len(runs_all) > 0:
            best_run_id = runs_all.iloc[0]["run_id"]
        else:
            raise Exception(f"Could not find MLFlow run for {best_model_name}")
        
        print(f"\nBest Model Found (from comparison table):")
        print(f"  Run ID: {best_run_id}")
        print(f"  Model: {best_model_name}")
        print(f"  RMSE: {best_rmse:.2f} minutes")
        print(f"  MAE: {best_mae:.2f} minutes")
        print(f"  R²: {best_r2:.4f}")
        
    except Exception as e:
        print(f"\nFATAL ERROR: Could not find best model anywhere!")
        print(f"   Error: {str(e)}")
        print(f"\nPlease ensure you have:")
        raise


ML MODEL REGISTRY

Experiment ID: 2575000388711542
Experiment Name: /Users/hingushrey2707@gmail.com/nyc-taxi-prediction

Searching for best model in MLFlow...

Best Run Found (from MLFlow):
  Run ID: 9254956582fb4d19bec2029d0e6f6fb4
  Model: 03_gbt_v1
  RMSE: 4.79 minutes
  MAE: 2.22 minutes
  R²: 0.8776


In [0]:
# ============================================
# Configure Legacy Workspace Registry
# ============================================

import mlflow

# Switch to legacy Workspace Model Registry
mlflow.set_registry_uri("databricks")

print("="*70)
print("MLFLOW REGISTRY CONFIGURATION")
print("="*70)
print(f"Configured to use Legacy Workspace Model Registry")
print(f"   Registry URI: {mlflow.get_registry_uri()}")
print(f"   Model naming: Simple format (no catalog/schema required)")

MLFLOW REGISTRY CONFIGURATION
Configured to use Legacy Workspace Model Registry
   Registry URI: databricks
   Model naming: Simple format (no catalog/schema required)


In [0]:
# ============================================
# Load Model from DBFS and Register
# ============================================

import re
from pyspark.sql.functions import col

print("="*70)
print("LOADING MODEL FROM DBFS")
print("="*70)

# Clean up model name for registry
clean_model_name = re.sub(r'^\d+_', '', best_model_name)
clean_model_name = re.sub(r'_v\d+$', '', clean_model_name)

# Determine model path based on best model name
# Map run names to saved model paths
model_path_mapping = {
    "linear_regression": "/mnt/taxi-data/models/best_linear_regression",
    "random_forest": "/mnt/taxi-data/models/best_random_forest",
    "gbt": "/mnt/taxi-data/models/best_gradient_boosted_trees",
    "gradient_boosted_trees": "/mnt/taxi-data/models/best_gradient_boosted_trees"
}

# Find the model path
model_path = None
for key, path in model_path_mapping.items():
    if key in clean_model_name.lower():
        model_path = path
        break

if model_path is None:
    # Default to gradient boosted trees (best performing)
    model_path = "/mnt/taxi-data/models/best_gradient_boosted_trees"
    print(f"\nCould not map model name, using default: {model_path}")

print(f"\nLoading model from DBFS...")
print(f"   Path: {model_path}")
print(f"   Model Type: {clean_model_name}")

# Load the appropriate model type
try:
    if "linear" in clean_model_name.lower():
        from pyspark.ml.regression import LinearRegressionModel
        loaded_model = LinearRegressionModel.load(model_path)
    elif "random_forest" in clean_model_name.lower() or "rf" in clean_model_name.lower():
        from pyspark.ml.regression import RandomForestRegressionModel
        loaded_model = RandomForestRegressionModel.load(model_path)
    else:  # GBT
        from pyspark.ml.regression import GBTRegressionModel
        loaded_model = GBTRegressionModel.load(model_path)
    
    print(f"Model loaded successfully from DBFS")
    
except Exception as e:
    print(f"\nFailed to load model from {model_path}")
    print(f"   Error: {str(e)}")
    raise

# Get metrics (either from MLFlow or from comparison table)
print(f"\nModel Performance:")
print(f"   RMSE: {best_rmse:.2f} minutes")
print(f"   MAE: {best_mae:.2f} minutes")
print(f"   R²: {best_r2:.4f}")

# ============================================
# Create New MLFlow Run and Log Model
# ============================================

print(f"\nCreating new MLFlow run with model artifact...")

# Create new run and log the model
with mlflow.start_run(run_name=f"{clean_model_name}_for_registry") as new_run:
    
    # Log metrics
    mlflow.log_metric("rmse", best_rmse)
    mlflow.log_metric("mae", best_mae)
    mlflow.log_metric("r2", best_r2)
    
    # Log parameters
    mlflow.log_param("model_type", clean_model_name.replace('_', ' ').title())
    mlflow.log_param("source", "loaded_from_dbfs")
    mlflow.log_param("original_run_id", best_run_id)
    
    # Log the model artifact
    mlflow.spark.log_model(
        spark_model=loaded_model,
        artifact_path="model"
    )
    
    new_run_id = new_run.info.run_id
    
    print(f"\nNew run created with model artifact!")
    print(f"   New Run ID: {new_run_id}")
    print(f"   Original Run ID: {best_run_id}")

# ============================================
# Register Model to Registry
# ============================================

registry_model_name = "nyc_taxi_duration_predictor"

print(f"\n" + "="*70)
print("REGISTERING MODEL TO MLFLOW REGISTRY")
print("="*70)

print(f"\nModel Details:")
print(f"  Registry Name: {registry_model_name}")
print(f"  Model Type: {clean_model_name.replace('_', ' ').title()}")
print(f"  Source Path: {model_path}")
print(f"  Performance: RMSE={best_rmse:.2f}, MAE={best_mae:.2f}, R²={best_r2:.4f}")

# Check existing versions
try:
    existing_versions = client.search_model_versions(f"name='{registry_model_name}'")
    print(f"\nFound {len(existing_versions)} existing version(s) in registry")
except:
    print(f"\nNo existing versions found (first registration)")

# Register the model
print(f"\nRegistering model...")

try:
    model_details = mlflow.register_model(
        model_uri=f"runs:/{new_run_id}/model",
        name=registry_model_name
    )
    
    print(f"\nModel registered successfully!")
    print(f"   Model Name: {registry_model_name}")
    print(f"   Version: {model_details.version}")
    print(f"   Status: {model_details.status}")
    print(f"   Source Model: {best_model_name}")
    
except Exception as e:
    print(f"\nRegistration failed: {str(e)}")
    raise

print("\n" + "="*70)

LOADING MODEL FROM DBFS

Loading model from DBFS...
   Path: /mnt/taxi-data/models/best_gradient_boosted_trees
   Model Type: gbt
Model loaded successfully from DBFS

Model Performance:
   RMSE: 4.79 minutes
   MAE: 2.22 minutes
   R²: 0.8776

Creating new MLFlow run with model artifact...


2025/12/03 03:27:55 INFO mlflow.spark: Inferring pip requirements by reloading the logged model from the databricks artifact repository, which can be time-consuming. To speed up, explicitly specify the conda_env or pip_requirements when calling log_model().


Downloading artifacts:   0%|          | 0/20 [00:00<?, ?it/s]



Uploading artifacts:   0%|          | 0/4 [00:00<?, ?it/s]


New run created with model artifact!
   New Run ID: c8b3d838d0f64e89b9821cd76225b22b
   Original Run ID: 9254956582fb4d19bec2029d0e6f6fb4

REGISTERING MODEL TO MLFLOW REGISTRY

Model Details:
  Registry Name: nyc_taxi_duration_predictor
  Model Type: Gbt
  Source Path: /mnt/taxi-data/models/best_gradient_boosted_trees
  Performance: RMSE=4.79, MAE=2.22, R²=0.8776

Found 2 existing version(s) in registry

Registering model...


Registered model 'nyc_taxi_duration_predictor' already exists. Creating a new version of this model...
2025/12/03 03:28:48 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: nyc_taxi_duration_predictor, version 3



Model registered successfully!
   Model Name: nyc_taxi_duration_predictor
   Version: 3
   Status: PENDING_REGISTRATION
   Source Model: 03_gbt_v1



Created version '3' of model 'nyc_taxi_duration_predictor'.


In [0]:
# ============================================
# Update Model Description and Lifecycle
# ============================================

print("="*70)
print("MODEL LIFECYCLE MANAGEMENT")
print("="*70)

from mlflow.tracking import MlflowClient
client = MlflowClient()

# Create comprehensive model description
description = f"""
NYC Yellow Taxi Trip Duration Prediction Model

MODEL INFORMATION:
- Algorithm: {clean_model_name.replace('_', ' ').title()}
- Training Framework: Apache Spark MLlib
- Use Case: Trip duration prediction for dispatch planning

PERFORMANCE METRICS:
- RMSE: {best_rmse:.2f} minutes
- MAE: {best_mae:.2f} minutes  
- R² Score: {best_r2:.4f}
- Best among: Linear Regression, Random Forest, Gradient Boosted Trees

FEATURES (8 total):
1. trip_distance - Distance in miles
2. passenger_count - Number of passengers (1-6)
3. fare_amount - Trip fare in USD
4. hour_of_day - Hour of day (0-23)
5. day_of_week - Day of week (1=Sunday, 7=Saturday)
6. is_weekend - Weekend indicator (0/1)
7. PULocationID - Pickup location zone (indexed)
8. DOLocationID - Dropoff location zone (indexed)

TRAINING DETAILS:
- Dataset: NYC TLC Yellow Taxi (Jan 2024 - Jun 2025)
- Training Samples: ~42.5M trips
- Test Samples: ~10.6M trips
- Data Size: 1.05 GB (53M+ total rows)
- Distributed Training: Yes (Spark cluster with 2-4 workers)

ARCHITECTURE:
- Platform: Azure Databricks
- Storage: Delta Lake (ACID compliant)
- Partitioning: By year and month
- ML Pipeline: VectorAssembler + StandardScaler + Model
- Experiment Tracking: MLFlow

METADATA:
- Model Version: {model_details.version}
- Source Run: {best_run_id}
- Run Name: {best_model_name}
- Registered: {model_details.creation_timestamp}

DEPLOYMENT:
Ready for production inference on batch and streaming data.
"""

# Update model description
client.update_registered_model(
    name=registry_model_name,
    description=description
)

print(f"\nModel description updated")

# Lifecycle Stage Transitions
print(f"\nTransitioning model through lifecycle stages...")

# Stage 1: Move to Staging
print(f"\n[1/2] Transitioning to Staging...")
client.transition_model_version_stage(
    name=registry_model_name,
    version=model_details.version,
    stage="Staging",
    archive_existing_versions=False  # Keep old versions
)
print(f"Version {model_details.version} → Staging")

# Stage 2: Move to Production  
print(f"\n[2/2] Transitioning to Production...")
client.transition_model_version_stage(
    name=registry_model_name,
    version=model_details.version,
    stage="Production",
    archive_existing_versions=True  # Archive old production versions
)
print(f"Version {model_details.version} → Production")

print("\n" + "="*70)
print("MODEL LIFECYCLE STAGES:")
print("  None → Staging → Production")
print("="*70)

# Add version description
version_description = f"""
Best performing model: {clean_model_name.replace('_', ' ').title()}
RMSE: {best_rmse:.2f} min | MAE: {best_mae:.2f} min | R²: {best_r2:.4f}
Promoted to Production based on lowest RMSE among all baseline models.
"""

client.update_model_version(
    name=registry_model_name,
    version=model_details.version,
    description=version_description
)

print(f"\nVersion {model_details.version} description updated")

MODEL LIFECYCLE MANAGEMENT

Model description updated

Transitioning model through lifecycle stages...

[1/2] Transitioning to Staging...


  client.transition_model_version_stage(
  client.transition_model_version_stage(


Version 3 → Staging

[2/2] Transitioning to Production...
Version 3 → Production

MODEL LIFECYCLE STAGES:
  None → Staging → Production

Version 3 description updated


In [0]:
# ============================================
# Load and Test Production Model
# ============================================

from pyspark.sql.functions import avg, abs as spark_abs, col

print("="*70)
print("PRODUCTION MODEL TESTING")
print("="*70)

print(f"\nLoading production model from registry...")

production_model_uri = f"models:/{registry_model_name}/Production"

try:
    production_model = mlflow.spark.load_model(production_model_uri)
    print(f"Production model loaded successfully!")
    print(f"   Model: {registry_model_name}")
    print(f"   Stage: Production")
    print(f"   Version: {model_details.version}")
except Exception as e:
    print(f"Failed to load production model: {str(e)}")
    raise

# Test prediction on sample data
print(f"\nTesting model with sample data...")

try:
    test_sample = spark.table("taxi_ml_test").limit(10)
    test_predictions = production_model.transform(test_sample)
    
    print(f"\nPredictions successful!")
    print(f"\nSample Predictions (First 10 rows):")
    print("="*70)
    
    test_predictions.select(
        col("trip_duration_minutes").alias("Actual (min)"),
        col("prediction").alias("Predicted (min)"),
        spark_abs(col("prediction") - col("trip_duration_minutes")).alias("Error (min)")
    ).show(10, truncate=False)
    
    # Calculate sample accuracy
    sample_metrics = test_predictions.select(
        avg(spark_abs(col("prediction") - col("trip_duration_minutes"))).alias("avg_error"),
        avg(col("trip_duration_minutes")).alias("avg_actual"),
        avg(col("prediction")).alias("avg_predicted")
    ).first()
    
    print(f"\nSample Statistics:")
    print(f"  Average Actual Duration: {sample_metrics['avg_actual']:.2f} minutes")
    print(f"  Average Predicted Duration: {sample_metrics['avg_predicted']:.2f} minutes")
    print(f"  Average Absolute Error: {sample_metrics['avg_error']:.2f} minutes")
    
except Exception as e:
    print(f"Prediction test failed: {str(e)}")
    raise

print("\n" + "="*70)
print("PRODUCTION MODEL READY FOR DEPLOYMENT")
print("="*70)

PRODUCTION MODEL TESTING

Loading production model from registry...


Downloading artifacts:   0%|          | 0/24 [00:00<?, ?it/s]

Production model loaded successfully!
   Model: nyc_taxi_duration_predictor
   Stage: Production
   Version: 3

Testing model with sample data...

Predictions successful!

Sample Predictions (First 10 rows):
+------------------+------------------+-------------------+
|Actual (min)      |Predicted (min)   |Error (min)        |
+------------------+------------------+-------------------+
|33.35             |34.10186822188318 |0.7518682218831785 |
|9.15              |6.955730689199445 |2.1942693108005553 |
|8.583333333333334 |7.414722298119862 |1.1686110352134715 |
|2.8666666666666667|2.4456718750729975|0.42099479159366915|
|8.316666666666666 |7.673308148864867 |0.6433585178017998 |
|44.2              |37.95390848711456 |6.246091512885442  |
|9.983333333333333 |10.249253729994667|0.2659203966613344 |
|9.683333333333334 |9.259646573300543 |0.4236867600327905 |
|13.2              |13.063136089601484|0.13686391039851564|
|10.3              |9.793411708832465 |0.5065882911675352 |
+-----------

In [0]:
# ============================================
# Prepare New Data for Inference
# ============================================

print("\nPreparing new data for batch inference...")

# Simulate "new" data using a specific month as unseen data
# Using most recent month in your dataset
new_trips = spark.table("taxi_trips").filter(
    (col("year") == 2024) & (col("month") == 12)
)

print(f"New trips to process: {new_trips.count():,}")

# Load preprocessing pipeline
from pyspark.ml import PipelineModel

preprocessing = PipelineModel.load("/mnt/taxi-data/models/preprocessing_pipeline")

# Apply preprocessing
new_data_processed = preprocessing.transform(new_trips)
new_data_ml = new_data_processed.select(
    "features",
    "trip_duration_minutes",
    "trip_distance",
    "hour_of_day",
    "day_of_week"
)

print("✓ Data preprocessed and ready for inference")


Preparing new data for batch inference...
New trips to process: 3,185,044
✓ Data preprocessed and ready for inference


In [0]:
# ============================================
# Run Batch Predictions
# DS FEATURE: Distributed batch processing
# ============================================

import time 

print("\nRunning batch predictions...")
print("DS FEATURE: Distributed inference across cluster\n")

start_time = time.time()

# Make predictions
prod_model = mlflow.spark.load_model(production_model_uri)
predictions = prod_model.transform(new_data_ml)

# Force evaluation and count
pred_count = predictions.count()

inference_time = time.time() - start_time
throughput = pred_count / inference_time if inference_time > 0 else 0

print("\n" + "="*70)
print("BATCH INFERENCE PERFORMANCE")
print("="*70)
print(f"Total Predictions: {pred_count:,}")
print(f"Inference Time: {inference_time:.2f} seconds")
print(f"Throughput: {throughput:.0f} predictions/second")
print("="*70)


Running batch predictions...
DS FEATURE: Distributed inference across cluster



Downloading artifacts:   0%|          | 0/24 [00:00<?, ?it/s]


BATCH INFERENCE PERFORMANCE
Total Predictions: 3,185,044
Inference Time: 10.25 seconds
Throughput: 310730 predictions/second


In [0]:
# ============================================
# Enrich Predictions with Business Logic
# ============================================

from pyspark.sql.functions import *

print("\nEnriching predictions with business metrics...")

predictions_enriched = predictions.withColumn(
    "predicted_duration_minutes",
    col("prediction")
).withColumn(
    "actual_duration_minutes",
    col("trip_duration_minutes")
).withColumn(
    "error_minutes",
    abs(col("prediction") - col("trip_duration_minutes"))
).withColumn(
    "error_percentage",
    (abs(col("prediction") - col("trip_duration_minutes")) / col("trip_duration_minutes")) * 100
).withColumn(
    "accuracy_category",
    when(col("error_minutes") < 2, "Excellent")
    .when(col("error_minutes") < 5, "Good")
    .when(col("error_minutes") < 10, "Fair")
    .otherwise("Poor")
).withColumn(
    "prediction_timestamp",
    current_timestamp()
)

print("Predictions enriched with error metrics and categories")

# Show sample enriched predictions
print("\nSample Enriched Predictions:")
predictions_enriched.select(
    "actual_duration_minutes",
    "predicted_duration_minutes",
    "error_minutes",
    "error_percentage",
    "accuracy_category"
).show(10)


Enriching predictions with business metrics...
Predictions enriched with error metrics and categories

Sample Enriched Predictions:
+-----------------------+--------------------------+-------------------+------------------+-----------------+
|actual_duration_minutes|predicted_duration_minutes|      error_minutes|  error_percentage|accuracy_category|
+-----------------------+--------------------------+-------------------+------------------+-----------------+
|     36.583333333333336|         41.89129394341718|  5.307960610083846|14.509231735992289|             Fair|
|      8.483333333333333|         8.955167405314453|0.47183407198112093| 5.561894758127163|        Excellent|
|     26.116666666666667|        24.886576815853495| 1.2300898508131723| 4.709980283904935|        Excellent|
|     13.916666666666666|         14.49185406043568| 0.5751873937690135|   4.1330830689989|        Excellent|
|                    9.4|         9.802226407944177|0.40222640794417686| 4.279004339831668|      

In [0]:
# ============================================
# Save Predictions to Delta Lake
# DS FEATURE: ACID transactions
# ============================================

print("\nSaving predictions to Delta Lake...")

predictions_path = "/mnt/taxi-data/delta/predictions"

predictions_enriched.write \
    .format("delta") \
    .mode("overwrite") \
    .partitionBy("accuracy_category") \
    .save(predictions_path)

# Register as table
spark.sql(f"""
    CREATE TABLE IF NOT EXISTS taxi_predictions
    USING DELTA
    LOCATION '{predictions_path}'
""")

print(f"Predictions saved to: {predictions_path}")
print("Table created: taxi_predictions")
print("Partitioned by: accuracy_category")
print("Format: Delta Lake (ACID-compliant)")


Saving predictions to Delta Lake...
Predictions saved to: /mnt/taxi-data/delta/predictions
Table created: taxi_predictions
Partitioned by: accuracy_category
Format: Delta Lake (ACID-compliant)


In [0]:
# ============================================
# Prediction Quality Analysis
# ============================================

print("\n" + "="*70)
print("PREDICTION QUALITY ANALYSIS")
print("="*70)

# Overall statistics
overall_stats = predictions_enriched.select(
    count("*").alias("total_predictions"),
    round(avg("error_minutes"), 2).alias("avg_error_minutes"),
    round(stddev("error_minutes"), 2).alias("stddev_error"),
    round(percentile_approx("error_minutes", 0.5), 2).alias("median_error"),
    round(percentile_approx("error_minutes", 0.95), 2).alias("p95_error")
).collect()[0]

print("\nOverall Statistics:")
print(f"  Total Predictions: {overall_stats['total_predictions']:,}")
print(f"  Avg Error: {overall_stats['avg_error_minutes']:.2f} minutes")
print(f"  Std Dev: {overall_stats['stddev_error']:.2f} minutes")
print(f"  Median Error: {overall_stats['median_error']:.2f} minutes")
print(f"  95th Percentile: {overall_stats['p95_error']:.2f} minutes")

# By accuracy category
print("\nPredictions by Accuracy Category:")
spark.sql("""
    SELECT 
        accuracy_category,
        COUNT(*) as count,
        ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 1) as percentage,
        ROUND(AVG(error_minutes), 2) as avg_error
    FROM taxi_predictions
    GROUP BY accuracy_category
    ORDER BY 
        CASE accuracy_category
            WHEN 'Excellent' THEN 1
            WHEN 'Good' THEN 2
            WHEN 'Fair' THEN 3
            ELSE 4
        END
""").show()

# By hour of day
print("\nPrediction Accuracy by Hour of Day:")
spark.sql("""
    SELECT 
        hour_of_day,
        COUNT(*) as predictions,
        ROUND(AVG(error_minutes), 2) as avg_error
    FROM taxi_predictions
    GROUP BY hour_of_day
    ORDER BY hour_of_day
""").show(24)

# By day of week
print("\nPrediction Accuracy by Day of Week:")
spark.sql("""
    SELECT 
        day_of_week,
        CASE day_of_week
            WHEN 1 THEN 'Sunday'
            WHEN 2 THEN 'Monday'
            WHEN 3 THEN 'Tuesday'
            WHEN 4 THEN 'Wednesday'
            WHEN 5 THEN 'Thursday'
            WHEN 6 THEN 'Friday'
            WHEN 7 THEN 'Saturday'
        END as day_name,
        COUNT(*) as predictions,
        ROUND(AVG(error_minutes), 2) as avg_error
    FROM taxi_predictions
    GROUP BY day_of_week
    ORDER BY day_of_week
""").show()



PREDICTION QUALITY ANALYSIS

Overall Statistics:
  Total Predictions: 3,185,044
  Avg Error: 2.41 minutes
  Std Dev: 4.73 minutes
  Median Error: 1.01 minutes
  95th Percentile: 9.65 minutes

Predictions by Accuracy Category:
+-----------------+-------+----------+---------+
|accuracy_category|  count|percentage|avg_error|
+-----------------+-------+----------+---------+
|        Excellent|2313264|      72.6|     0.76|
|             Good| 522448|      16.4|     3.06|
|             Fair| 197543|       6.2|     6.98|
|             Poor| 151789|       4.8|    19.43|
+-----------------+-------+----------+---------+


Prediction Accuracy by Hour of Day:
+-----------+-----------+---------+
|hour_of_day|predictions|avg_error|
+-----------+-----------+---------+
|          0|      89274|     2.18|
|          1|      54184|     1.84|
|          2|      32495|     1.67|
|          3|      20318|     1.81|
|          4|      14534|     2.75|
|          5|      16138|     3.87|
|          6|      