In [None]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.functions import col, year, month, dayofyear, when, lit, unix_timestamp, count, avg
from pyspark.sql.types import IntegerType
from datetime import datetime, timedelta
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.figure import Figure
import numpy as np
import math
from typing import List, Optional, Union, Tuple

In [4]:
# ----------------------------------------------------------------------
# Create Local PySpark Session with BigQuery Integration
# ----------------------------------------------------------------------

# 🆕 Environment Setup for Local PySpark
from google.cloud import bigquery
import pandas as pd

# Set up BigQuery credentials for local development
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/Applications/saggydev/projects_learning/data_engineering_course/secrets/dtc-de-course-466501-e23cbf158abc.json"
print(f"🔐 Using credentials: {os.environ.get('GOOGLE_APPLICATION_CREDENTIALS')}")

# Enhanced Spark configuration for local large-scale incremental processing
# 🔧 LOCAL SPARK CONFIGURATION EXPLAINED:
#
# BIGQUERY INTEGRATION - Local PySpark to BigQuery connectivity
# ├─ "spark.jars.packages" - Automatically downloads BigQuery connector
# │  └─ No manual JAR placement needed, handles all dependencies
# │
# ├─ "spark.sql.execution.arrow.pyspark.enabled" = "false"
# │  └─ Disabled to avoid compatibility issues with BigQuery connector
# │
# └─ Memory configs optimized for local machine processing
#
# ADAPTIVE QUERY EXECUTION (AQE) - Runtime optimization based on actual data
# ├─ "spark.sql.adaptive.enabled" = "true"
# │  └─ Enables AQE for dynamic query optimization during execution
# │
# ├─ "spark.sql.adaptive.coalescePartitions.enabled" = "true"
# │  └─ Automatically reduces number of partitions after shuffle to optimize performance
# │
# ├─ "spark.sql.adaptive.localShuffleReader.enabled" = "true"
# │  └─ Reduces network I/O by reading shuffle data locally when possible
# │
# └─ "spark.sql.adaptive.advisoryPartitionSizeInBytes" = "256MB"
#    └─ Target size for each partition (256MB optimal for most workloads)
#
# SKEW HANDLING - Deals with uneven data distribution
# ├─ "spark.sql.adaptive.skewJoin.enabled" = "true"
# │  └─ Automatically detects and handles data skew in joins
# │
# ├─ "spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes" = "256MB"
# │  └─ Partition size threshold to identify skewed partitions
# │
# └─ "spark.sql.adaptive.skewJoin.skewedPartitionFactor" = "5"
#    └─ Factor to determine skew (partition 5x larger than median = skewed)
#
# PERFORMANCE OPTIMIZATIONS
# ├─ "spark.sql.shuffle.partitions" = "200" (reduced for local)
# │  └─ Number of partitions for shuffle operations (reduced for local processing)
# │
# └─ "spark.serializer" = "org.apache.spark.serializer.KryoSerializer"
#    └─ Faster serialization compared to default Java serialization

