# NYC Taxi Data Analysis - Incremental Year-on-Year Processing

## Dataset and Goals
- Goal: Predict the duration of NYC taxi rides using features like pickup time and trip coordinates
- **NEW**: Incremental processing strategy for 20-30GB dataset growing at 15GB/year
- Optimized for GCP Spark Notebook with sufficient memory and computation bandwidth

## Notebook Workflow
- **Incremental Data Loading**: Load data year by year for efficient processing
- Explore and visualize raw data with temporal analysis
- Engineer new features with historical context
- Examine outliers across different time periods
- Incorporate external datasets
    - Weather data
- Visualize and analyze how these new features impact trip duration over time
- Briefly explore a classification approach to predicting duration ranges
- Build XGBoost models with temporal validation
- **NEW**: Implement incremental model training and evaluation

All of this will be done in PySpark, optimized for large-scale incremental processing.

In [4]:
# from google.cloud.dataproc_spark_connect import DataprocSparkSession
# from pyspark.sql.connect import functions as F
from pyspark.sql.functions import col, year, month, dayofyear, when, lit, unix_timestamp, count, avg
from datetime import datetime, timedelta
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.figure import Figure
import numpy as np
import math
from typing import List, Optional, Union, Tuple

## Core PySpark Setup & BigQuery Integration - Optimized for Incremental Loading

In [None]:
# ----------------------------------------------------------------------
# Create Dataproc Spark Connect session with custom configs
# ----------------------------------------------------------------------