spark = (
    SparkSession.builder
    .appName("NYC Taxi Duration Prediction - Local")
    # 🆕 BigQuery connector - automatically downloads from Maven Central
    .config("spark.jars.packages", "com.google.cloud.spark:spark-bigquery-with-dependencies_2.12:0.32.0")
    # 🆕 DISABLE Arrow to avoid local compatibility issues
    .config("spark.sql.execution.arrow.pyspark.enabled", "false")
    # 🆕 Memory configs optimized for local machine
    .config("spark.driver.memory", "4g")
    .config("spark.executor.memory", "4g")
    .config("spark.driver.maxResultSize", "2g")
    # Performance optimizations (adjusted for local)
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .config("spark.sql.adaptive.skewJoin.enabled", "true")
    .config("spark.sql.adaptive.localShuffleReader.enabled", "true")
    .config("spark.sql.adaptive.advisoryPartitionSizeInBytes", "256MB")
    .config("spark.sql.shuffle.partitions", "200")  # Reduced for local processing
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .config("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")
    .config("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
    .getOrCreate()
)

# ----------------------------------------------------------------------
# GCP Project and Dataset Config
# ----------------------------------------------------------------------

PROJECT_ID = "dtc-de-course-466501"
DATASET_ID = "dbt_production"

INCREMENTAL_CONFIG = {
    "start_year": 2015,
    "end_year": 2016,
    "batch_size_months": 6,
    "checkpoint_dir": "/tmp/spark-checkpoints",
    "cache_level": "MEMORY_AND_DISK_SER",
    "max_records_per_partition": 1000000,
    # 🆕 Local processing limits
    "local_test_limit": 10000,  # Limit for local testing
    "local_batch_limit": 50000  # Limit for local batch processing
}

print(f"✅ Spark Version: {spark.version}")
print(f"📊 Working with: {PROJECT_ID}.{DATASET_ID}")
print(f"🔄 Incremental Processing: {INCREMENTAL_CONFIG['start_year']}-{INCREMENTAL_CONFIG['end_year']}")
print(f"🏠 Local Processing Mode: Limits enabled for efficient local development")

# ----------------------------------------------------------------------
# Hybrid BigQuery Data Loading Functions
# ----------------------------------------------------------------------

def load_table_hybrid(spark, table_name, project_id=PROJECT_ID, dataset_id=DATASET_ID, limit=None):
    """
    Hybrid loader that tries direct connector first, falls back to client method
    """
    from google.cloud import bigquery
    import logging
    
    # Tables known to have complex schemas that cause issues with direct connector
    complex_tables = ["fact_trips"]
    
    # Skip direct connector for known complex tables
    if table_name in complex_tables:
        print(f"📊 Table '{table_name}' has complex schema, using BigQuery client method")
        
        # Create BigQuery client
        client = bigquery.Client()
        
        # Build query with limit
        limit_clause = f"LIMIT {limit}" if limit else ""
        query = f"SELECT * FROM `{project_id}.{dataset_id}.{table_name}` {limit_clause}"
        
        print(f"🔍 Running query: {query}")
        
        # Execute query and convert to pandas DataFrame
        query_job = client.query(query)
        df_pandas = query_job.to_dataframe()
        
        # Convert pandas DataFrame to Spark DataFrame
        df_spark = spark.createDataFrame(df_pandas)
        
        print(f"✅ BigQuery client successful: {df_spark.count()} rows loaded")
        return df_spark
    
    else:
        # Try direct connector for simple tables
        print(f"📊 Attempting direct connector for table '{table_name}'")
        try:
            reader = (spark.read
                    .format("bigquery")
                    .option("project", project_id)
                    .option("dataset", dataset_id)
                    .option("table", table_name)
                    .option("readDataFormat", "AVRO")
                    .option("maxParallelism", "1")
            )
            
            df = reader.load()
            if limit:
                df = df.limit(limit)
            
            # Test if it works
            df.take(1)
            print(f"✅ Direct connector successful for {table_name}")
            return df
            
        except Exception as e:
            print(f"⚠️ Direct connector failed for {table_name}: {str(e)}")
            print(f"🔄 Falling back to BigQuery client method")
            
            # Fallback to client method
            client = bigquery.Client()
            limit_clause = f"LIMIT {limit}" if limit else ""
            query = f"SELECT * FROM `{project_id}.{dataset_id}.{table_name}` {limit_clause}"
            
            query_job = client.query(query)
            df_pandas = query_job.to_dataframe()
            df_spark = spark.createDataFrame(df_pandas)
            
            print(f"✅ BigQuery client fallback successful: {df_spark.count()} rows loaded")
            return df_spark

def run_bigquery_query(spark, query):
    """
    Run a custom BigQuery query and load results into Spark DataFrame
    """
    client = bigquery.Client()
    
    print(f"🔍 Running custom query...")
    query_job = client.query(query)
    df_pandas = query_job.to_dataframe()
    df_spark = spark.createDataFrame(df_pandas)
    
    print(f"✅ Query successful: {df_spark.count()} rows loaded")
    return df_spark

# ----------------------------------------------------------------------
# Example: Read from BigQuery using Hybrid Approach
# ----------------------------------------------------------------------

print("\n" + "="*80)
print("🚀 Testing BigQuery Connection with Hybrid Approach")
print("="*80)

# Test with a small sample first
print("📋 Loading sample data from fact_trips table...")

df = load_table_hybrid(
    spark, 
    "fact_trips", 
    PROJECT_ID, 
    DATASET_ID, 
    limit=INCREMENTAL_CONFIG["local_test_limit"]  # Use local test limit
)

print("✅ Successfully loaded BigQuery table into Spark DataFrame!")
print(f"📊 Schema:")
df.printSchema()

print(f"📋 Sample data:")
df.show(5)

print(f"📈 Basic statistics:")
print(f"   • Total rows: {df.count():,}")
print(f"   • Total columns: {len(df.columns)}")



🔐 Using credentials: /Applications/saggydev/projects_learning/data_engineering_course/secrets/dtc-de-course-466501-e23cbf158abc.json
:: loading settings :: url = jar:file:/Applications/saggydev/projects_learning/data_engineering_course/.venv/lib/python3.9/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /Users/saggysimmba/.ivy2/cache
The jars for the packages stored in: /Users/saggysimmba/.ivy2/jars
com.google.cloud.spark#spark-bigquery-with-dependencies_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-9e82a2d8-150b-4775-8c14-478c5827d7ab;1.0
	confs: [default]
	found com.google.cloud.spark#spark-bigquery-with-dependencies_2.12;0.32.0 in central
:: resolution report :: resolve 79ms :: artifacts dl 1ms
	:: modules in use:
	com.google.cloud.spark#spark-bigquery-with-dependencies_2.12;0.32.0 from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	---------------------------------------------------------------------
	|      default     |   1   |   0   |   0   |   0   ||   1   |   0   |
	------------------------------------------------------

✅ Spark Version: 3.5.4
📊 Working with: dtc-de-course-466501.dbt_production
🔄 Incremental Processing: 2015-2016
🏠 Local Processing Mode: Limits enabled for efficient local development

🚀 Testing BigQuery Connection with Hybrid Approach
📋 Loading sample data from fact_trips table...
📊 Table 'fact_trips' has complex schema, using BigQuery client method
🔍 Running query: SELECT * FROM `dtc-de-course-466501.dbt_production.fact_trips` LIMIT 10000


25/08/04 19:59:16 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors
25/08/04 19:59:19 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

✅ BigQuery client successful: 10000 rows loaded
✅ Successfully loaded BigQuery table into Spark DataFrame!
📊 Schema:
root
 |-- tripid: string (nullable = true)
 |-- vendorid: long (nullable = true)
 |-- service_type: string (nullable = true)
 |-- ratecodeid: long (nullable = true)
 |-- pickup_locationid: long (nullable = true)
 |-- pickup_borough: string (nullable = true)
 |-- pickup_zone: string (nullable = true)
 |-- dropoff_locationid: long (nullable = true)
 |-- dropoff_borough: string (nullable = true)
 |-- dropoff_zone: string (nullable = true)
 |-- pickup_datetime: timestamp (nullable = true)
 |-- pickup_date: date (nullable = true)
 |-- dropoff_datetime: timestamp (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: decimal(38,18) (nullable = true)
 |-- fare_amount: decimal(38,18) (nullable = true)
 |-- extra: decimal(38,18) (nullable = true)
 |-- mta_tax: decimal(38,18) (nullable = true)
 |-- tip_a

**Test if data batching works**

In [6]:
START_YEAR = 2015
END_YEAR = 2016

print(f"✅ Spark Version: {spark.version}")
print(f"📊 Working with: {PROJECT_ID}.{DATASET_ID}")
print(f"🔄 Incremental Processing for years: {START_YEAR}-{END_YEAR}")

# ---------------------------------------------------------------
# Read data from BigQuery in yearly batches
# ---------------------------------------------------------------
table_name = f"{PROJECT_ID}.{DATASET_ID}.fact_trips"

for year in range(START_YEAR, END_YEAR + 1):

    year_filter = f"WHERE EXTRACT(YEAR FROM pickup_date) = {year}"
    query = f"""
    SELECT * FROM `{PROJECT_ID}.{DATASET_ID}.fact_trips`
    {year_filter}
    """

    # Load data for that year
    df_year = run_bigquery_query(spark, query)  # ✅ No Java errors!

    # Add year_month for inspection
    df_with_month = df_year.withColumn(
        "year_month",
        F.substring("pickup_date", 1, 7)
    )

    # Show distinct months in this batch
    months_df = (
        df_with_month
        .select("year_month")
        .distinct()
        .orderBy("year_month")
    )

    months_list = months_df.collect()
    months_str_list = [row["year_month"] for row in months_list]

    print("✅ Loaded batch for year:", year)
    months_df.show(truncate=False)

    print(f"📅 Number of months found: {len(months_str_list)}")
    if months_str_list:
      print(f"🗓️ First month in batch: {months_str_list[0]}")
      print(f"🗓️ Last month in batch: {months_str_list[-1]}")
    else:
      print("⚠️ No data found for this batch.")

    # ----
    # YOUR LOGIC HERE for processing df_with_month
    # ----

    print("="*80)

✅ Spark Version: 3.5.4
📊 Working with: dtc-de-course-466501.dbt_production
🔄 Incremental Processing for years: 2015-2016
🔍 Running custom query...


: 

**Potential Optimization for this query**
- Push‐down your entire date‐range filter so you only hit BigQuery once, rather than once per year.

- Select only the columns you actually need (e.g. pickup_date plus whatever other fields your logic needs) to reduce I/O.

- Use Spark’s built-in date functions (date_format, year) instead of substring.

- Compute your “months seen” stats in one pass via DataFrame aggregations—no collect() of the full list required.

- Cache the filtered DataFrame if you’re going to loop over it multiple times.

- Iterate per month (rather than per year) so you can build true “incremental” logic month-by-month if needed

In [None]:
# ----------------------------------------------------------------------
# FIXED Yearly Batch Processing - LOCAL PYSPARK VERSION
# ----------------------------------------------------------------------
# ⚠️ Use this cell instead of the problematic one above!

START_YEAR = 2015
END_YEAR = 2016

print(f"✅ Spark Version: {spark.version}")
print(f"📊 Working with: {PROJECT_ID}.{DATASET_ID}")
print(f"🔄 Incremental Processing for years: {START_YEAR}-{END_YEAR}")
print(f"🏠 Local Mode: Using batch limits for efficient processing")

# ---------------------------------------------------------------
# Read data from BigQuery in yearly batches - HYBRID APPROACH
# ---------------------------------------------------------------

yearly_dataframes = {}  # Store DataFrames for each year

for year in range(START_YEAR, END_YEAR + 1):
    
    print(f"\n{'='*60}")
    print(f"🗓️ Processing Year: {year}")
    print(f"{'='*60}")
    
    # 🆕 Use custom BigQuery SQL query instead of direct connector
    # This avoids the Java DirectByteBuffer errors
    query = f"""
    SELECT * FROM `{PROJECT_ID}.{DATASET_ID}.fact_trips`
    WHERE EXTRACT(YEAR FROM pickup_date) = {year}
    LIMIT {INCREMENTAL_CONFIG['local_batch_limit']}
    """
    
    print(f"🔍 Loading up to {INCREMENTAL_CONFIG['local_batch_limit']:,} rows for year {year}")
    
    # 🆕 Use the hybrid BigQuery query function (no direct connector issues)
    df_year = run_bigquery_query(spark, query)
    
    # Add year_month for inspection (same as before)
    df_with_month = df_year.withColumn(
        "year_month",
        F.substring("pickup_date", 1, 7)
    )
    
    # Show distinct months in this batch
    months_df = (
        df_with_month
        .select("year_month")
        .distinct()
        .orderBy("year_month")
    )
    
    months_list = months_df.collect()
    months_str_list = [row["year_month"] for row in months_list]
    
    print(f"✅ Loaded batch for year: {year}")
    print(f"📊 Actual rows loaded: {df_with_month.count():,}")
    
    print(f"📅 Number of months found: {len(months_str_list)}")
    if months_str_list:
        print(f"🗓️ First month in batch: {months_str_list[0]}")
        print(f"🗓️ Last month in batch: {months_str_list[-1]}")
        
        # Show month distribution
        print(f"📊 Months in this year:")
        months_df.show(truncate=False)
    else:
        print("⚠️ No data found for this batch.")
    
    # 🆕 Store DataFrame for later use
    yearly_dataframes[year] = df_with_month
    
    # 🆕 Show sample of the data
    print(f"📋 Sample data for {year}:")
    if 'tripid' in df_with_month.columns:
        df_with_month.select(
            "tripid", "pickup_zone", "dropoff_zone", 
            "trip_distance", "fare_amount", "year_month"
        ).show(3)
    else:
        # Show whatever columns are available
        df_with_month.show(3)
    
    # ----
    # YOUR PROCESSING LOGIC HERE for df_with_month
    # ----
    
    print(f"💾 Year {year} processing complete")

print(f"\n🎉 All yearly batches processed!")
print(f"📊 Years available: {list(yearly_dataframes.keys())}")
print(f"💡 Access individual years using: yearly_dataframes[2015], yearly_dataframes[2016], etc.")

# 🆕 Optional: Create combined DataFrame from all years
print(f"\n🔗 Creating combined DataFrame...")
all_years_df = None
for year, df in yearly_dataframes.items():
    if all_years_df is None:
        all_years_df = df
    else:
        all_years_df = all_years_df.union(df)

if all_years_df:
    total_rows = all_years_df.count()
    print(f"✅ Combined DataFrame created with {total_rows:,} total rows")
    print(f"📊 Data spans {len(yearly_dataframes)} years")
else:
    print("⚠️ No data loaded")

print("="*80)


## Multiplot plotting functions

In [12]:
# Multi-panel plotting function (R multiplot equivalent)
# Courtesy of R Cookbooks adapted for Python/PySpark
def multiplot(*plots, plotlist=None, cols=1, layout=None, figsize=(15, 10),
              title=None, save_path=None, dpi=300):
    """
    Create multi-panel plots from matplotlib/seaborn/plotly figures

    Python equivalent of R's multiplot function from R Cookbooks

    Parameters:
    -----------
    *plots : matplotlib figures or plot functions
        Individual plots to be arranged
    plotlist : list, optional
        List of plots as alternative to *plots
    cols : int, default=1
        Number of columns in layout
    layout : array-like, optional
        Matrix specifying the layout. If present, 'cols' is ignored.
        Example: [[1,2], [3,3]] means plot 1 top-left, 2 top-right, 3 bottom spanning both columns
    figsize : tuple, default=(15, 10)
        Figure size (width, height) in inches
    title : str, optional
        Overall title for the multi-panel plot
    save_path : str, optional
        Path to save the figure
    dpi : int, default=300
        Resolution for saved figure

    Returns:
    --------
    matplotlib.figure.Figure
        The combined figure with all subplots

    Example:
    --------
    # Create individual plots
    fig1, ax1 = plt.subplots()
    ax1.plot([1,2,3], [1,4,2])

    fig2, ax2 = plt.subplots()
    ax2.bar([1,2,3], [3,1,4])

    # Combine them
    combined_fig = multiplot(fig1, fig2, cols=2, title="Combined Analysis")
    """

    # Combine plots from arguments and plotlist
    all_plots = list(plots) if plots else []
    if plotlist:
        all_plots.extend(plotlist)

    num_plots = len(all_plots)

    if num_plots == 0:
        print("⚠️ No plots provided")
        return None

    # Handle single plot case
    if num_plots == 1:
        if hasattr(all_plots[0], 'show'):
            all_plots[0].show()
        else:
            plt.figure(figsize=figsize)
            if title:
                plt.suptitle(title, fontsize=16, fontweight='bold')
            plt.show()
        return all_plots[0]

    # Determine layout
    if layout is None:
        # Calculate rows and columns
        nrows = math.ceil(num_plots / cols)
        ncols = cols
        layout_matrix = np.arange(1, cols * nrows + 1).reshape(nrows, ncols)
    else:
        layout_matrix = np.array(layout)
        nrows, ncols = layout_matrix.shape

    # Create the main figure
    fig = plt.figure(figsize=figsize)

    if title:
        fig.suptitle(title, fontsize=16, fontweight='bold', y=0.95)

    # Create GridSpec for flexible subplot arrangement
    gs = gridspec.GridSpec(nrows, ncols, figure=fig, hspace=0.3, wspace=0.3)

    # Place each plot in the correct position
    for i, plot in enumerate(all_plots, 1):
        if i > num_plots:
            break

        # Find positions where this plot should go
        positions = np.where(layout_matrix == i)

        if len(positions[0]) == 0:
            continue

        # Calculate subplot span
        row_min, row_max = positions[0].min(), positions[0].max()
        col_min, col_max = positions[1].min(), positions[1].max()

        # Create subplot
        ax = fig.add_subplot(gs[row_min:row_max+1, col_min:col_max+1])

        # Handle different plot types
        if hasattr(plot, 'figure'):
            # It's a matplotlib figure
            _copy_plot_to_axis(plot, ax)
        elif callable(plot):
            # It's a plotting function
            plot(ax)
        elif hasattr(plot, 'axes'):
            # It's a figure with axes
            _copy_plot_to_axis(plot, ax)
        else:
            # Try to handle as data for direct plotting
            ax.text(0.5, 0.5, f'Plot {i}', ha='center', va='center',
                   transform=ax.transAxes)

    plt.tight_layout()

    # Save if path provided
    if save_path:
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
        print(f"💾 Multi-panel plot saved to: {save_path}")

    plt.show()
    return fig

def _copy_plot_to_axis(source_fig, target_ax):
    """Helper function to copy plot content from source figure to target axis"""
    try:
        if hasattr(source_fig, 'axes') and source_fig.axes:
            source_ax = source_fig.axes[0]

            # Copy lines
            for line in source_ax.get_lines():
                target_ax.plot(line.get_xdata(), line.get_ydata(),
                             color=line.get_color(), linewidth=line.get_linewidth(),
                             linestyle=line.get_linestyle(), marker=line.get_marker(),
                             label=line.get_label())

            # Copy patches (bars, etc.)
            for patch in source_ax.patches:
                target_ax.add_patch(patch)

            # Copy collections (scatter plots, etc.)
            for collection in source_ax.collections:
                target_ax.add_collection(collection)

            # Copy labels and title
            target_ax.set_xlabel(source_ax.get_xlabel())
            target_ax.set_ylabel(source_ax.get_ylabel())
            target_ax.set_title(source_ax.get_title())

            # Copy limits
            target_ax.set_xlim(source_ax.get_xlim())
            target_ax.set_ylim(source_ax.get_ylim())

            # Copy legend if exists
            if source_ax.get_legend():
                target_ax.legend()

    except Exception as e:
        # Fallback: just add a text placeholder
        target_ax.text(0.5, 0.5, 'Plot Content', ha='center', va='center',
                      transform=target_ax.transAxes)
        print(f"⚠️ Could not copy plot content: {e}")

# Enhanced multiplot for PySpark DataFrames
def multiplot_spark(dataframes_and_plots, cols=2, figsize=(15, 10), title=None):
    """
    Create multi-panel plots specifically for PySpark DataFrame visualizations

    Parameters:
    -----------
    dataframes_and_plots : list of tuples
        Each tuple contains (spark_dataframe, plot_config)
        plot_config is a dict with keys: 'type', 'x', 'y', 'title', etc.
    cols : int
        Number of columns
    figsize : tuple
        Figure size
    title : str
        Overall title

    Example:
    --------
    plot_configs = [
        (df_2020, {'type': 'hist', 'x': 'trip_distance', 'title': '2020 Trip Distance'}),
        (df_2021, {'type': 'scatter', 'x': 'trip_distance', 'y': 'fare_amount', 'title': '2021 Distance vs Fare'}),
        (df_2022, {'type': 'bar', 'x': 'pickup_hour', 'y': 'count', 'title': '2022 Hourly Trips'})
    ]
    multiplot_spark(plot_configs, cols=2, title="Yearly Comparison")
    """

    num_plots = len(dataframes_and_plots)
    nrows = math.ceil(num_plots / cols)

    fig, axes = plt.subplots(nrows, cols, figsize=figsize)
    if title:
        fig.suptitle(title, fontsize=16, fontweight='bold')

    # Flatten axes array for easy indexing
    if num_plots == 1:
        axes = [axes]
    elif nrows == 1:
        axes = axes
    else:
        axes = axes.flatten()

    for i, (spark_df, plot_config) in enumerate(dataframes_and_plots):
        if i >= len(axes):
            break

        ax = axes[i]

        # Convert to Pandas for plotting
        pandas_df = spark_df.toPandas()

        plot_type = plot_config.get('type', 'scatter')
        x_col = plot_config.get('x')
        y_col = plot_config.get('y')
        plot_title = plot_config.get('title', f'Plot {i+1}')

        # Create the appropriate plot
        if plot_type == 'hist':
            ax.hist(pandas_df[x_col], bins=30, alpha=0.7)
            ax.set_xlabel(x_col)
            ax.set_ylabel('Frequency')
        elif plot_type == 'scatter':
            ax.scatter(pandas_df[x_col], pandas_df[y_col], alpha=0.6)
            ax.set_xlabel(x_col)
            ax.set_ylabel(y_col)
        elif plot_type == 'bar':
            if y_col:
                ax.bar(pandas_df[x_col], pandas_df[y_col])
            else:
                value_counts = pandas_df[x_col].value_counts()
                ax.bar(value_counts.index, value_counts.values)
            ax.set_xlabel(x_col)
            ax.set_ylabel(y_col or 'Count')
        elif plot_type == 'line':
            ax.plot(pandas_df[x_col], pandas_df[y_col])
            ax.set_xlabel(x_col)
            ax.set_ylabel(y_col)

        ax.set_title(plot_title)
        ax.grid(True, alpha=0.3)

    # Hide unused subplots
    for i in range(num_plots, len(axes)):
        axes[i].set_visible(False)

    plt.tight_layout()
    plt.show()
    return fig

# Quick plotting helper for common PySpark visualizations
def quick_multiplot_comparison(yearly_data_dict, plot_type='hist', column='trip_distance',
                              cols=2, figsize=(15, 10)):
    """
    Quick comparison plots across multiple years

    Parameters:
    -----------
    yearly_data_dict : dict
        Dictionary with year as key, PySpark DataFrame as value
    plot_type : str
        Type of plot ('hist', 'box', 'violin')
    column : str
        Column to plot
    cols : int
        Number of columns in layout
    figsize : tuple
        Figure size

    Example:
    --------
    yearly_data = {
        2020: df_2020,
        2021: df_2021,
        2022: df_2022
    }
    quick_multiplot_comparison(yearly_data, 'hist', 'trip_distance')
    """

    plot_configs = []
    for year, df in yearly_data_dict.items():
        config = {
            'type': plot_type,
            'x': column,
            'title': f'{year} - {column.replace("_", " ").title()}'
        }
        plot_configs.append((df, config))

    return multiplot_spark(plot_configs, cols=cols, figsize=figsize,
                          title=f"Year-over-Year Comparison: {column.replace('_', ' ').title()}")

print("✅ Multi-panel plotting functions created (R multiplot equivalent)")
print("✅ PySpark-specific multiplot functions ready")
print("✅ Quick comparison plotting utilities available")

# Example usage guide
print("\n" + "="*50)
print("MULTIPLOT USAGE EXAMPLES")
print("="*50)
print("""
# Basic usage:
fig1, ax1 = plt.subplots()
ax1.plot([1,2,3], [1,4,2])

fig2, ax2 = plt.subplots()
ax2.bar([1,2,3], [3,1,4])

combined = multiplot(fig1, fig2, cols=2, title="Side by Side")

# PySpark DataFrame plotting:
plot_configs = [
    (df_2020, {'type': 'hist', 'x': 'trip_distance', 'title': '2020 Trips'}),
    (df_2021, {'type': 'hist', 'x': 'trip_distance', 'title': '2021 Trips'})
]
multiplot_spark(plot_configs, cols=2)

# Quick year comparison:
yearly_data = {2020: df_2020, 2021: df_2021, 2022: df_2022}
quick_multiplot_comparison(yearly_data, 'hist', 'trip_distance')
""")

✅ Multi-panel plotting functions created (R multiplot equivalent)
✅ PySpark-specific multiplot functions ready
✅ Quick comparison plotting utilities available

MULTIPLOT USAGE EXAMPLES

# Basic usage:
fig1, ax1 = plt.subplots()
ax1.plot([1,2,3], [1,4,2])

fig2, ax2 = plt.subplots()
ax2.bar([1,2,3], [3,1,4])

combined = multiplot(fig1, fig2, cols=2, title="Side by Side")

# PySpark DataFrame plotting:
plot_configs = [
    (df_2020, {'type': 'hist', 'x': 'trip_distance', 'title': '2020 Trips'}),
    (df_2021, {'type': 'hist', 'x': 'trip_distance', 'title': '2021 Trips'})
]
multiplot_spark(plot_configs, cols=2)

# Quick year comparison:
yearly_data = {2020: df_2020, 2021: df_2021, 2022: df_2022}
quick_multiplot_comparison(yearly_data, 'hist', 'trip_distance')



In [13]:
combined_df = None

for year in range(2015, 2017):
  df_year = (
      spark.read
           .format("bigquery")
           .option("table", table_name)
           .option("filter", f"EXTRACT(YEAR FROM pickup_date) = {year}")
           .load()
  )

  if combined_df is None:
    combined_df = df_year
  else:
    combined_df = combined_df.unionByName(df_year)

print("✅ Loaded combined data across years")
combined_df.show(5)

✅ Loaded combined data across years


+--------------------+--------+------------+----------+-----------------+--------------+-----------------+------------------+---------------+--------------------+-------------------+-----------+-------------------+------------------+---------------+-------------+------------+-----------+-----------+-----------+------------+---------------------+------------+------------+------------------------+------------+-----+-----------+-----------+------------+---------------+------------+------------+------------+-----------+
|              tripid|vendorid|service_type|ratecodeid|pickup_locationid|pickup_borough|      pickup_zone|dropoff_locationid|dropoff_borough|        dropoff_zone|    pickup_datetime|pickup_date|   dropoff_datetime|store_and_fwd_flag|passenger_count|trip_distance| fare_amount|      extra|    mta_tax| tip_amount|tolls_amount|improvement_surcharge|total_amount|payment_type|payment_type_description|climate_date|  mjd| cloudCover|   humidity|    dewPoint|precipIntensity|    high

**Potential Optimization for this query**
- None as of now. It is super-fast

In [14]:
df_with_time = (
    combined_df
    .withColumn("pickup_year", F.year("pickup_date"))
    .withColumn("pickup_month", F.month("pickup_date"))
)

df_with_time.show(5)

+--------------------+--------+------------+----------+-----------------+--------------+-----------------+------------------+---------------+--------------------+-------------------+-----------+-------------------+------------------+---------------+-------------+------------+-----------+-----------+-----------+------------+---------------------+------------+------------+------------------------+------------+-----+-----------+-----------+------------+---------------+------------+------------+------------+-----------+-----------+------------+
|              tripid|vendorid|service_type|ratecodeid|pickup_locationid|pickup_borough|      pickup_zone|dropoff_locationid|dropoff_borough|        dropoff_zone|    pickup_datetime|pickup_date|   dropoff_datetime|store_and_fwd_flag|passenger_count|trip_distance| fare_amount|      extra|    mta_tax| tip_amount|tolls_amount|improvement_surcharge|total_amount|payment_type|payment_type_description|climate_date|  mjd| cloudCover|   humidity|    dewPoint

**Consistency Check**
- A column called `trip_durations` needs to be created which calculates the intervals between `pickup_datetime` and `dropoff_datetime`

**Potential Optimization for this query**
- None as of now. Works for the time being

In [15]:
df_with_duration = df_with_time.withColumn(
    "trip_duration_min",
    (
        (F.col("dropoff_datetime").cast("long")
         - F.col("pickup_datetime").cast("long"))
        .cast("double")
        / F.lit(60.0)
    )
)

df_with_duration.select(
    "pickup_datetime", "dropoff_datetime", "trip_duration_min"
).show(5, truncate=False)

+-------------------+-------------------+------------------+
|pickup_datetime    |dropoff_datetime   |trip_duration_min |
+-------------------+-------------------+------------------+
|2015-01-01 00:00:42|2015-01-01 00:22:22|21.666666666666668|
|2015-01-01 00:06:04|2015-01-01 00:39:02|32.96666666666667 |
|2015-01-01 00:10:40|2015-01-01 00:15:34|4.9               |
|2015-01-01 00:21:14|2015-01-01 00:38:55|17.683333333333334|
|2015-01-01 00:22:38|2015-01-01 00:36:44|14.1              |
+-------------------+-------------------+------------------+
only showing top 5 rows



**Potential Optimization for this query**
- None needed as of now. Fast enough IMO.

In [16]:
df_with_duration.explain(extended=True)

== Parsed Logical Plan ==
Project [tripid#3084, vendorid#3085L, service_type#3086, ratecodeid#3087L, pickup_locationid#3088L, pickup_borough#3089, pickup_zone#3090, dropoff_locationid#3091L, dropoff_borough#3092, dropoff_zone#3093, pickup_datetime#3094, pickup_date#3095, dropoff_datetime#3096, store_and_fwd_flag#3097, passenger_count#3098L, trip_distance#3099, fare_amount#3100, extra#3101, mta_tax#3102, tip_amount#3103, tolls_amount#3104, improvement_surcharge#3105, total_amount#3106, payment_type#3107L, ... 14 more fields]
+- Project [tripid#3084, vendorid#3085L, service_type#3086, ratecodeid#3087L, pickup_locationid#3088L, pickup_borough#3089, pickup_zone#3090, dropoff_locationid#3091L, dropoff_borough#3092, dropoff_zone#3093, pickup_datetime#3094, pickup_date#3095, dropoff_datetime#3096, store_and_fwd_flag#3097, passenger_count#3098L, trip_distance#3099, fare_amount#3100, extra#3101, mta_tax#3102, tip_amount#3103, tolls_amount#3104, improvement_surcharge#3105, total_amount#3106, pay

In [17]:
def temporal_train_test_split(df, strategy="chronological", test_ratio=0.2,
                            random_seed=42, validation_split=True):
    """
    Split time series data into train/test sets using various strategies

    Parameters:
    -----------
    df : PySpark DataFrame
        Input data with datetime columns
    strategy : str
        Splitting strategy: 'chronological', 'random', 'stratified_temporal', 'holdout_months'
    test_ratio : float
        Proportion of data for testing (default: 0.2 = 20%)
    random_seed : int
        Random seed for reproducibility
    validation_split : bool
        Whether to create a validation set (splits train further)

    Returns:
    --------
    dict
        Dictionary with 'train', 'test', and optionally 'validation' DataFrames
    """

    print(f"🔄 Splitting data using '{strategy}' strategy")
    print(f"📊 Test ratio: {test_ratio*100}%")

    if strategy == "chronological":
        # RECOMMENDED for time series: Use 2015 for training, 2016 for testing
        train_df = df.filter(col("pickup_year") == 2015)
        test_df = df.filter(col("pickup_year") == 2016)

        result = {"train": train_df, "test": test_df}

        if validation_split:
            # Use last 2 months of 2015 for validation
            train_final = train_df.filter(col("pickup_month") <= 10)
            validation_df = train_df.filter(col("pickup_month") > 10)
            result["train"] = train_final
            result["validation"] = validation_df

        print("✅ Chronological split: 2015 (train) vs 2016 (test)")

    elif strategy == "random":
        # Random split within each year to maintain temporal balance
        df_with_split = df.withColumn("rand_val", rand(seed=random_seed))

        train_df = df_with_split.filter(col("rand_val") > test_ratio)
        test_df = df_with_split.filter(col("rand_val") <= test_ratio)

        result = {"train": train_df.drop("rand_val"), "test": test_df.drop("rand_val")}

        if validation_split:
            val_ratio = test_ratio / 2
            train_final = train_df.filter(col("rand_val") > (test_ratio + val_ratio))
            validation_df = train_df.filter(
                (col("rand_val") > test_ratio) & (col("rand_val") <= (test_ratio + val_ratio))
            )
            result["train"] = train_final.drop("rand_val")
            result["validation"] = validation_df.drop("rand_val")

        print("✅ Random split across both years")

    elif strategy == "stratified_temporal":
        # Stratified split maintaining temporal patterns (by month)
        df_with_split = df.withColumn("rand_val", rand(seed=random_seed))

        # Create stratified split by month
        train_df = df_with_split.filter(col("rand_val") > test_ratio)
        test_df = df_with_split.filter(col("rand_val") <= test_ratio)

        result = {"train": train_df.drop("rand_val"), "test": test_df.drop("rand_val")}
        print("✅ Stratified temporal split")

    elif strategy == "holdout_months":
        # Hold out specific months for testing (e.g., Dec 2015 and Dec 2016)
        train_df = df.filter(col("pickup_month") != 12)
        test_df = df.filter(col("pickup_month") == 12)

        result = {"train": train_df, "test": test_df}

        if validation_split:
            # Use November for validation
            train_final = train_df.filter(col("pickup_month") != 11)
            validation_df = train_df.filter(col("pickup_month") == 11)
            result["train"] = train_final
            result["validation"] = validation_df

        print("✅ Holdout months split: December for testing")

    else:
        raise ValueError(f"Unknown strategy: {strategy}")

    # Show split statistics
    for split_name, split_df in result.items():
        count = split_df.count()
        print(f"📊 {split_name.capitalize()}: {count:,} records")

        # Show year distribution
        year_dist = split_df.groupBy("pickup_year").count().collect()
        year_info = ", ".join([f"{row['pickup_year']}: {row['count']:,}" for row in year_dist])
        print(f"   Year distribution: {year_info}")

    return result

In [18]:
# RECOMMENDED APPROACH for your use case
def create_production_splits(df, approach="chronological_with_validation"):
    """
    Create production-ready train/test splits optimized for taxi duration prediction

    Parameters:
    -----------
    df : PySpark DataFrame
        Full dataset
    approach : str
        'chronological_with_validation' or 'cross_temporal_validation'

    Returns:
    --------
    dict
        Split datasets ready for ML pipeline
    """

    if approach == "chronological_with_validation":
        # BEST for your case: 2015 for training, 2016 for testing
        splits = temporal_train_test_split(df, strategy="chronological",
                                         validation_split=True)

        print("\n🎯 RECOMMENDED APPROACH:")
        print("   • Train: 2015 (Jan-Oct)")
        print("   • Validation: 2015 (Nov-Dec)")
        print("   • Test: 2016 (Full year)")
        print("   • Benefit: Tests model on completely unseen future data")

    elif approach == "cross_temporal_validation":
        # Alternative: Mixed years with temporal awareness
        splits = temporal_train_test_split(df, strategy="stratified_temporal",
                                         test_ratio=0.2, validation_split=True)

        print("\n🔄 ALTERNATIVE APPROACH:")
        print("   • Mixed temporal cross-validation")
        print("   • Maintains seasonal patterns in all splits")

    return splits

In [19]:
splits = create_production_splits(df_with_duration, approach="chronological_with_validation")

train_data = splits["train"]
validation_data = splits["validation"]
test_data = splits["test"]

🔄 Splitting data using 'chronological' strategy
📊 Test ratio: 20.0%
✅ Chronological split: 2015 (train) vs 2016 (test)


📊 Train: 55,243,075 records


   Year distribution: 2015: 55,243,075


📊 Test: 62,724,650 records


   Year distribution: 2016: 62,724,650


📊 Validation: 10,814,005 records


   Year distribution: 2015: 10,814,005

🎯 RECOMMENDED APPROACH:
   • Train: 2015 (Jan-Oct)
   • Validation: 2015 (Nov-Dec)
   • Test: 2016 (Full year)
   • Benefit: Tests model on completely unseen future data


In [20]:
# e.g. show the last 5 rows by descending pickup_datetime
df_with_duration.orderBy(F.col("pickup_datetime").desc()).show(5, truncate=False)

+--------------------------------+--------+------------+----------+-----------------+--------------+-------------------------+------------------+---------------+--------------+-------------------+-----------+-------------------+------------------+---------------+-------------+------------+-----------+-----------+-----------+------------+---------------------+------------+------------+------------------------+------------+-----+-----------+-----------+------------+---------------+------------+------------+-----------+-----------+-----------+------------+------------------+
|tripid                          |vendorid|service_type|ratecodeid|pickup_locationid|pickup_borough|pickup_zone              |dropoff_locationid|dropoff_borough|dropoff_zone  |pickup_datetime    |pickup_date|dropoff_datetime   |store_and_fwd_flag|passenger_count|trip_distance|fare_amount |extra      |mta_tax    |tip_amount |tolls_amount|improvement_surcharge|total_amount|payment_type|payment_type_description|climate_d

In [21]:
test_data.show(5)

+--------------------+--------+------------+----------+-----------------+--------------+--------------------+------------------+---------------+--------------------+-------------------+-----------+-------------------+------------------+---------------+-------------+------------+-----------+-----------+-----------+------------+---------------------+------------+------------+------------------------+------------+-----+-----------+-----------+------------+---------------+------------+------------+-----------+-----------+-----------+------------+------------------+
|              tripid|vendorid|service_type|ratecodeid|pickup_locationid|pickup_borough|         pickup_zone|dropoff_locationid|dropoff_borough|        dropoff_zone|    pickup_datetime|pickup_date|   dropoff_datetime|store_and_fwd_flag|passenger_count|trip_distance| fare_amount|      extra|    mta_tax| tip_amount|tolls_amount|improvement_surcharge|total_amount|payment_type|payment_type_description|climate_date|  mjd| cloudCover|

**Missing Values**

Knowing about missing values is important because they indicate how much we don’t know about our data. Making inferences based on just a few cases is often unwise. In addition, many modelling procedures break down when missing values are involved and the corresponding rows will either have to be removed completely or the values need to be estimated somehow.

Here, we are in the fortunate position that our data is complete and there are no missing values.

In [None]:
%%time
# Compute null‐counts and total‐rows in one aggregation
agg_exprs = [
    F.count(F.when(F.col(c).isNull(), c)).alias(f"{c}_nulls")
    for c in df_with_time.columns
] + [
    F.count("*").alias("total_rows")
]

stats = df_with_time.agg(*agg_exprs).collect()[0].asDict()

# Now compute percentages in Python
null_percentages = {
    c.replace("_nulls",""): (count / stats["total_rows"]) * 100
    for c, count in stats.items() if c.endswith("_nulls")
}

print(f"Total rows: {stats['total_rows']}")
for col, pct in null_percentages.items():
    print(f"{col:20s}: {pct:.2f}% null")

**Reformatting Features**

For our following analysis, we will turn the data and time from characters into date objects. We also recode vendor_id as a factor. This makes it easier to visualise relationships that involve these features.

In [None]:
# Show full schema
df_with_duration.printSchema()

In [None]:
# Filter trips ≤ 1 minute
short_trips = df_with_duration.filter(F.col("trip_duration_min") <= F.lit(1.0))

count_short = short_trips.count()
count_all   = df_with_duration.count()

print(f"Short trips (≤1 min): {count_short:,} out of {count_all:,}")
print(f"That’s {count_short/count_all:.2%} of all trips")

**Observations**
- Sub-1-minute rides usually come from GPS rounding, driver repositioning, or data glitches—not people really hailing a cab for 30 seconds.
- They 're rare (<1% of trips).
- They can either be kept for realism or be filtered as outliers, depending on modeling goals.

**Analysis**
- For the time being, I will be keeping them as they possibly won't affect our final predictions much.

**Individual Feature Visualizations**
- Visualisations of feature distributions and their relations are key to understanding a data set, and they often open up new lines of inquiry. I always recommend to examine the data from as many different perspectives as possible to notice even subtle trends and correlations.

- In this section we will begin by having a look at the distributions of the individual data features.

- We start with a map of NYC and overlay a managable number of pickup coordinates to get a general overview of the locations and distances in question. For this visualisation I have used the `leaflet` package, which includes a variety of cool tools for interactive maps. In this map you can zoom and pan through the pickup locations

**Note**
- Additional data to be added into the pipeline later

Let’s start with plotting the target feature `trip_duration`

In [None]:
train_data.columns

In [None]:
pdf = (
    train_data
      .withColumn("rand", F.rand(1234))
      .orderBy("rand")
      .limit(8000)
      .toPandas()
)

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
durations = pdf["trip_duration_min"]
plt.figure(figsize=(8,5))
plt.hist(durations, bins=150, color="red")
plt.xscale("log")
plt.xlabel("Trip Duration (minutes)")
plt.ylabel("Count")
plt.title("Histogram of Trip Duration")
plt.tight_layout()
plt.show()

In [None]:
# Compute key percentiles
med = np.percentile(durations, 50)
p75 = np.percentile(durations, 75)
p95 = np.percentile(durations, 95)

# Define log-spaced bins
bins = np.logspace(np.log10(durations.min()+1e-3), np.log10(durations.max()), 50)

plt.figure(figsize=(10, 6))
n, bins_out, patches = plt.hist(
    durations, bins=bins, color='red', alpha=0.6, edgecolor='black'
)

# Add vertical lines for percentiles
for val, color, label in [(med, 'blue', 'Median'),
                          (p75, 'green', '75th pct'),
                          (p95, 'purple', '95th pct')]:
    plt.axvline(val, color=color, linestyle='--', linewidth=2, label=f'{label}: {val:.2f} min')

plt.xscale('log')
plt.xlabel("Trip Duration (minutes)")
plt.ylabel("Count")
plt.title("Histogram of Trip Duration (Log‐spaced bins)")
plt.legend()
plt.grid(True, which='both', ls=':', linewidth=0.7)
plt.tight_layout()
plt.show()

Note the logarithmic x-axis and square-root y-axis.

We find:

- the majority of rides follow a rather smooth distribution that looks almost log-normal with a peak just short of 10 min.

- There are some suspiciously short rides with less than a minute duration.

In [None]:
from pyspark.sql.functions import desc

# Use the .desc() method on the column expression
ordered_df = train_data.orderBy(train_data.trip_duration_min.desc())

# Reorder columns
first_cols = ["trip_duration_min", "pickup_datetime", "dropoff_datetime"]
rest_cols = [c for c in train_data.columns if c not in first_cols]
reordered_df = ordered_df.select(*first_cols, *rest_cols)

# Show the top 10 longest trips
reordered_df.show(10, truncate=False)

- These trip durations are clearly unrealistic and indicative of serious data quality issues
- This minor subset of records spans several months to years, when in fact its not practically feasible to have such long rides

Over the year, the distributions of **pickup\_datetime** and **dropoff\_datetime** look like this:

In [None]:
import pandas as pd

In [None]:
# Assuming you already have `pdf` as a Pandas DataFrame:
# Convert datetime columns just in case
pdf["pickup_datetime"] = pd.to_datetime(pdf["pickup_datetime"])
pdf["dropoff_datetime"] = pd.to_datetime(pdf["dropoff_datetime"])

# ✅ Cut off after July 2015
cutoff_date = pd.to_datetime("2015-08-01")
pdf_filtered = pdf[(pdf["pickup_datetime"] < cutoff_date) & (pdf["dropoff_datetime"] < cutoff_date)]

# Plot histograms
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6), sharex=True)

# Pickup histogram
ax1.hist(pdf_filtered["pickup_datetime"], bins=120, color="red")
ax1.set_title("Pickup Dates")
ax1.set_ylabel("Frequency")

# Dropoff histogram
ax2.hist(pdf_filtered["dropoff_datetime"], bins=120, color="blue")
ax2.set_title("Dropoff Dates")
ax2.set_xlabel("Datetime")
ax2.set_ylabel("Frequency")

plt.tight_layout()
plt.show()

Fairly homogeneous, covering half a year between January and July 2015. There is an interesting drop around late January early February

In [None]:
from pyspark.sql.functions import col

In [None]:
# Filter rows between Jan 20, 2016 and Feb 10, 2016
filtered_data = train_data.filter(
    (col("pickup_datetime") > "2015-01-20") &
    (col("pickup_datetime") < "2015-02-10")
)

# Select pickup_datetime and collect to Pandas
pickup_pdf = filtered_data.select("pickup_datetime").toPandas()

# Plot histogram
plt.figure(figsize=(10, 5))
plt.hist(pickup_pdf["pickup_datetime"], bins=120, color="red")
plt.xlabel("Pickup Datetime")
plt.ylabel("Frequency")
plt.title("Histogram of Pickup Datetimes (Jan 20–Feb 10, 2015)")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

That’s winter in NYC, so maybe snow storms or other heavy weather? Events like this should be taken into account, maybe through some handy external data set?

In the plot above we can already see some daily and weekly modulations in the number of trips. Let’s investigate these variations together with the distributions of passenger_count and vendor_id by creating a multi-plot panel with different components

In [None]:
import seaborn as sns

In [None]:
# Parse pickup_datetime
pdf["pickup_datetime"] = pd.to_datetime(pdf["pickup_datetime"])
# √-scaled bar plot for passenger_count
p1_data = pdf["passenger_count"].value_counts().reset_index()
p1_data.columns = ["passenger_count", "n"]
p1_data["n_sqrt"] = p1_data["n"]**0.5

plt.subplot(3, 2, 1)
sns.barplot(data=p1_data, x="passenger_count", y="n_sqrt", hue="passenger_count", palette="Reds", dodge=False)
plt.title("Passenger Count (√ scale)")
plt.xlabel("Passenger Count")
plt.ylabel("√Count")
plt.legend().remove()

In [None]:
plt.subplot(3, 2, 2)
sns.countplot(
    data=pdf,
    x="vendorid",
    hue="vendorid",
    palette="YlOrBr",  # yellow-orange-brown sequential
    dodge=False
)
plt.title("Vendor ID Count")
plt.xlabel("Vendor ID")
plt.ylabel("Count")
plt.legend().remove()

In [None]:
plt.subplot(3, 2, 3)
sns.countplot(data=pdf, x="store_and_fwd_flag", color="gray")
plt.yscale("log")  # Log scale on y-axis
plt.title("Store and Forward Flag (Log Scale)")
plt.xlabel("store_and_fwd_flag")
plt.ylabel("Log Count")
plt.legend().remove()  # Remove legend like theme(legend.position = "none")

In [None]:
from pyspark.sql.functions import date_format

In [None]:
# 1) Ensure pickup_datetime is a pandas datetime
pdf["pickup_datetime"] = pd.to_datetime(pdf["pickup_datetime"])

# 2) Add abbreviated weekday (Mon, Tue, …) with Monday as first day
pdf["wday"] = pdf["pickup_datetime"].dt.strftime("%a")

# 3) Group by weekday & vendor_id, then count
p4_pdf = (
    pdf
    .groupby(["wday", "vendorid"])
    .size()
    .reset_index(name="count")
)

# 4) Order the weekdays properly
day_order = ["Mon","Tue","Wed","Thu","Fri","Sat","Sun"]
p4_pdf["wday"] = pd.Categorical(p4_pdf["wday"], categories=day_order, ordered=True)
p4_pdf = p4_pdf.sort_values("wday")

# 5) Plot exactly like your ggplot2 version
plt.figure(figsize=(8, 4))
sns.scatterplot(
    data=p4_pdf,
    x="wday",
    y="count",
    hue="vendorid",
    s=100,
    legend=False
)
plt.xlabel("Day of the week")
plt.ylabel("Total number of pickups")
plt.title("Pickups by weekday & vendor")
plt.tight_layout()
plt.show()

In [None]:
from pyspark.sql.functions import hour

In [None]:
# Ensure pickup_datetime is a pandas datetime (if not already)
pdf["pickup_datetime"] = pd.to_datetime(pdf["pickup_datetime"])

# Extract hour of day
pdf["hpick"] = pdf["pickup_datetime"].dt.hour

# Group by hour & vendorid, then count trips
p5_pdf = (
    pdf
    .groupby(["hpick", "vendorid"])
    .size()
    .reset_index(name="count")
)

# Order by hour for clarity
p5_pdf = p5_pdf.sort_values("hpick")

# Plot just like ggplot2’s geom_point(size=4), legend off
plt.figure(figsize=(8, 4))
sns.scatterplot(
    data=p5_pdf,
    x="hpick",
    y="count",
    hue="vendorid",
    s=100,       # size ~4pt
    legend=False
)
plt.xlabel("Hour of the day")
plt.ylabel("Total number of pickups")
plt.title("Pickups by hour & vendor")
plt.tight_layout()
plt.show()

**Check for 0 or between 7-9 passengers**

In [None]:
# 1) Filter for passenger_count == 0 or between 7 and 9
anomalous_df = train_data.filter(
    (col("passenger_count") == 0) |
    (col("passenger_count").between(7, 9))
)

# 2) See how many of each you have
anomalous_df.groupBy("passenger_count") \
    .count() \
    .orderBy("passenger_count") \
    .show()

# 3) (Optional) Peek at a few of the actual rows
anomalous_df.show(10, truncate=False)

There are a few trips with zero, or seven to nine passengers but they are a rare exception.

Towards larger passenger numbers we are seeing a smooth decline through 3 to 4, until the larger crowds (and larger cars) give us another peak at 5 to 6 passengers.

Vendor 2 has significantly more trips in this data set than vendor 1 (note the logarithmic y-axis). This is true for every day of the week.

We find an interesting pattern with Monday being the quietest day and Saturday very busy. This is the same for the two different vendors, with vendor_id == 2 showing significantly higher trip numbers.

The `store_and_fwd_flag` values, indicating whether the trip data was sent immediately to the vendor (“N”) or held in the memory of the taxi because there was no connection to the server (“Y”), show that there was almost no storing taking place (note again the logarithmic y-axis):

In [None]:
# 1) Count occurrences in the pandas DataFrame
counts_pdf = (
    pdf["store_and_fwd_flag"]
    .value_counts(dropna=False)                  # include NaNs if you want
    .reset_index(name="count")
    .rename(columns={"index": "store_and_fwd_flag"})
)

# 2) Plot the data
plt.figure(figsize=(6, 4))
sns.barplot(
    data=counts_pdf,
    x="store_and_fwd_flag",
    y="count",
    palette="pastel"
)
plt.xlabel("Store and Forward Flag")
plt.ylabel("Number of Trips")
plt.title("Trips by Store-and-Forward Flag")
plt.tight_layout()
plt.show()

These numbers are equivalent to about half a percent of trips not being transmitted immediately.

The trip volume per hour of the day depends somewhat on the month and strongly on the day of the week

In [None]:
# 1) Ensure pickup_datetime is datetime and extract hour & month abbreviation
pdf["pickup_datetime"] = pd.to_datetime(pdf["pickup_datetime"])
pdf["hpick"] = pdf["pickup_datetime"].dt.hour
pdf["Month"] = pdf["pickup_datetime"].dt.strftime("%b")

# 2) Group by hour & Month, then count trips
p1_pdf = (
    pdf
    .groupby(["hpick", "Month"])
    .size()
    .reset_index(name="count")
)

# 3) hpick is already int from dt.hour

# 4) Order Month factor for a proper legend
month_order = ["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]
p1_pdf["Month"] = pd.Categorical(p1_pdf["Month"], categories=month_order, ordered=True)
p1_pdf = p1_pdf.sort_values(["Month", "hpick"])

# 5) Plot with Seaborn lineplot
plt.figure(figsize=(10, 6))
sns.lineplot(
    data=p1_pdf,
    x="hpick",
    y="count",
    hue="Month",
    linewidth=1.5
)
plt.xlabel("Hour of the day")
plt.ylabel("Count")
plt.title("Pickups by Hour of the Day & Month")
plt.legend(title="Month", bbox_to_anchor=(1.02, 1), loc="upper left")
plt.tight_layout()
plt.show()

**Not a good pattern. Let's convert the entire dataframe**

In [None]:
# 1) Extract hour of day and month abbreviation from pickup_datetime
df2 = (
    train_data
    .withColumn("hpick", hour("pickup_datetime"))
    .withColumn("Month", date_format("pickup_datetime", "MMM"))
)

# 2) Group by hour & Month, then count trips
p1_df = (
    df2
    .groupBy("hpick", "Month")
    .count()
    .orderBy("hpick", "Month")
)

# 3) Collect to Pandas for plotting
p1_pdf = p1_df.toPandas()
p1_pdf["hpick"] = p1_pdf["hpick"].astype(int)

# 4) (Optional) Order Month factor for a proper legend
month_order = ["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]
p1_pdf["Month"] = pd.Categorical(p1_pdf["Month"], categories=month_order, ordered=True)
p1_pdf = p1_pdf.sort_values(["Month","hpick"])

# 5) Plot with Seaborn lineplot (geom_line equivalent)
plt.figure(figsize=(10, 6))
sns.lineplot(
    data=p1_pdf,
    x="hpick",
    y="count",
    hue="Month",
    linewidth=1.5
)
plt.xlabel("Hour of the day")
plt.ylabel("Count")
plt.title("Pickups by Hour of the Day & Month")
plt.legend(title="Month", bbox_to_anchor=(1.02, 1), loc="upper left")
plt.tight_layout()
plt.show()

- Deep pre-dawn trough around 4–5 AM on all days (counts fall below ~100 K), reflecting minimal overnight demand

- Two clear rush-hour peaks—a morning spike and an evening spike (~5–7 PM, ~400–420 K rides)—driven by commuter travel

- Saturday demand is highest overall (evening peaks >430 K and a shallower early-morning dip)

**Finally, we will look at a simple overview visualisation of the pickup/dropoff latitudes and longitudes**. We will look more into this when we have the latitudes and longitudes

## **Feature Relations**

While the previous section looked primarily at the distributions of the individual features, here we will examine in more detail how those features are related to each other and to our target `trip_duration_min`.

### 3.1 Pickup date/time vs trip_duration

How does the variation in trip numbers throughout the day and the week affect the average trip duration? Do quieter days and hours lead to faster trips? Here we include the vendor_id as an additional feature. Furthermore, for the hours of the day we add a smoothing layer to indicate the extent of the variation and its uncertainties

In [None]:
from pyspark.sql.functions import expr

In [None]:
from matplotlib.ticker import FuncFormatter

In [None]:
# 1) Extract hour of day
df_hpick = train_data.withColumn("hpick", hour("pickup_datetime"))

# 2) Compute median trip_duration in minutes
p2_df = (
    df_hpick
    .groupBy("hpick", "vendorid")
    .agg(
        (expr("percentile_approx(trip_duration_min, 0.5) / 60"))
        .alias("median_min")
    )
    .orderBy("hpick", "vendorid")
)

# 3) Collect to Pandas
p2_pdf = p2_df.toPandas()
p2_pdf["hpick"] = p2_pdf["hpick"].astype(int)

# 4) Plot with LOESS smoothing and points, legend on
g = sns.lmplot(
    data=p2_pdf,
    x="hpick",
    y="median_min",
    hue="vendorid",
    lowess=True,
    height=5,
    aspect=2,
    scatter_kws={"s": 100},
    line_kws={"linewidth": 1.5},
    legend=True          # enable legend
)

# 5) Format the y-axis to mm:ss
ax = g.ax
def mmss(x, pos):
    m = int(x)
    s = int(round((x - m) * 60))
    return f"{m}:{s:02d}"
ax.yaxis.set_major_formatter(FuncFormatter(mmss))

# 6) Tidy up labels and legend title
ax.set_xlabel("Hour of the day")
ax.set_ylabel("Median trip duration [min:sec]")
ax.set_title("Median Trip Duration by Hour & Vendor (LOESS smoothing)")
g._legend.set_title("Vendor ID")

plt.tight_layout()
plt.show()

**Pre-dawn dips to shortest trips**
- Both vendors hit their minimum median durations (~7.5-9 min) around 5-6 AM, when traffic is lightest.

**Steady climb into mid-afternoon peak**
- From 6 AM onward, median trip times rise, reaching a high of ~11-11.5 minutes between 2-4 PM, reflecting heavier traffic.

**Evening taper**
- After 4 PM, durations gradually fall back toward 10 minutes by 10-11 PM as congestion eases.

**Vendor parity with slight offsets**
- Vendor 1 (blue) and Vendor 2 (orange) follow almost identical curves; Vendor 1's median is typically 5-15 seconds longer during the mid-day and evening peaks.

### Passenger count and Vendor vs trip_duration
The next question we are asking is whether different numbers of passengers and/or the different vendors are correlated with the duration of the trip. We choose to examine this issue using a series of boxplots for the `passenger_counts` together with a facet wrap which contrasts the two `vendor_ids`

In [None]:
# 1) Sample 10% of the data (or use .limit(N) if you want an exact number)
df_bc = train_data \
    .select("passenger_count", "trip_duration_min", "vendorid") \
    .sample(fraction=0.10, seed=42)

# 2) Bring back to Pandas
pdf_bc = df_bc.toPandas()

# 3) Ensure dtypes
pdf_bc["passenger_count"] = pdf_bc["passenger_count"].astype(int)
pdf_bc["vendorid"]          = pdf_bc["vendorid"].astype(str)

# 4) Plot with seaborn’s catplot (boxplots) and log‐scale y–axis
g = sns.catplot(
    data=pdf_bc,
    x="passenger_count",
    y="trip_duration_min",
    col="vendorid",
    kind="box",
    sharey=False,
    palette="tab10",
    height=4,
    aspect=1
)

# 5) Apply log‐scale to each facet, relabel axes
for ax in g.axes.flatten():
    ax.set_yscale("log")
    ax.set_xlabel("Number of passengers")
    ax.set_ylabel("Trip duration [min]")

plt.tight_layout()
plt.show()

In [None]:
pdf.shape

In [None]:
# 1) Pull just the two columns, limit to e.g. 100 000 rows to avoid OOM
pdf_den = (
    train_data
    .select("trip_duration_min", "vendorid")
    .limit(100000)
    .toPandas()
)

# 2) Make sure vendorid is a string for hue
pdf_den["vendorid"] = pdf_den["vendorid"].astype(str)

# 3) Plot a stacked density on a log‐x axis
plt.figure(figsize=(10, 6))
sns.kdeplot(
    data=pdf_den,
    x="trip_duration_min",
    hue="vendorid",
    fill=True,
    common_norm=False,    # each density scaled to its own area
    multiple="stack",     # stack them up
    alpha=0.7
)
plt.xscale("log")
plt.xlabel("Trip duration (seconds)")
plt.ylabel("Density")
plt.title("Stacked density of trip duration by vendor")
plt.show()

Comparing the densities of the trip_duration distribution for the two vendors we find that the medians are very similar, whereas the means are likely skewed by vendor 2 containing most of the long-duration outliers

In [None]:
from pyspark.sql.functions import mean

# 1) Group by vendorid and compute mean & median of trip_duration (in seconds)
summary_df = (
    train_data
    .groupBy("vendorid")
    .agg(
        mean("trip_duration_min").alias("mean_duration"),
        expr("percentile_approx(trip_duration_min, 0.5)").alias("median_duration")
    )
)

# 2) Show results
summary_df.show()

### Store and Forward vs. Trip_Duration

In [None]:
# Group by vendorid & store_and_forward_flag, then count trips
counts_df = train_data.groupBy("vendorid", "store_and_fwd_flag").count()

# Display the result
counts_df.show()

In [None]:
# Filter for Vendor 1 and select relevant columns
df_filtered = (
            train_data
            .filter(col("vendorid") == 1)
            .select("passenger_count", "trip_duration_min", "store_and_fwd_flag")
)

# Sample a subset to avoid memory issues, then convert to Pandas
pdf = df_filtered.sample(fraction=0.1, seed=42).toPandas()

# Ensure categorical types for plotting
pdf["passenger_count"] = pdf["passenger_count"].astype(str)
pdf["store_and_fwd_flag"] = pdf["store_and_fwd_flag"].astype(str)

# Create faceted boxplots with log-scaled y-axis
g = sns.catplot(
    data=pdf,
    x="passenger_count",
    y="trip_duration_min",
    hue="passenger_count",
    col="store_and_fwd_flag",
    kind="box",
    sharey=False,
    palette="tab10",
    height=4,
    aspect=1,
)

# Add global title
g.fig.suptitle("Store_and_fwd_flag impact", y=1.03)

plt.tight_layout()
plt.show()

We find that there is no overwhelming differences between the stored and non-stored trips. The stored ones might be slightly longer, though, and don’t include any of the suspiciously long trips.

## **Feature Engineering**

In this section we try to build new features from the existing ones, trying to find better predictors for the target variable.

**To Do**
- Define all of the new features and analyze them. Pick OR discard on the go


The new temporal features (date, mos, wday, hr.) are derived from the `pickup_datetime`. We get the JFK and La Guardia airport co-ordinates from Wikipedia. The blizzard feature is based on the external weather data.

### Direct distance of the trip

From the coordinates of the trip and dropoff points we can calculate the direct distance (as the crow flies) between the two points, and compare it to our trip_duration(s). Since taxis aren’t crows (in most practical scenarios), these values correspond to the minimum possible travel distance.

To compute these distances we are using the distCosine function of the `geosphere` package for spherical trigonometry. This method gives us the shortest distance between two points on the spherical Earth. For the purpose of this analysis, we choose to ignore ellipsoidal distortion of the Earth's shape. Here are the raw values of distance vs. duration (based on a downsized sample to speed up the kernel)

In [None]:
train_data.columns

In [None]:
# Load the public Taxi geometry table from Bigquery
zone_centroids = (
    spark.read.format("bigquery")
    .option("query", """
        SELECT
            zone_id AS locationid,
            ST_X(ST_CENTROID(zone_geom)) AS lon,
            ST_Y(ST_CENTROID(zone_geom)) AS lat
        FROM `bigquery-public-data.new_york_taxi_trips.taxi_zone_geom`
    """)
    .load()
)

In [None]:
zone_centroids.show(5)

In [None]:
train_data = (
    train_data
    .join(
        zone_centroids
            .withColumnRenamed("locationid", "pickup_locationid")
            .withColumnRenamed("lon", "pickup_lon")
            .withColumnRenamed("lat", "pickup_lat"),
        on="pickup_locationid",
        how="left"
    )
    .join(
        zone_centroids
            .withColumnRenamed("locationid", "dropoff_locationid")
            .withColumnRenamed("lon", "dropoff_lon")
            .withColumnRenamed("lat", "dropoff_lat"),
        on="dropoff_locationid",
        how="left"
    )
)

In [None]:
validation_data = (
    validation_data
    .join(
        zone_centroids
            .withColumnRenamed("locationid", "pickup_locationid")
            .withColumnRenamed("lon", "pickup_lon")
            .withColumnRenamed("lat", "pickup_lat"),
        on="pickup_locationid",
        how="left"
    )
    .join(
        zone_centroids
            .withColumnRenamed("locationid", "dropoff_locationid")
            .withColumnRenamed("lon", "dropoff_lon")
            .withColumnRenamed("lat", "dropoff_lat"),
        on="dropoff_locationid",
        how="left"
    )
)

In [None]:
test_data = (
    test_data
    .join(
        zone_centroids
            .withColumnRenamed("locationid", "pickup_locationid")
            .withColumnRenamed("lon", "pickup_lon")
            .withColumnRenamed("lat", "pickup_lat"),
        on="pickup_locationid",
        how="left"
    )
    .join(
        zone_centroids
            .withColumnRenamed("locationid", "dropoff_locationid")
            .withColumnRenamed("lon", "dropoff_lon")
            .withColumnRenamed("lat", "dropoff_lat"),
        on="dropoff_locationid",
        how="left"
    )
)

In [None]:
train_data.show(5)

In [None]:
from pyspark.sql.functions import (
    col, radians, sin, cos, sqrt, asin, expr, rand
)

# Earth radius in meters
EARTH_RADIUS = 6_371_000.0

df_dist = train_data.select(
    "*", # keep all existing columns
    # convert to radians
    radians(col("pickup_lat")).alias("plat"),
    radians(col("pickup_lon")).alias("plon"),
    radians(col("dropoff_lat")).alias("dlat"),
    radians(col("dropoff_lon")).alias("dlon"),
    # compute deltas
    (radians(col("dropoff_lat")) - radians(col("pickup_lat"))).alias("Δlat"),
    (radians(col("dropoff_lon")) - radians(col("pickup_lon"))).alias("Δlon"),
    # Haversine “a” term
    (
        sin((radians(col("dropoff_lat")) - radians(col("pickup_lat"))) / 2)**2
        + cos(radians(col("pickup_lat"))) * cos(radians(col("dropoff_lat")))
          * sin((radians(col("dropoff_lon")) - radians(col("pickup_lon"))) / 2)**2
    ).alias("a"),
    # angular distance c
    (
        2
        * asin(
            sqrt(
                sin((radians(col("dropoff_lat")) - radians(col("pickup_lat"))) / 2)**2
                + cos(radians(col("pickup_lat"))) * cos(radians(col("dropoff_lat")))
                  * sin((radians(col("dropoff_lon")) - radians(col("pickup_lon"))) / 2)**2
            )
        )
    ).alias("c"),
    # direct distance (m)
    (col("c") * EARTH_RADIUS).alias("direct_dist_m")
)

In [None]:
# Sample 50K rows for plotting
sampled = (
    df_dist
    .orderBy(rand(seed=4321))
    .limit(50000)
    .select("direct_dist_m", "trip_duration_min")
)

In [None]:
# Convert trip_duration_min -> trip_duration_sec in the sampled Spark dataframe
sampled = sampled.withColumn(
        "trip_duration_sec",
        col("trip_duration_min")*60
)

# If you only want the distance and the new seconds column:
sampled = sampled.select(
    "direct_dist_m",
    "trip_duration_sec"
)

In [None]:
# 3) Plot in Pandas (log–log)
pdf = sampled.toPandas()

plt.figure(figsize=(8,6))
sns.scatterplot(
    x="direct_dist_m",
    y="trip_duration_sec",
    data=pdf,
    alpha=0.3,
    s=10
)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Direct distance [m]")
plt.ylabel("Trip duration [sec]")
plt.title("Direct distance vs. Trip duration")
plt.tight_layout()
plt.show()

We find:
- The distance generally increases with increasing trip_duration
- The trip distances and the corresponding duration seem reasonable in this data and the associated plot
- A significant number of trips are less than a km but their durations are significant. Seems to be erroneous recording

Let’s filter the data a little bit to remove the extreme (and the extremely suspicious) data points, and bin the data into a 2-d histogram. This plot shows that in log-log space the trip_duration is increasing slower than linear for larger distance values.


In [None]:
# Filter the dataframe
filtered_df = df_dist.filter(
    (col("trip_duration_min")<3600) & (col("trip_duration_min")>120) &
    (col("direct_dist_m")>100) &(col("direct_dist_m")<100e3)
)

# Convert to Pandas for plotting
pdf = filtered_df.select("direct_dist_m", "trip_duration_min").toPandas()

# Plot using seaborn or matplotlib
plt.figure(figsize=(8,6))
hb = plt.hexbin(
    x=np.log10(pdf["direct_dist_m"]),
    y=np.log10(pdf["trip_duration_min"]),
    gridsize=500,
    cmap="viridis"
)
cb = plt.colorbar(hb)
cb.set_label('Counts')

plt.xlabel("Direct distance [log10(m)]")
plt.ylabel("Trip duration [log10(s)]")
plt.title("Trip Duration vs Distance (log-log)")
plt.tight_layout()
plt.show()

## **Travel Speed**
Distance over time is of course velocity, and by computing the average apparent velocity of our taxis we will have another diagnostic to remove bogus values. Of course, we won’t be able to use speed as a predictor for our model, since it requires knowing the travel time, but it can still be helpful in cleaning up our training data and finding other features with predictive power. This is the speed distribution

In [None]:
df_dist.columns

In [None]:
# Filter the values
df_dist = df_dist.filter(col("trip_duration_min")>0 )
df_dist = df_dist.withColumn("speed", col("direct_dist_m")/(col("trip_duration_min")*60))

# Convert to pandas
df_dist_plot = df_dist.toPandas()
df_dist_plot = df_dist_plot[(df_dist_plot["speed"]>2) & (df_dist_plot["speed"]<1e2)]

# Plot the figure
plt.figure(figsize=(8, 6))
plt.hist(pdf["speed"], bins=50, color="red", edgecolor="black")
plt.xlabel("Average speed [km/h] (direct distance)")
plt.ylabel("Count")
plt.title("Histogram of Average Speed")
plt.grid(True)
plt.tight_layout()
plt.show()