# Enhanced Spark configuration for large-scale incremental processing
# 🔧 SPARK CONFIGURATION EXPLAINED:
# 
# ADAPTIVE QUERY EXECUTION (AQE) - Runtime optimization based on actual data
# ├─ "spark.sql.adaptive.enabled" = "true"
# │  └─ Enables AQE for dynamic query optimization during execution
# │
# ├─ "spark.sql.adaptive.coalescePartitions.enabled" = "true" 
# │  └─ Automatically reduces number of partitions after shuffle to optimize performance
# │
# ├─ "spark.sql.adaptive.localShuffleReader.enabled" = "true"
# │  └─ Reduces network I/O by reading shuffle data locally when possible
# │
# └─ "spark.sql.adaptive.advisoryPartitionSizeInBytes" = "256MB"
#    └─ Target size for each partition (256MB optimal for most workloads)
#
# SKEW HANDLING-Deals with uneven data distribution
# ├─ "spark.sql.adaptive.skewJoin.enabled" = "true"
# │  └─ Automatically detects and handles data skew in joins
# │
# ├─ "spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes" = "256MB"
# │  └─ Partition size threshold to identify skewed partitions
# │
# └─ "spark.sql.adaptive.skewJoin.skewedPartitionFactor" = "5"
#    └─ Factor to determine skew (partition 5x larger than median = skewed)
#
# PERFORMANCE OPTIMIZATIONS
# ├─ "spark.sql.shuffle.partitions" = "400"
# │  └─ Number of partitions for shuffle operations (good for 20-30GB datasets)
# │
# ├─ "spark.serializer" = "org.apache.spark.serializer.KryoSerializer"
# │  └─ Faster serialization compared to default Java serialization
# │
# └─ "spark.sql.execution.arrow.pyspark.enabled" = "true"
#    └─ Enables Apache Arrow for faster data transfer between JVM and Python
spark = (
    DataprocSparkSession.builder
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .config("spark.sql.adaptive.skewJoin.enabled", "true")
    .config("spark.sql.adaptive.localShuffleReader.enabled", "true")
    .config("spark.sql.adaptive.advisoryPartitionSizeInBytes", "256MB")
    .config("spark.sql.shuffle.partitions", "400")
    .config("spark.sql.execution.arrow.pyspark.enabled", "true")
    .config("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")
    .config("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
    .getOrCreate()
)

# ----------------------------------------------------------------------
# GCP Project and Dataset Config
# ----------------------------------------------------------------------

PROJECT_ID = "dtc-de-course-457315"
DATASET_ID = "dbt_production"

INCREMENTAL_CONFIG = {
    "start_year": 2015,
    "end_year": 2016,
    "batch_size_months": 6,
    "checkpoint_dir": "/tmp/spark-checkpoints",   # checkpointing not used in Spark Connect
    "cache_level": "MEMORY_AND_DISK_SER",
    "max_records_per_partition": 1000000
}

print(f"✅ Spark Version: {spark.version}")
print(f"📊 Working with: {PROJECT_ID}.{DATASET_ID}")
print(f"🔄 Incremental Processing: {INCREMENTAL_CONFIG['start_year']}-{INCREMENTAL_CONFIG['end_year']}")

# ----------------------------------------------------------------------
# Example: Read from BigQuery
# ----------------------------------------------------------------------

table_name = f"{PROJECT_ID}.{DATASET_ID}.fact_trips"

df = (
    spark.read
    .format("bigquery")
    .option("table", table_name)
    .load()
)

print("✅ Loaded BigQuery table into Spark DataFrame:")
df.show(5)

25/07/05 16:11:08 WARN Utils: Your hostname, Sagars-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.168.29.162 instead (on interface en0)
25/07/05 16:11:08 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Ivy Default Cache set to: /Users/saggysimmba/.ivy2/cache
The jars for the packages stored in: /Users/saggysimmba/.ivy2/jars
com.google.cloud.spark#spark-bigquery-with-dependencies_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-30f7e725-ddc7-4c38-bd9c-5b9ae1a72b2f;1.0
	confs: [default]


:: loading settings :: url = jar:file:/Applications/saggydev/projects_learning/data_engineering_course/.venv/lib/python3.9/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


	found com.google.cloud.spark#spark-bigquery-with-dependencies_2.12;0.32.0 in central
downloading https://repo1.maven.org/maven2/com/google/cloud/spark/spark-bigquery-with-dependencies_2.12/0.32.0/spark-bigquery-with-dependencies_2.12-0.32.0.jar ...
	[SUCCESSFUL ] com.google.cloud.spark#spark-bigquery-with-dependencies_2.12;0.32.0!spark-bigquery-with-dependencies_2.12.jar (5121ms)
:: resolution report :: resolve 24463ms :: artifacts dl 5123ms
	:: modules in use:
	com.google.cloud.spark#spark-bigquery-with-dependencies_2.12;0.32.0 from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	---------------------------------------------------------------------
	|      default     |   1   |   1   |   1   |   0   ||   1   |   1   |
	---------------------------------------------------------------------
:: retrievin

✅ Spark Version: 3.5.4
✅ Spark UI: http://192.168.29.162:4040
📊 Working with: dtc-de-course-457315.nyc_taxi_data
🔄 Incremental Processing: 2015-2016


25/07/05 16:11:52 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors


| Column Name                | Description                                                  |
|----------------------------|--------------------------------------------------------------|
| tripid                     | Unique identifier for each taxi trip                         |
| vendorid                   | ID of the vendor/company providing the taxi service          |
| service_type               | Type of taxi service (e.g. Yellow, Green)                    |
| ratecodeid                 | Rate code indicating fare rules for the trip                 |
| pickup_locationid          | Location ID where the trip started                           |
| pickup_borough             | Borough where the trip started                               |
| pickup_zone                | Zone within the borough where the trip started               |
| dropoff_locationid         | Location ID where the trip ended                             |
| dropoff_borough            | Borough where the trip ended                                 |
| dropoff_zone               | Zone within the borough where the trip ended                 |
| pickup_datetime            | Timestamp when the trip started                              |
| pickup_date                | Date (only) of the pickup                                    |
| dropoff_datetime           | Timestamp when the trip ended                                |
| store_and_fwd_flag         | Flag indicating if trip record was stored and forwarded      |
| passenger_count            | Number of passengers in the trip                             |
| trip_distance              | Distance travelled during the trip (in miles)                |
| fare_amount                | Fare charged for the trip                                    |
| extra                      | Additional charges (e.g. surcharge, night fee)               |
| mta_tax                    | MTA (Metropolitan Transportation Authority) tax amount       |
| tip_amount                 | Tip given to the driver                                      |
| tolls_amount               | Total tolls paid during the trip                             |
| improvement_surcharge      | NYC-imposed surcharge to support improvements                |
| total_amount               | Total amount paid (fare + extras + tip + tolls)              |
| payment_type               | Numeric code for the payment type                            |
| payment_type_description   | Description of the payment type (e.g. Credit card, Cash)     |
| climate_date               | Date of associated climate/weather data                      |
| mjd                        | Modified Julian Date (astronomical date format)              |
| cloudCover                 | Fraction of sky covered by clouds                            |
| humidity                   | Relative humidity percentage                                 |
| dewPoint                   | Dew point temperature in degrees Fahrenheit or Celsius       |
| precipIntensity            | Intensity of precipitation during the period                 |
| highTemp                   | High temperature of the day                                  |
| lowTemp                    | Low temperature of the day                                   |
| visibility                 | Visibility distance during the trip (in miles)               |
| windSpeed                  | Wind speed during the trip                                   |


## Define multiple plotting functions

In [None]:
# Multi-panel plotting function (R multiplot equivalent)
# Courtesy of R Cookbooks adapted for Python/PySpark
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.figure import Figure
import numpy as np
import math
from typing import List, Optional, Union, Tuple

def multiplot(*plots, plotlist=None, cols=1, layout=None, figsize=(15, 10), 
              title=None, save_path=None, dpi=300):
    """
    Create multi-panel plots from matplotlib/seaborn/plotly figures
    
    Python equivalent of R's multiplot function from R Cookbooks
    
    Parameters:
    -----------
    *plots : matplotlib figures or plot functions
        Individual plots to be arranged
    plotlist : list, optional
        List of plots as alternative to *plots
    cols : int, default=1
        Number of columns in layout
    layout : array-like, optional
        Matrix specifying the layout. If present, 'cols' is ignored.
        Example: [[1,2], [3,3]] means plot 1 top-left, 2 top-right, 3 bottom spanning both columns
    figsize : tuple, default=(15, 10)
        Figure size (width, height) in inches
    title : str, optional
        Overall title for the multi-panel plot
    save_path : str, optional
        Path to save the figure
    dpi : int, default=300
        Resolution for saved figure
        
    Returns:
    --------
    matplotlib.figure.Figure
        The combined figure with all subplots
        
    Example:
    --------
    # Create individual plots
    fig1, ax1 = plt.subplots()
    ax1.plot([1,2,3], [1,4,2])
    
    fig2, ax2 = plt.subplots()
    ax2.bar([1,2,3], [3,1,4])
    
    # Combine them
    combined_fig = multiplot(fig1, fig2, cols=2, title="Combined Analysis")
    """
    
    # Combine plots from arguments and plotlist
    all_plots = list(plots) if plots else []
    if plotlist:
        all_plots.extend(plotlist)
    
    num_plots = len(all_plots)
    
    if num_plots == 0:
        print("⚠️ No plots provided")
        return None
    
    # Handle single plot case
    if num_plots == 1:
        if hasattr(all_plots[0], 'show'):
            all_plots[0].show()
        else:
            plt.figure(figsize=figsize)
            if title:
                plt.suptitle(title, fontsize=16, fontweight='bold')
            plt.show()
        return all_plots[0]
    
    # Determine layout
    if layout is None:
        # Calculate rows and columns
        nrows = math.ceil(num_plots / cols)
        ncols = cols
        layout_matrix = np.arange(1, cols * nrows + 1).reshape(nrows, ncols)
    else:
        layout_matrix = np.array(layout)
        nrows, ncols = layout_matrix.shape
    
    # Create the main figure
    fig = plt.figure(figsize=figsize)
    
    if title:
        fig.suptitle(title, fontsize=16, fontweight='bold', y=0.95)
    
    # Create GridSpec for flexible subplot arrangement
    gs = gridspec.GridSpec(nrows, ncols, figure=fig, hspace=0.3, wspace=0.3)
    
    # Place each plot in the correct position
    for i, plot in enumerate(all_plots, 1):
        if i > num_plots:
            break
            
        # Find positions where this plot should go
        positions = np.where(layout_matrix == i)
        
        if len(positions[0]) == 0:
            continue
            
        # Calculate subplot span
        row_min, row_max = positions[0].min(), positions[0].max()
        col_min, col_max = positions[1].min(), positions[1].max()
        
        # Create subplot
        ax = fig.add_subplot(gs[row_min:row_max+1, col_min:col_max+1])
        
        # Handle different plot types
        if hasattr(plot, 'figure'):
            # It's a matplotlib figure
            _copy_plot_to_axis(plot, ax)
        elif callable(plot):
            # It's a plotting function
            plot(ax)
        elif hasattr(plot, 'axes'):
            # It's a figure with axes
            _copy_plot_to_axis(plot, ax)
        else:
            # Try to handle as data for direct plotting
            ax.text(0.5, 0.5, f'Plot {i}', ha='center', va='center', 
                   transform=ax.transAxes)
    
    plt.tight_layout()
    
    # Save if path provided
    if save_path:
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
        print(f"💾 Multi-panel plot saved to: {save_path}")
    
    plt.show()
    return fig

def _copy_plot_to_axis(source_fig, target_ax):
    """Helper function to copy plot content from source figure to target axis"""
    try:
        if hasattr(source_fig, 'axes') and source_fig.axes:
            source_ax = source_fig.axes[0]
            
            # Copy lines
            for line in source_ax.get_lines():
                target_ax.plot(line.get_xdata(), line.get_ydata(), 
                             color=line.get_color(), linewidth=line.get_linewidth(),
                             linestyle=line.get_linestyle(), marker=line.get_marker(),
                             label=line.get_label())
            
            # Copy patches (bars, etc.)
            for patch in source_ax.patches:
                target_ax.add_patch(patch)
            
            # Copy collections (scatter plots, etc.)
            for collection in source_ax.collections:
                target_ax.add_collection(collection)
            
            # Copy labels and title
            target_ax.set_xlabel(source_ax.get_xlabel())
            target_ax.set_ylabel(source_ax.get_ylabel())
            target_ax.set_title(source_ax.get_title())
            
            # Copy limits
            target_ax.set_xlim(source_ax.get_xlim())
            target_ax.set_ylim(source_ax.get_ylim())
            
            # Copy legend if exists
            if source_ax.get_legend():
                target_ax.legend()
                
    except Exception as e:
        # Fallback: just add a text placeholder
        target_ax.text(0.5, 0.5, 'Plot Content', ha='center', va='center',
                      transform=target_ax.transAxes)
        print(f"⚠️ Could not copy plot content: {e}")

# Enhanced multiplot for PySpark DataFrames
def multiplot_spark(dataframes_and_plots, cols=2, figsize=(15, 10), title=None):
    """
    Create multi-panel plots specifically for PySpark DataFrame visualizations
    
    Parameters:
    -----------
    dataframes_and_plots : list of tuples
        Each tuple contains (spark_dataframe, plot_config)
        plot_config is a dict with keys: 'type', 'x', 'y', 'title', etc.
    cols : int
        Number of columns
    figsize : tuple
        Figure size
    title : str
        Overall title
        
    Example:
    --------
    plot_configs = [
        (df_2020, {'type': 'hist', 'x': 'trip_distance', 'title': '2020 Trip Distance'}),
        (df_2021, {'type': 'scatter', 'x': 'trip_distance', 'y': 'fare_amount', 'title': '2021 Distance vs Fare'}),
        (df_2022, {'type': 'bar', 'x': 'pickup_hour', 'y': 'count', 'title': '2022 Hourly Trips'})
    ]
    multiplot_spark(plot_configs, cols=2, title="Yearly Comparison")
    """
    
    num_plots = len(dataframes_and_plots)
    nrows = math.ceil(num_plots / cols)
    
    fig, axes = plt.subplots(nrows, cols, figsize=figsize)
    if title:
        fig.suptitle(title, fontsize=16, fontweight='bold')
    
    # Flatten axes array for easy indexing
    if num_plots == 1:
        axes = [axes]
    elif nrows == 1:
        axes = axes
    else:
        axes = axes.flatten()
    
    for i, (spark_df, plot_config) in enumerate(dataframes_and_plots):
        if i >= len(axes):
            break
            
        ax = axes[i]
        
        # Convert to Pandas for plotting
        pandas_df = spark_df.toPandas()
        
        plot_type = plot_config.get('type', 'scatter')
        x_col = plot_config.get('x')
        y_col = plot_config.get('y')
        plot_title = plot_config.get('title', f'Plot {i+1}')
        
        # Create the appropriate plot
        if plot_type == 'hist':
            ax.hist(pandas_df[x_col], bins=30, alpha=0.7)
            ax.set_xlabel(x_col)
            ax.set_ylabel('Frequency')
        elif plot_type == 'scatter':
            ax.scatter(pandas_df[x_col], pandas_df[y_col], alpha=0.6)
            ax.set_xlabel(x_col)
            ax.set_ylabel(y_col)
        elif plot_type == 'bar':
            if y_col:
                ax.bar(pandas_df[x_col], pandas_df[y_col])
            else:
                value_counts = pandas_df[x_col].value_counts()
                ax.bar(value_counts.index, value_counts.values)
            ax.set_xlabel(x_col)
            ax.set_ylabel(y_col or 'Count')
        elif plot_type == 'line':
            ax.plot(pandas_df[x_col], pandas_df[y_col])
            ax.set_xlabel(x_col)
            ax.set_ylabel(y_col)
        
        ax.set_title(plot_title)
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for i in range(num_plots, len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    plt.show()
    return fig

# Quick plotting helper for common PySpark visualizations
def quick_multiplot_comparison(yearly_data_dict, plot_type='hist', column='trip_distance', 
                              cols=2, figsize=(15, 10)):
    """
    Quick comparison plots across multiple years
    
    Parameters:
    -----------
    yearly_data_dict : dict
        Dictionary with year as key, PySpark DataFrame as value
    plot_type : str
        Type of plot ('hist', 'box', 'violin')
    column : str
        Column to plot
    cols : int
        Number of columns in layout
    figsize : tuple
        Figure size
        
    Example:
    --------
    yearly_data = {
        2020: df_2020,
        2021: df_2021,
        2022: df_2022
    }
    quick_multiplot_comparison(yearly_data, 'hist', 'trip_distance')
    """
    
    plot_configs = []
    for year, df in yearly_data_dict.items():
        config = {
            'type': plot_type,
            'x': column,
            'title': f'{year} - {column.replace("_", " ").title()}'
        }
        plot_configs.append((df, config))
    
    return multiplot_spark(plot_configs, cols=cols, figsize=figsize, 
                          title=f"Year-over-Year Comparison: {column.replace('_', ' ').title()}")

print("✅ Multi-panel plotting functions created (R multiplot equivalent)")
print("✅ PySpark-specific multiplot functions ready")
print("✅ Quick comparison plotting utilities available")

# Example usage guide
print("\n" + "="*50)
print("MULTIPLOT USAGE EXAMPLES")
print("="*50)
print("""
# Basic usage:
fig1, ax1 = plt.subplots()
ax1.plot([1,2,3], [1,4,2])

fig2, ax2 = plt.subplots() 
ax2.bar([1,2,3], [3,1,4])

combined = multiplot(fig1, fig2, cols=2, title="Side by Side")

# PySpark DataFrame plotting:
plot_configs = [
    (df_20, {'type': 'hist', 'x': 'trip_distance', 'title': '2020 Trips'}),
    (df_2021, {'type': 'hist', 'x': 'trip_distance', 'title': '2021 Trips'})
]
multiplot_spark(plot_configs, cols=2)

# Quick year comparison:
yearly_data = {2020: df_2020, 2021: df_2021, 2022: df_2022}
quick_multiplot_comparison(yearly_data, 'hist', 'trip_distance')
""")

**Test if data batching works**

In [None]:
START_YEAR = 2015
END_YEAR = 2016

print(f"✅ Spark Version: {spark.version}")
print(f"📊 Working with: {PROJECT_ID}.{DATASET_ID}")
print(f"🔄 Incremental Processing for years: {START_YEAR}-{END_YEAR}")

# ---------------------------------------------------------------
# Read data from BigQuery in yearly batches
# ---------------------------------------------------------------
table_name = f"{PROJECT_ID}.{DATASET_ID}.fact_trips"

for year in range(START_YEAR, END_YEAR + 1):

    # Build filter for one year
    year_filter = f"EXTRACT(YEAR FROM pickup_date) = {year}"

    # Load data for that year
    df_year = (
        spark.read
        .format("bigquery")
        .option("table", table_name)
        .option("filter", year_filter)
        .load()
    )

    # Add year_month for inspection
    df_with_month = df_year.withColumn(
        "year_month",
        F.substring("pickup_date", 1, 7)
    )

    # Show distinct months in this batch
    months_df = (
        df_with_month
        .select("year_month")
        .distinct()
        .orderBy("year_month")
    )

    months_list = months_df.collect()
    months_str_list = [row["year_month"] for row in months_list]

    print("✅ Loaded batch for year:", year)
    months_df.show(truncate=False)

    print(f"📅 Number of months found: {len(months_str_list)}")
    if months_str_list:
        print(f"🗓️ First month in batch: {months_str_list[0]}")
        print(f"🗓️ Last month in batch: {months_str_list[-1]}")
    else:
        print("⚠️ No data found for this batch.")

    # ----
    # YOUR LOGIC HERE for processing df_with_month
    # ----

    print("="*80)

## Load Data (for BigQuery)

In [None]:
# PySpark equivalent of R's fread for BigQuery data loading
# R: train <- as_tibble(fread(cmd = 'unzip -p ../input/nyc-taxi-trip-duration/train.zip'))

def load_taxi_data_optimized(years=[2015, 2016], table_name="yellow_taxi_external_table", 
                           sample_fraction=None, cache=True):
    """
    Load taxi data from BigQuery with optimized performance
    Equivalent to R's fread but for BigQuery tables
    
    Parameters:
    -----------
    years : list
        Years to load (default: [2015, 2016])
    table_name : str
        BigQuery table name
    sample_fraction : float, optional
        Sample fraction for faster development (e.g., 0.1 for 10%)
    cache : bool
        Whether to cache the data in memory
        
    Returns:
    --------
    PySpark DataFrame
        Combined data from specified years
    """
    
    print(f"📊 Loading taxi data for years: {years}")
    print(f"📋 Table: {table_name}")
    
    # Build optimized query for multiple years
    years_filter = ', '.join(map(str, years))
    
    # Optimized query with year filtering and column selection
    query = f"""
    SELECT 
        id,
        vendor_id,
        tpep_pickup_datetime,
        tpep_dropoff_datetime,
        passenger_count,
        trip_distance,
        pickup_longitude,
        pickup_latitude,
        dropoff_longitude,
        dropoff_latitude,
        payment_type,
        fare_amount,
        extra,
        mta_tax,
        tip_amount,
        tolls_amount,
        total_amount,
        -- Calculate trip duration in seconds (our target variable)
        TIMESTAMP_DIFF(tpep_dropoff_datetime, tpep_pickup_datetime, SECOND) as trip_duration_seconds,
        -- Extract useful time features
        EXTRACT(YEAR FROM tpep_pickup_datetime) as pickup_year,
        EXTRACT(MONTH FROM tpep_pickup_datetime) as pickup_month,
        EXTRACT(DAY FROM tpep_pickup_datetime) as pickup_day,
        EXTRACT(HOUR FROM tpep_pickup_datetime) as pickup_hour,
        EXTRACT(DAYOFWEEK FROM tpep_pickup_datetime) as pickup_dayofweek
    FROM `{PROJECT_ID}.{DATASET_ID}.{table_name}`
    WHERE 
        EXTRACT(YEAR FROM tpep_pickup_datetime) IN ({years_filter})
        AND tpep_pickup_datetime IS NOT NULL
        AND tpep_dropoff_datetime IS NOT NULL
        AND trip_distance > 0
        AND fare_amount > 0
        AND tpep_pickup_datetime < tpep_dropoff_datetime
    """
    
    # Add sampling if specified
    if sample_fraction:
        query += f" AND RAND() < {sample_fraction}"
        print(f"🎲 Sampling {sample_fraction*100}% of data for faster development")
    
    try:
        # Load data with optimized BigQuery connector
        df = spark.read \
            .format("bigquery") \
            .option("query", query) \
            .option("readDataFormat", "ARROW") \
            .option("materializationProject", PROJECT_ID) \
            .option("materializationDataset", DATASET_ID) \
            .load()
        
        # Cache if requested
        if cache:
            df = df.cache()
            print(f"💾 Data cached in memory for faster access")
        
        # Show basic info
        total_rows = df.count()
        print(f"✅ Loaded {total_rows:,} records")
        print(f"📊 Columns: {len(df.columns)}")
        
        return df
        
    except Exception as e:
        print(f"❌ Error loading data: {e}")
        return None

# Load 2015 and 2016 data (equivalent to R's train dataset)
print("🚀 Loading NYC Taxi Data (2015 & 2016)...")
taxi_data = load_taxi_data_optimized(years=[2015, 2016], 
                                   table_name="yellow_taxi_external_table",
                                   sample_fraction=0.1,  # 10% sample for development
                                   cache=True)

# Show basic statistics (equivalent to R's summary())
if taxi_data:
    print("\n📈 Dataset Overview:")
    taxi_data.describe().show()
    
    print("\n🗂️ Schema:")
    taxi_data.printSchema()
    
    print("\n📊 Year Distribution:")
    taxi_data.groupBy("pickup_year").count().orderBy("pickup_year").show()


In [None]:
# Train/Test Splitting Strategies for Year-by-Year Processing
from pyspark.sql.functions import col, rand, when, month, dayofmonth
from pyspark.sql.types import IntegerType

def temporal_train_test_split(df, strategy="chronological", test_ratio=0.2, 
                            random_seed=42, validation_split=True):
    """
    Split time series data into train/test sets using various strategies
    
    Parameters:
    -----------
    df : PySpark DataFrame
        Input data with datetime columns
    strategy : str
        Splitting strategy: 'chronological', 'random', 'stratified_temporal', 'holdout_months'
    test_ratio : float
        Proportion of data for testing (default: 0.2 = 20%)
    random_seed : int
        Random seed for reproducibility
    validation_split : bool
        Whether to create a validation set (splits train further)
        
    Returns:
    --------
    dict
        Dictionary with 'train', 'test', and optionally 'validation' DataFrames
    """
    
    print(f"🔄 Splitting data using '{strategy}' strategy")
    print(f"📊 Test ratio: {test_ratio*100}%")
    
    if strategy == "chronological":
        # RECOMMENDED for time series: Use 2015 for training, 2016 for testing
        train_df = df.filter(col("pickup_year") == 2015)
        test_df = df.filter(col("pickup_year") == 2016)
        
        result = {"train": train_df, "test": test_df}
        
        if validation_split:
            # Use last 2 months of 2015 for validation
            train_final = train_df.filter(col("pickup_month") <= 10)
            validation_df = train_df.filter(col("pickup_month") > 10)
            result["train"] = train_final
            result["validation"] = validation_df
            
        print("✅ Chronological split: 2015 (train) vs 2016 (test)")
        
    elif strategy == "random":
        # Random split within each year to maintain temporal balance
        df_with_split = df.withColumn("rand_val", rand(seed=random_seed))
        
        train_df = df_with_split.filter(col("rand_val") > test_ratio)
        test_df = df_with_split.filter(col("rand_val") <= test_ratio)
        
        result = {"train": train_df.drop("rand_val"), "test": test_df.drop("rand_val")}
        
        if validation_split:
            val_ratio = test_ratio / 2
            train_final = train_df.filter(col("rand_val") > (test_ratio + val_ratio))
            validation_df = train_df.filter(
                (col("rand_val") > test_ratio) & (col("rand_val") <= (test_ratio + val_ratio))
            )
            result["train"] = train_final.drop("rand_val")
            result["validation"] = validation_df.drop("rand_val")
            
        print("✅ Random split across both years")
        
    elif strategy == "stratified_temporal":
        # Stratified split maintaining temporal patterns (by month)
        df_with_split = df.withColumn("rand_val", rand(seed=random_seed))
        
        # Create stratified split by month
        train_df = df_with_split.filter(col("rand_val") > test_ratio)
        test_df = df_with_split.filter(col("rand_val") <= test_ratio)
        
        result = {"train": train_df.drop("rand_val"), "test": test_df.drop("rand_val")}
        print("✅ Stratified temporal split")
        
    elif strategy == "holdout_months":
        # Hold out specific months for testing (e.g., Dec 2015 and Dec 2016)
        train_df = df.filter(col("pickup_month") != 12)
        test_df = df.filter(col("pickup_month") == 12)
        
        result = {"train": train_df, "test": test_df}
        
        if validation_split:
            # Use November for validation
            train_final = train_df.filter(col("pickup_month") != 11)
            validation_df = train_df.filter(col("pickup_month") == 11)
            result["train"] = train_final
            result["validation"] = validation_df
            
        print("✅ Holdout months split: December for testing")
        
    else:
        raise ValueError(f"Unknown strategy: {strategy}")
    
    # Show split statistics
    for split_name, split_df in result.items():
        count = split_df.count()
        print(f"📊 {split_name.capitalize()}: {count:,} records")
        
        # Show year distribution
        year_dist = split_df.groupBy("pickup_year").count().collect()
        year_info = ", ".join([f"{row.pickup_year}: {row.count:,}" for row in year_dist])
        print(f"   Year distribution: {year_info}")
    
    return result

# RECOMMENDED APPROACH for your use case
def create_production_splits(df, approach="chronological_with_validation"):
    """
    Create production-ready train/test splits optimized for taxi duration prediction
    
    Parameters:
    -----------
    df : PySpark DataFrame
        Full dataset
    approach : str
        'chronological_with_validation' or 'cross_temporal_validation'
        
    Returns:
    --------
    dict
        Split datasets ready for ML pipeline
    """
    
    if approach == "chronological_with_validation":
        # BEST for your case: 2015 for training, 2016 for testing
        splits = temporal_train_test_split(df, strategy="chronological", 
                                         validation_split=True)
        
        print("\n🎯 RECOMMENDED APPROACH:")
        print("   • Train: 2015 (Jan-Oct)")
        print("   • Validation: 2015 (Nov-Dec)")
        print("   • Test: 2016 (Full year)")
        print("   • Benefit: Tests model on completely unseen future data")
        
    elif approach == "cross_temporal_validation":
        # Alternative: Mixed years with temporal awareness
        splits = temporal_train_test_split(df, strategy="stratified_temporal", 
                                         test_ratio=0.2, validation_split=True)
        
        print("\n🔄 ALTERNATIVE APPROACH:")
        print("   • Mixed temporal cross-validation")
        print("   • Maintains seasonal patterns in all splits")
        
    return splits

# Apply the recommended splitting strategy
print("🚀 Creating production-ready train/test splits...")
if 'taxi_data' in locals() and taxi_data is not None:
    splits = create_production_splits(taxi_data, approach="chronological_with_validation")
    
    # Store splits for later use
    train_data = splits["train"]
    validation_data = splits["validation"] 
    test_data = splits["test"]
    
    print("\n✅ Splits created successfully!")
    print("📊 Ready for feature engineering and modeling")
else:
    print("⚠️ Load taxi_data first before creating splits")


In [None]:
# Strategic Recommendations for Your Dataset

print("="*60)
print("🎯 STRATEGIC RECOMMENDATIONS FOR YOUR APPROACH")
print("="*60)

recommendations = {
    "1. RECOMMENDED SPLITTING STRATEGY": {
        "approach": "Chronological Split",
        "implementation": "2015 for training, 2016 for testing",
        "benefits": [
            "Tests model on completely unseen future data",
            "Reflects real-world deployment scenario",
            "Captures temporal drift and seasonality changes",
            "Avoids data leakage from future to past"
        ],
        "split_details": {
            "train": "2015 (Jan-Oct) - 83% of year",
            "validation": "2015 (Nov-Dec) - 17% of year", 
            "test": "2016 (Full year) - 100% of year"
        }
    },
    
    "2. INCREMENTAL PROCESSING STRATEGY": {
        "approach": "Year-by-Year Loading with Caching",
        "implementation": "Load → Process → Cache → Combine",
        "benefits": [
            "Memory efficient for large datasets",
            "Allows for year-specific feature engineering",
            "Enables temporal analysis and comparison",
            "Scalable to additional years"
        ],
        "code_pattern": """
        # Year-by-year processing
        df_2015 = load_data_by_year("table", 2015).cache()
        df_2016 = load_data_by_year("table", 2016).cache()
        
        # Combine for modeling
        combined_df = df_2015.union(df_2016)
        """
    },
    
    "3. ALTERNATIVE STRATEGIES": {
        "random_split": {
            "use_case": "When temporal patterns are not important",
            "pros": ["Balanced distribution", "Standard ML approach"],
            "cons": ["Ignores temporal dependencies", "Data leakage risk"]
        },
        "stratified_temporal": {
            "use_case": "When you need seasonal balance in all splits",
            "pros": ["Maintains seasonal patterns", "Good for seasonal models"],
            "cons": ["Still has some data leakage", "Complex to implement"]
        },
        "holdout_months": {
            "use_case": "Testing specific seasonal performance",
            "pros": ["Tests seasonal robustness", "Easy to interpret"],
            "cons": ["Limited test data", "May not be representative"]
        }
    },
    
    "4. PRODUCTION CONSIDERATIONS": {
        "data_quality": [
            "Filter out invalid trips (duration < 0, distance = 0)",
            "Handle outliers (trips > 24 hours, unrealistic speeds)",
            "Validate coordinate ranges (NYC area only)"
        ],
        "feature_engineering": [
            "Extract temporal features (hour, day, month, season)",
            "Calculate haversine distance for validation",
            "Create pickup/dropoff zone features",
            "Add weather data if available"
        ],
        "model_validation": [
            "Use chronological validation",
            "Monitor for temporal drift",
            "Validate on different time periods",
            "Check performance across seasons"
        ]
    }
}

# Display recommendations
for category, details in recommendations.items():
    print(f"\n{category}")
    print("-" * len(category))
    
    if isinstance(details, dict):
        if "approach" in details:
            print(f"📋 Approach: {details['approach']}")
            print(f"🔧 Implementation: {details['implementation']}")
            
            if "benefits" in details:
                print("✅ Benefits:")
                for benefit in details["benefits"]:
                    print(f"   • {benefit}")
                    
            if "split_details" in details:
                print("📊 Split Details:")
                for split, detail in details["split_details"].items():
                    print(f"   • {split.capitalize()}: {detail}")
                    
            if "code_pattern" in details:
                print("💻 Code Pattern:")
                print(details["code_pattern"])
        else:
            for key, value in details.items():
                print(f"\n🔸 {key.replace('_', ' ').title()}:")
                if isinstance(value, dict):
                    for subkey, subvalue in value.items():
                        print(f"   {subkey}: {subvalue}")
                elif isinstance(value, list):
                    for item in value:
                        print(f"   • {item}")
                else:
                    print(f"   {value}")

print("\n" + "="*60)
print("🚀 NEXT STEPS")
print("="*60)
print("""
1. Load your 2015 & 2016 data using the optimized functions above
2. Apply the chronological splitting strategy
3. Perform exploratory data analysis using the multiplot functions
4. Engineer temporal and spatial features
5. Build and validate your trip duration prediction model
6. Monitor model performance across different time periods

Remember: The chronological split (2015 → 2016) is the most realistic 
approach for time series prediction tasks like taxi trip duration!
""")


In [None]:
# Multi-panel plotting function (R multiplot equivalent)
# Courtesy of R Cookbooks adapted for Python/PySpark
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.figure import Figure
import numpy as np
import math
from typing import List, Optional, Union, Tuple

def multiplot(*plots, plotlist=None, cols=1, layout=None, figsize=(15, 10), 
              title=None, save_path=None, dpi=300):
    """
    Create multi-panel plots from matplotlib/seaborn/plotly figures
    
    Python equivalent of R's multiplot function from R Cookbooks
    
    Parameters:
    -----------
    *plots : matplotlib figures or plot functions
        Individual plots to be arranged
    plotlist : list, optional
        List of plots as alternative to *plots
    cols : int, default=1
        Number of columns in layout
    layout : array-like, optional
        Matrix specifying the layout. If present, 'cols' is ignored.
        Example: [[1,2], [3,3]] means plot 1 top-left, 2 top-right, 3 bottom spanning both columns
    figsize : tuple, default=(15, 10)
        Figure size (width, height) in inches
    title : str, optional
        Overall title for the multi-panel plot
    save_path : str, optional
        Path to save the figure
    dpi : int, default=300
        Resolution for saved figure
        
    Returns:
    --------
    matplotlib.figure.Figure
        The combined figure with all subplots
        
    Example:
    --------
    # Create individual plots
    fig1, ax1 = plt.subplots()
    ax1.plot([1,2,3], [1,4,2])
    
    fig2, ax2 = plt.subplots()
    ax2.bar([1,2,3], [3,1,4])
    
    # Combine them
    combined_fig = multiplot(fig1, fig2, cols=2, title="Combined Analysis")
    """
    
    # Combine plots from arguments and plotlist
    all_plots = list(plots) if plots else []
    if plotlist:
        all_plots.extend(plotlist)
    
    num_plots = len(all_plots)
    
    if num_plots == 0:
        print("⚠️ No plots provided")
        return None
    
    # Handle single plot case
    if num_plots == 1:
        if hasattr(all_plots[0], 'show'):
            all_plots[0].show()
        else:
            plt.figure(figsize=figsize)
            if title:
                plt.suptitle(title, fontsize=16, fontweight='bold')
            plt.show()
        return all_plots[0]
    
    # Determine layout
    if layout is None:
        # Calculate rows and columns
        nrows = math.ceil(num_plots / cols)
        ncols = cols
        layout_matrix = np.arange(1, cols * nrows + 1).reshape(nrows, ncols)
    else:
        layout_matrix = np.array(layout)
        nrows, ncols = layout_matrix.shape
    
    # Create the main figure
    fig = plt.figure(figsize=figsize)
    
    if title:
        fig.suptitle(title, fontsize=16, fontweight='bold', y=0.95)
    
    # Create GridSpec for flexible subplot arrangement
    gs = gridspec.GridSpec(nrows, ncols, figure=fig, hspace=0.3, wspace=0.3)
    
    # Place each plot in the correct position
    for i, plot in enumerate(all_plots, 1):
        if i > num_plots:
            break
            
        # Find positions where this plot should go
        positions = np.where(layout_matrix == i)
        
        if len(positions[0]) == 0:
            continue
            
        # Calculate subplot span
        row_min, row_max = positions[0].min(), positions[0].max()
        col_min, col_max = positions[1].min(), positions[1].max()
        
        # Create subplot
        ax = fig.add_subplot(gs[row_min:row_max+1, col_min:col_max+1])
        
        # Handle different plot types
        if hasattr(plot, 'figure'):
            # It's a matplotlib figure
            _copy_plot_to_axis(plot, ax)
        elif callable(plot):
            # It's a plotting function
            plot(ax)
        elif hasattr(plot, 'axes'):
            # It's a figure with axes
            _copy_plot_to_axis(plot, ax)
        else:
            # Try to handle as data for direct plotting
            ax.text(0.5, 0.5, f'Plot {i}', ha='center', va='center', 
                   transform=ax.transAxes)
    
    plt.tight_layout()
    
    # Save if path provided
    if save_path:
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
        print(f"💾 Multi-panel plot saved to: {save_path}")
    
    plt.show()
    return fig

def _copy_plot_to_axis(source_fig, target_ax):
    """Helper function to copy plot content from source figure to target axis"""
    try:
        if hasattr(source_fig, 'axes') and source_fig.axes:
            source_ax = source_fig.axes[0]
            
            # Copy lines
            for line in source_ax.get_lines():
                target_ax.plot(line.get_xdata(), line.get_ydata(), 
                             color=line.get_color(), linewidth=line.get_linewidth(),
                             linestyle=line.get_linestyle(), marker=line.get_marker(),
                             label=line.get_label())
            
            # Copy patches (bars, etc.)
            for patch in source_ax.patches:
                target_ax.add_patch(patch)
            
            # Copy collections (scatter plots, etc.)
            for collection in source_ax.collections:
                target_ax.add_collection(collection)
            
            # Copy labels and title
            target_ax.set_xlabel(source_ax.get_xlabel())
            target_ax.set_ylabel(source_ax.get_ylabel())
            target_ax.set_title(source_ax.get_title())
            
            # Copy limits
            target_ax.set_xlim(source_ax.get_xlim())
            target_ax.set_ylim(source_ax.get_ylim())
            
            # Copy legend if exists
            if source_ax.get_legend():
                target_ax.legend()
                
    except Exception as e:
        # Fallback: just add a text placeholder
        target_ax.text(0.5, 0.5, 'Plot Content', ha='center', va='center',
                      transform=target_ax.transAxes)
        print(f"⚠️ Could not copy plot content: {e}")

# Enhanced multiplot for PySpark DataFrames
def multiplot_spark(dataframes_and_plots, cols=2, figsize=(15, 10), title=None):
    """
    Create multi-panel plots specifically for PySpark DataFrame visualizations
    
    Parameters:
    -----------
    dataframes_and_plots : list of tuples
        Each tuple contains (spark_dataframe, plot_config)
        plot_config is a dict with keys: 'type', 'x', 'y', 'title', etc.
    cols : int
        Number of columns
    figsize : tuple
        Figure size
    title : str
        Overall title
        
    Example:
    --------
    plot_configs = [
        (df_2020, {'type': 'hist', 'x': 'trip_distance', 'title': '2020 Trip Distance'}),
        (df_2021, {'type': 'scatter', 'x': 'trip_distance', 'y': 'fare_amount', 'title': '2021 Distance vs Fare'}),
        (df_2022, {'type': 'bar', 'x': 'pickup_hour', 'y': 'count', 'title': '2022 Hourly Trips'})
    ]
    multiplot_spark(plot_configs, cols=2, title="Yearly Comparison")
    """
    
    num_plots = len(dataframes_and_plots)
    nrows = math.ceil(num_plots / cols)
    
    fig, axes = plt.subplots(nrows, cols, figsize=figsize)
    if title:
        fig.suptitle(title, fontsize=16, fontweight='bold')
    
    # Flatten axes array for easy indexing
    if num_plots == 1:
        axes = [axes]
    elif nrows == 1:
        axes = axes
    else:
        axes = axes.flatten()
    
    for i, (spark_df, plot_config) in enumerate(dataframes_and_plots):
        if i >= len(axes):
            break
            
        ax = axes[i]
        
        # Convert to Pandas for plotting
        pandas_df = spark_df.toPandas()
        
        plot_type = plot_config.get('type', 'scatter')
        x_col = plot_config.get('x')
        y_col = plot_config.get('y')
        plot_title = plot_config.get('title', f'Plot {i+1}')
        
        # Create the appropriate plot
        if plot_type == 'hist':
            ax.hist(pandas_df[x_col], bins=30, alpha=0.7)
            ax.set_xlabel(x_col)
            ax.set_ylabel('Frequency')
        elif plot_type == 'scatter':
            ax.scatter(pandas_df[x_col], pandas_df[y_col], alpha=0.6)
            ax.set_xlabel(x_col)
            ax.set_ylabel(y_col)
        elif plot_type == 'bar':
            if y_col:
                ax.bar(pandas_df[x_col], pandas_df[y_col])
            else:
                value_counts = pandas_df[x_col].value_counts()
                ax.bar(value_counts.index, value_counts.values)
            ax.set_xlabel(x_col)
            ax.set_ylabel(y_col or 'Count')
        elif plot_type == 'line':
            ax.plot(pandas_df[x_col], pandas_df[y_col])
            ax.set_xlabel(x_col)
            ax.set_ylabel(y_col)
        
        ax.set_title(plot_title)
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for i in range(num_plots, len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    plt.show()
    return fig

# Quick plotting helper for common PySpark visualizations
def quick_multiplot_comparison(yearly_data_dict, plot_type='hist', column='trip_distance', 
                              cols=2, figsize=(15, 10)):
    """
    Quick comparison plots across multiple years
    
    Parameters:
    -----------
    yearly_data_dict : dict
        Dictionary with year as key, PySpark DataFrame as value
    plot_type : str
        Type of plot ('hist', 'box', 'violin')
    column : str
        Column to plot
    cols : int
        Number of columns in layout
    figsize : tuple
        Figure size
        
    Example:
    --------
    yearly_data = {
        2020: df_2020,
        2021: df_2021,
        2022: df_2022
    }
    quick_multiplot_comparison(yearly_data, 'hist', 'trip_distance')
    """
    
    plot_configs = []
    for year, df in yearly_data_dict.items():
        config = {
            'type': plot_type,
            'x': column,
            'title': f'{year} - {column.replace("_", " ").title()}'
        }
        plot_configs.append((df, config))
    
    return multiplot_spark(plot_configs, cols=cols, figsize=figsize, 
                          title=f"Year-over-Year Comparison: {column.replace('_', ' ').title()}")

print("✅ Multi-panel plotting functions created (R multiplot equivalent)")
print("✅ PySpark-specific multiplot functions ready")
print("✅ Quick comparison plotting utilities available")

# Example usage guide
print("\n" + "="*50)
print("MULTIPLOT USAGE EXAMPLES")
print("="*50)
print("""
# Basic usage:
fig1, ax1 = plt.subplots()
ax1.plot([1,2,3], [1,4,2])

fig2, ax2 = plt.subplots() 
ax2.bar([1,2,3], [3,1,4])

combined = multiplot(fig1, fig2, cols=2, title="Side by Side")

# PySpark DataFrame plotting:
plot_configs = [
    (df_2020, {'type': 'hist', 'x': 'trip_distance', 'title': '2020 Trips'}),
    (df_2021, {'type': 'hist', 'x': 'trip_distance', 'title': '2021 Trips'})
]
multiplot_spark(plot_configs, cols=2)

# Quick year comparison:
yearly_data = {2020: df_2020, 2021: df_2021, 2022: df_2022}
quick_multiplot_comparison(yearly_data, 'hist', 'trip_distance')
""")


In [None]:
# PySpark SQL Functions (dplyr equivalent)
from pyspark.sql import functions as F
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window

# Data manipulation imports (equivalent to dplyr, data.table, tibble, tidyr)
from pyspark.sql import DataFrame
from pyspark.sql.functions import (
    # dplyr equivalents
    col, lit, when, otherwise, desc, asc,
    sum as spark_sum, mean as spark_mean, count as spark_count,
    min as spark_min, max as spark_max, avg, stddev,
    
    # tidyr equivalents  
    explode, split, array, struct, collect_list, collect_set,
    
    # stringr equivalents
    regexp_replace, regexp_extract, lower, upper, trim, ltrim, rtrim,
    substring, length, concat, concat_ws, split as string_split,
    
    # lubridate equivalents
    year, month, dayofyear, dayofweek, hour, minute, second,
    date_format, to_date, to_timestamp, current_date, current_timestamp,
    datediff, months_between, add_months, date_add, date_sub
)

print("✅ PySpark SQL Functions imported (dplyr/data.table/tidyr/stringr/lubridate equivalents)")

# Window functions for advanced operations
window_specs = {
    'row_number': Window.orderBy("pickup_datetime"),
    'rank': Window.partitionBy("zone").orderBy(desc("trip_count")),
    'lag_lead': Window.partitionBy("route").orderBy("date")
}

print("✅ Window functions configured for advanced analytics")

In [None]:
# Visualization libraries (ggplot2, scales, grid, RColorBrewer equivalents)
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker  # scales equivalent
import matplotlib.gridspec as gridspec  # grid equivalent
import matplotlib.colors as mcolors  # RColorBrewer equivalent
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

import seaborn as sns  # ggplot2 + RColorBrewer equivalent
import plotly.express as px  # ggplot2 equivalent
import plotly.graph_objects as go
import plotly.figure_factory as ff
from plotly.subplots import make_subplots

# Configure plotting styles
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# RColorBrewer equivalent color palettes
color_palettes = {
    'Set1': sns.color_palette("Set1", 10),
    'Set2': sns.color_palette("Set2", 8), 
    'Dark2': sns.color_palette("Dark2", 8),
    'Paired': sns.color_palette("Paired", 12),
    'Blues': sns.color_palette("Blues", 10),
    'Reds': sns.color_palette("Reds", 10),
    'Greens': sns.color_palette("Greens", 10),
    'Viridis': sns.color_palette("viridis", 10),
    'Plasma': sns.color_palette("plasma", 10)
}

# Figure configuration
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.grid'] = True

print("✅ Visualization libraries imported (ggplot2/scales/grid/RColorBrewer equivalents)")
print(f"✅ Available color palettes: {list(color_palettes.keys())}")

# Utility function for quick plotting (ggplot2 qplot equivalent)
def quick_plot(df, x, y=None, kind='scatter', title="", color_col=None):
    """Quick plotting function similar to R's qplot"""
    if isinstance(df, DataFrame):  # Convert PySpark to Pandas
        df = df.toPandas()
    
    plt.figure(figsize=(10, 6))
    
    if kind == 'scatter':
        if color_col:
            plt.scatter(df[x], df[y], c=df[color_col], alpha=0.6)
        else:
            plt.scatter(df[x], df[y], alpha=0.6)
    elif kind == 'line':
        plt.plot(df[x], df[y])
    elif kind == 'hist':
        plt.hist(df[x], bins=30, alpha=0.7)
    elif kind == 'box':
        df.boxplot(column=x)
    
    plt.title(title)
    plt.xlabel(x)
    if y: plt.ylabel(y)
    plt.grid(True, alpha=0.3)
    plt.show()

print("✅ Quick plot function created (qplot equivalent)")

In [None]:
# Correlation plotting (corrplot equivalent) 
import numpy as np
from pyspark.ml.stat import Correlation
from pyspark.ml.feature import VectorAssembler

def spark_corrplot(spark_df, numeric_cols, method='pearson', figsize=(10, 8)):
    """
    Create correlation plot from PySpark DataFrame (corrplot equivalent)
    """
    # Assemble features for correlation
    assembler = VectorAssembler(inputCols=numeric_cols, outputCol="features")
    df_assembled = assembler.transform(spark_df)
    
    # Calculate correlation matrix
    correlation_matrix = Correlation.corr(df_assembled, "features", method).head()[0]
    correlation_array = np.array(correlation_matrix.toArray())
    
    # Create correlation heatmap
    plt.figure(figsize=figsize)
    mask = np.triu(np.ones_like(correlation_array, dtype=bool))  # Upper triangle mask
    
    sns.heatmap(correlation_array, 
                mask=mask,
                annot=True, 
                cmap='RdBu_r', 
                center=0,
                xticklabels=numeric_cols,
                yticklabels=numeric_cols,
                square=True,
                fmt='.2f')
    
    plt.title(f'{method.title()} Correlation Matrix')
    plt.tight_layout()
    plt.show()
    
    return correlation_array

# Alluvial/Sankey diagrams (alluvial equivalent)
def create_sankey_diagram(source_col, target_col, value_col, df, title="Flow Diagram"):
    """
    Create Sankey diagram from PySpark DataFrame (alluvial equivalent)
    """
    if isinstance(df, DataFrame):  # Convert PySpark to Pandas
        df_pandas = df.toPandas()
    else:
        df_pandas = df
    
    # Get unique nodes
    sources = df_pandas[source_col].unique()
    targets = df_pandas[target_col].unique()
    all_nodes = list(set(list(sources) + list(targets)))
    
    # Create node mapping
    node_map = {node: idx for idx, node in enumerate(all_nodes)}
    
    # Prepare data for Sankey
    source_indices = [node_map[src] for src in df_pandas[source_col]]
    target_indices = [node_map[tgt] for tgt in df_pandas[target_col]]
    values = df_pandas[value_col].tolist()
    
    # Create Sankey diagram
    fig = go.Figure(data=[go.Sankey(
        node = dict(
            pad = 15,
            thickness = 20,
            line = dict(color = "black", width = 0.5),
            label = all_nodes,
            color = px.colors.qualitative.Set3[:len(all_nodes)]
        ),
        link = dict(
            source = source_indices,
            target = target_indices, 
            value = values
        )
    )])
    
    fig.update_layout(title_text=title, font_size=10)
    fig.show()
    
    return fig

print("✅ Correlation plotting function created (corrplot equivalent)")
print("✅ Sankey diagram function created (alluvial equivalent)")


In [None]:
# Geospatial libraries (geosphere, leaflet, maps equivalents)
import folium  # leaflet equivalent
from folium import plugins  # leaflet.extras equivalent
import geopandas as gpd  # maps equivalent
from geopy.distance import geodesic  # geosphere equivalent
from geopy.geocoders import Nominatim

# PySpark geospatial functions
from pyspark.sql.functions import radians, cos, sin, asin, sqrt, atan2

def haversine_distance_spark(lat1, lon1, lat2, lon2):
    """
    Calculate Haversine distance in PySpark (geosphere equivalent)
    Returns distance in miles
    """
    # Convert to radians
    lat1_rad, lon1_rad = radians(lat1), radians(lon1)
    lat2_rad, lon2_rad = radians(lat2), radians(lon2)
    
    # Haversine formula
    dlat = lat2_rad - lat1_rad
    dlon = lon2_rad - lon1_rad
    
    a = sin(dlat/2)**2 + cos(lat1_rad) * cos(lat2_rad) * sin(dlon/2)**2
    c = 2 * asin(sqrt(a))
    
    # Earth radius in miles
    earth_radius_miles = 3959
    distance = earth_radius_miles * c
    
    return distance

def create_interactive_map(df, lat_col, lon_col, popup_col=None, title="Interactive Map"):
    """
    Create interactive map from PySpark DataFrame (leaflet equivalent)
    """
    if isinstance(df, DataFrame):  # Convert PySpark to Pandas
        df_pandas = df.toPandas()
    else:
        df_pandas = df
    
    # Get center coordinates
    center_lat = df_pandas[lat_col].mean()
    center_lon = df_pandas[lon_col].mean()
    
    # Create base map
    m = folium.Map(
        location=[center_lat, center_lon], 
        zoom_start=10,
        tiles='OpenStreetMap'
    )
    
    # Add markers
    for idx, row in df_pandas.iterrows():
        popup_text = str(row[popup_col]) if popup_col else f"Point {idx}"
        folium.Marker(
            location=[row[lat_col], row[lon_col]],
            popup=popup_text,
            icon=folium.Icon(color='blue', icon='info-sign')
        ).add_to(m)
    
    # Add heatmap if many points
    if len(df_pandas) > 100:
        heat_data = [[row[lat_col], row[lon_col]] for idx, row in df_pandas.iterrows()]
        plugins.HeatMap(heat_data).add_to(m)
    
    return m

def create_choropleth_map(geo_data, data_col, title="Choropleth Map"):
    """
    Create choropleth map (leaflet + geospatial equivalent)
    """
    # NYC bounds
    nyc_center = [40.7128, -74.0060]
    
    m = folium.Map(location=nyc_center, zoom_start=10)
    
    if isinstance(geo_data, dict):  # GeoJSON
        folium.Choropleth(
            geo_data=geo_data,
            data=data_col,
            columns=['zone_id', 'value'],
            key_on='feature.properties.zone_id',
            fill_color='YlOrRd',
            fill_opacity=0.7,
            line_opacity=0.2,
            legend_name=title
        ).add_to(m)
    
    return m

# NYC Taxi Zone utilities
def load_nyc_taxi_zones():
    """Load NYC taxi zone shapefile for mapping"""
    # This would typically load from GeoJSON or shapefile
    # For now, we'll create a placeholder
    print("📍 NYC Taxi Zones loader ready")
    print("   Use: geopandas.read_file('path/to/taxi_zones.geojson')")
    return None

print("✅ Geospatial libraries imported (geosphere/leaflet/maps equivalents)")
print("✅ Haversine distance function created for PySpark")
print("✅ Interactive mapping functions created (leaflet equivalent)")
print("✅ Choropleth mapping ready")


In [None]:
# Machine Learning libraries (xgboost, caret equivalents)
from pyspark.ml import Pipeline
from pyspark.ml.feature import (
    VectorAssembler, StandardScaler, StringIndexer, OneHotEncoder,
    Bucketizer, QuantileDiscretizer, PCA, MinMaxScaler
)
from pyspark.ml.classification import (
    RandomForestClassifier, GBTClassifier, LogisticRegression,
    DecisionTreeClassifier, LinearSVC, MultilayerPerceptronClassifier
)
from pyspark.ml.regression import (
    RandomForestRegressor, GBTRegressor, LinearRegression,
    DecisionTreeRegressor, GeneralizedLinearRegression
)
from pyspark.ml.evaluation import (
    RegressionEvaluator, BinaryClassificationEvaluator,
    MulticlassClassificationEvaluator
)
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, TrainValidationSplit
from pyspark.ml.stat import Correlation

# Traditional ML libraries for comparison
import xgboost as xgb  # Direct xgboost equivalent
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.ensemble import RandomForestRegressor as SKRandomForest
from sklearn.metrics import mean_squared_error, r2_score, classification_report

# ML utility functions (caret equivalent)
def create_ml_pipeline(feature_cols, target_col, model_type='rf', task='regression'):
    """
    Create ML pipeline (caret trainControl equivalent)
    """
    # Feature engineering stages
    assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
    scaler = StandardScaler(inputCol="features", outputCol="scaled_features")
    
    # Model selection
    if task == 'regression':
        if model_type == 'rf':
            model = RandomForestRegressor(featuresCol="scaled_features", labelCol=target_col)
        elif model_type == 'gbt':
            model = GBTRegressor(featuresCol="scaled_features", labelCol=target_col)
        elif model_type == 'linear':
            model = LinearRegression(featuresCol="scaled_features", labelCol=target_col)
    else:  # classification
        if model_type == 'rf':
            model = RandomForestClassifier(featuresCol="scaled_features", labelCol=target_col)
        elif model_type == 'gbt':
            model = GBTClassifier(featuresCol="scaled_features", labelCol=target_col)
        elif model_type == 'logistic':
            model = LogisticRegression(featuresCol="scaled_features", labelCol=target_col)
    
    # Create pipeline
    pipeline = Pipeline(stages=[assembler, scaler, model])
    
    return pipeline

def tune_hyperparameters(pipeline, train_data, param_grid, evaluator, folds=3):
    """
    Hyperparameter tuning (caret tune equivalent)
    """
    cv = CrossValidator(
        estimator=pipeline,
        estimatorParamMaps=param_grid,
        evaluator=evaluator,
        numFolds=folds,
        seed=42
    )
    
    cv_model = cv.fit(train_data)
    return cv_model

def evaluate_model(model, test_data, task='regression'):
    """
    Model evaluation (caret confusionMatrix equivalent)
    """
    predictions = model.transform(test_data)
    
    if task == 'regression':
        evaluator = RegressionEvaluator(labelCol="label", predictionCol="prediction")
        rmse = evaluator.evaluate(predictions, {evaluator.metricName: "rmse"})
        r2 = evaluator.evaluate(predictions, {evaluator.metricName: "r2"})
        mae = evaluator.evaluate(predictions, {evaluator.metricName: "mae"})
        
        print(f"📊 Regression Metrics:")
        print(f"   RMSE: {rmse:.4f}")
        print(f"   R²: {r2:.4f}")  
        print(f"   MAE: {mae:.4f}")
        
        return {"rmse": rmse, "r2": r2, "mae": mae}
    else:
        evaluator = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="rawPrediction")
        auc = evaluator.evaluate(predictions)
        
        mc_evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction")
        accuracy = mc_evaluator.evaluate(predictions, {mc_evaluator.metricName: "accuracy"})
        precision = mc_evaluator.evaluate(predictions, {mc_evaluator.metricName: "weightedPrecision"})
        recall = mc_evaluator.evaluate(predictions, {mc_evaluator.metricName: "weightedRecall"})
        
        print(f"📊 Classification Metrics:")
        print(f"   AUC: {auc:.4f}")
        print(f"   Accuracy: {accuracy:.4f}")
        print(f"   Precision: {precision:.4f}")
        print(f"   Recall: {recall:.4f}")
        
        return {"auc": auc, "accuracy": accuracy, "precision": precision, "recall": recall}

# Feature importance function
def get_feature_importance(model, feature_cols):
    """
    Extract feature importance (caret varImp equivalent)
    """
    if hasattr(model, 'stages'):
        # Extract from pipeline
        final_model = model.stages[-1]
    else:
        final_model = model
    
    if hasattr(final_model, 'featureImportances'):
        importances = final_model.featureImportances.toArray()
        feature_importance_df = spark.createDataFrame(
            zip(feature_cols, importances), 
            ['feature', 'importance']
        ).orderBy(desc('importance'))
        
        return feature_importance_df
    else:
        print("⚠️ Model doesn't support feature importance")
        return None

print("✅ PySpark ML libraries imported (xgboost/caret equivalents)")
print("✅ ML pipeline utilities created")
print("✅ Hyperparameter tuning functions ready")
print("✅ Model evaluation functions created")


In [None]:
# Load sample NYC taxi data from BigQuery
try:
    # Load taxi trip data
    yellow_taxi_df = spark.read \
        .format("bigquery") \
        .option("table", f"{PROJECT_ID}.{DATASET_ID}.yellow_taxi_external_table") \
        .option("maxPartitions", "50") \
        .load()
    
    # Sample for demonstration (remove .sample() for full analysis)
    sample_df = yellow_taxi_df.sample(0.01, seed=42)  # 1% sample
    
    print(f"📊 Loaded sample data: {sample_df.count():,} records")
    
    # Data manipulation using dplyr-equivalent functions
    processed_df = sample_df \
        .filter(
            (col("trip_distance") > 0) & 
            (col("trip_distance") < 50) &
            (col("fare_amount") > 0) &
            (col("PULocationID").isNotNull())
        ) \
        .withColumn("trip_duration_min", 
                   (unix_timestamp("tpep_dropoff_datetime") - 
                    unix_timestamp("tpep_pickup_datetime")) / 60) \
        .withColumn("avg_speed", col("trip_distance") / (col("trip_duration_min") / 60)) \
        .withColumn("pickup_hour", hour("tpep_pickup_datetime")) \
        .withColumn("pickup_day", dayofweek("tpep_pickup_datetime")) \
        .filter((col("trip_duration_min") > 1) & (col("avg_speed") <= 80))
    
    print("✅ Data processed using dplyr-equivalent operations")
    
    # Show basic statistics
    processed_df.select("trip_distance", "fare_amount", "avg_speed", "trip_duration_min") \
        .describe().show()
    
except Exception as e:
    print(f"⚠️ Note: To run this example, ensure your BigQuery tables are set up")
    print(f"Error: {e}")
    
    # Create sample data for demonstration
    from pyspark.sql.types import StructType, StructField, FloatType, IntegerType
    
    sample_data = [
        (10.5, 25.0, 35.2, 18.0, 14, 2),
        (5.2, 12.5, 42.1, 7.4, 9, 1),
        (15.8, 42.0, 28.5, 33.2, 17, 5),
        (3.1, 8.5, 38.7, 4.8, 7, 3)
    ]
    
    schema = StructType([
        StructField("trip_distance", FloatType()),
        StructField("fare_amount", FloatType()),
        StructField("avg_speed", FloatType()),
        StructField("trip_duration_min", FloatType()),
        StructField("pickup_hour", IntegerType()),
        StructField("pickup_day", IntegerType())
    ])
    
    processed_df = spark.createDataFrame(sample_data, schema)
    print("📊 Created sample dataset for demonstration")