In [None]:
%pip install prophet

In [None]:
import warnings

warnings.filterwarnings('ignore')



# Try to import Prophet

try:

    from prophet import Prophet

    PROPHET_AVAILABLE = True

except ImportError:

    PROPHET_AVAILABLE = False

    print("Prophet not available. Install with: pip install prophet")

In [None]:
import gradio as gr
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from google.cloud import bigquery
import xgboost as xgb
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.callbacks import EarlyStopping
import os
from datetime import datetime, timedelta

# Configuration

In [None]:
PROJECT_ID = os.environ.get("PROJECT_ID", "nyctaxi-467111")
BUCKET_NAME = "nyc_raw_data_bucket"
DATASET_NAME = "PreMlGold"
OUTPUT_DATASET = "PostMlGold"
MlDATASET_NAME = "PostMlGold"
TAXI_TYPES = ["yellow", "green", "fhv", "fhvhv"]

# Initialize BigQuery client
client = bigquery.Client(project=PROJECT_ID)

# Global variables for data persistence across tabs
global_df = None
city_ts = None
trip_scaled = None
ml_taxi_type_value = None

In [None]:
bq_client = client

## holidays

In [None]:
import holidays

# US holidays for 2024
us_holidays = holidays.US(years=[2023,2024,2025])

us_holidays

In [None]:
US_HOLIDAYS = [ (str(d), hname) for d,hname in us_holidays.items()]
US_HOLIDAYS

# ==================== Data Access Functions ====================

In [None]:
def get_available_partitions(taxi_type: str) -> list:
    """Get available year_month partitions for a taxi type."""
    query = f"""
    SELECT table_name
    FROM `{PROJECT_ID}.{DATASET_NAME}.INFORMATION_SCHEMA.TABLES`
    WHERE REGEXP_CONTAINS(table_name, r'^{taxi_type}_[0-9]{{4}}_[0-9]{{2}}_hourly$')
    """

    results = client.query(query).result()
    partitions = []

    import re
    pattern = re.compile(rf'^{taxi_type}_(\d{{4}})_(\d{{2}})_hourly$')

    for row in results:
        match = pattern.match(row.table_name)
        if match:
            year, month = match.groups()
            partitions.append(f"{year}_{month}")

    return sorted(partitions)

def load_partitions(taxi_type: str, partitions: list) -> pd.DataFrame:
    """Load and union selected monthly partitions."""
    if not partitions:
        raise ValueError("No partitions selected")

    union_queries = []
    for part in partitions:
        table_name = f"{taxi_type}_{part}_hourly"
        union_queries.append(f"""
            SELECT pickup_date, pickup_hour, trips
            FROM `{PROJECT_ID}.{DATASET_NAME}.{table_name}`
        """)

    full_query = " UNION ALL ".join(union_queries)
    df = client.query(full_query).to_dataframe()
    return df

def fetch_time_series_data(taxi_type: str, partitions: list, granularity: str = "Daily") -> pd.DataFrame:
    """Fetch aggregated time series data."""
    if not partitions:
        return pd.DataFrame()

    # Build UNION query for selected partitions
    union_queries = []
    for part in partitions:
        table_name = f"{taxi_type}_{part}_hourly"
        union_queries.append(f"""
            SELECT pickup_date, pickup_hour,
                   SUM(trips) as trips,
                   SUM(revenue) as revenue
            FROM `{PROJECT_ID}.{DATASET_NAME}.{table_name}`
            GROUP BY pickup_date, pickup_hour
        """)

    query = f"""
    WITH all_data AS (
        {" UNION ALL ".join(union_queries)}
    )
    SELECT pickup_date, pickup_hour,
           SUM(trips) as total_trips,
           SUM(revenue) as total_revenue
    FROM all_data
    GROUP BY pickup_date, pickup_hour
    ORDER BY pickup_date, pickup_hour
    """

    df = client.query(query).to_dataframe()

    # Create datetime column
    df['pickup_datetime'] = pd.to_datetime(df['pickup_date']) + pd.to_timedelta(df['pickup_hour'], unit='h')

    if granularity == "Daily":
        df = df.groupby(pd.Grouper(key='pickup_datetime', freq='D')).agg({
            'total_trips': 'sum',
            'total_revenue': 'sum'
        }).reset_index()

    return df


# ==================== Feature Engineering ====================


In [None]:
def create_features(df: pd.DataFrame) -> pd.DataFrame:
    """Create time-based features for ML models."""
    # Combine date and hour into datetime index
    df['datetime'] = pd.to_datetime(df['pickup_date'].astype(str) + ' ' +
                                   df['pickup_hour'].astype(str) + ':00:00')
    df = df.set_index('datetime').sort_index()
    df = df.drop(['pickup_date', 'pickup_hour'], axis=1, errors='ignore')

    # Calendar features
    df['hour'] = df.index.hour
    df['weekday'] = df.index.dayofweek
    df['day_of_month'] = df.index.day
    df['week_of_year'] = df.index.isocalendar().week.astype(int)
    df['is_weekend'] = (df['weekday'] >= 5).astype(int)

    # Cyclical encoding
    df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
    df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
    df['weekday_sin'] = np.sin(2 * np.pi * df['weekday'] / 7)
    df['weekday_cos'] = np.cos(2 * np.pi * df['weekday'] / 7)

    # Lag features (avoid data leakage)
    df['lag_1h'] = df['trips'].shift(1)
    df['lag_24h'] = df['trips'].shift(24)
    df['lag_168h'] = df['trips'].shift(168)  # 1 week

    # Rolling statistics (with shift to avoid leakage)
    df['rolling_mean_3h'] = df['trips'].shift(1).rolling(window=3, min_periods=1).mean()
    df['rolling_mean_24h'] = df['trips'].shift(1).rolling(window=24, min_periods=1).mean()
    df['rolling_std_24h'] = df['trips'].shift(1).rolling(window=24, min_periods=1).std()

    # Drop rows with NaN values from lag features
    df.dropna(inplace=True)

    return df


In [None]:




def load_actual_data_from_gcs(taxi_type, start_date, end_date):
    """Load actual data from BigQuery for the forecast period if available."""
    if not taxi_type:
        return pd.DataFrame()

    try:
        # Convert dates to get the partitions we need
        start_year_month = f"{start_date.year}_{start_date.strftime('%m')}"
        end_year_month = f"{end_date.year}_{end_date.strftime('%m')}"

        # Get available partitions
        all_partitions = get_available_partitions(taxi_type)

        # Filter partitions that fall within our date range
        relevant_partitions = []
        for partition in all_partitions:
            if start_year_month <= partition <= end_year_month:
                relevant_partitions.append(partition)

        if not relevant_partitions:
            return pd.DataFrame()

        # Build query to get hourly data for the exact date range
        union_queries = []
        for partition in relevant_partitions:
            table_name = f"{taxi_type}_{partition}_hourly"
            union_queries.append(f"""
                SELECT pickup_date, pickup_hour, trips
                FROM `{PROJECT_ID}.{DATASET_NAME}.{table_name}`
                WHERE DATETIME(pickup_date, TIME(pickup_hour, 0, 0)) >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
                  AND DATETIME(pickup_date, TIME(pickup_hour, 0, 0)) <= '{end_date.strftime('%Y-%m-%d %H:%M:%S')}'
            """)

        full_query = " UNION ALL ".join(union_queries)

        # Execute query
        df = client.query(full_query).to_dataframe()

        if df.empty:
            return pd.DataFrame()

        # Create datetime index
        df['datetime'] = pd.to_datetime(df['pickup_date']) + pd.to_timedelta(df['pickup_hour'], unit='h')
        df = df.set_index('datetime').sort_index()

        return df[['trips']]

    except Exception as e:
        print(f"Error loading actual future data: {str(e)}")
        return pd.DataFrame()

# ==================== Visualization Functions ====================

In [None]:
def plot_time_series(taxi_type, partitions, metric, granularity):
    """Create interactive time series plot."""
    if not taxi_type or not partitions:
        return px.line(title="Please select taxi type and partitions")

    df = fetch_time_series_data(taxi_type, partitions, granularity)

    if df.empty:
        return px.line(title="No data available for selected options")

    y_col = "total_trips" if metric == "Trips" else "total_revenue"
    title = f"{taxi_type.title()} Taxi - {metric} ({granularity})"

    fig = px.line(df, x="pickup_datetime", y=y_col, title=title, markers=True)
    fig.update_layout(
        xaxis_title="Date",
        yaxis_title=metric,
        hovermode="x unified",
        template="plotly_white"
    )
    return fig

# ==================== ML Functions ====================

In [None]:
def load_and_prepare_data_with_metric(taxi_type, start_partition, end_partition, metric="trips"):
    """Load data and prepare for ML with selected metric."""
    global global_df, city_ts, trip_scaled, ml_taxi_type_value, selected_metric

    selected_metric = metric  # Store selected metric globally

    if not taxi_type or not start_partition or not end_partition:
        return "Please select taxi type and date range"

    try:
        # Store the taxi type globally
        ml_taxi_type_value = taxi_type

        # Get all partitions in range
        all_partitions = get_available_partitions(taxi_type)

        # Filter partitions within range
        selected_partitions = []
        for part in all_partitions:
            if start_partition <= part <= end_partition:
                selected_partitions.append(part)

        if not selected_partitions:
            return "No data available in selected range"

        # Modified query to include revenue
        union_queries = []
        for part in selected_partitions:
            table_name = f"{taxi_type}_{part}_hourly"
            union_queries.append(f"""
                SELECT pickup_date, pickup_hour, trips, revenue
                FROM `{PROJECT_ID}.{DATASET_NAME}.{table_name}`
            """)

        full_query = " UNION ALL ".join(union_queries)
        raw_df = client.query(full_query).to_dataframe()

        # Create a copy with the selected metric
        raw_df['metric_value'] = raw_df[metric].copy()

        # Replace 'trips' column with selected metric for feature engineering
        original_trips = raw_df['trips'].copy()
        if metric == "revenue":
            raw_df['trips'] = raw_df['revenue']

        # Create features
        global_df = create_features(raw_df)

        # Restore metric name for clarity
        global_df['metric_value'] = global_df['trips'].copy()

        # Prepare time series data
        city_ts = pd.DataFrame({
            'date_hour': global_df.index,
            'trips': global_df['trips'].values,
            'metric_value': global_df['trips'].values
        }).reset_index(drop=True)

        # Scale data for anomaly detection
        scaler = MinMaxScaler()
        trip_scaled = scaler.fit_transform(city_ts[['trips']].values)

        return f"Successfully loaded {len(selected_partitions)} partition(s) for {taxi_type} taxi\n" \
               f"Data shape: {global_df.shape}\n" \
               f"Date range: {global_df.index.min()} to {global_df.index.max()}\n" \
               f"Forecasting metric: {metric.upper()}"

    except Exception as e:
        return f"Error loading data: {str(e)}"


def train_xgboost_with_ci(train_df, test_df, forecast_days, confidence_level):
    """Train XGBoost with confidence intervals using multiple models approach."""
    # Remove 'revenue' from feature_cols since it might not exist after metric selection
    feature_cols = [col for col in train_df.columns
                    if col not in ['trips', 'date_hour', 'metric_value', 'revenue', 'trips_original']]

    # Ensure all feature columns exist in the dataframe
    feature_cols = [col for col in feature_cols if col in train_df.columns]

    X_train = train_df[feature_cols]
    y_train = train_df['trips']
    X_test = test_df[feature_cols]
    y_test = test_df['trips']

    # Calculate quantiles for confidence intervals
    alpha = 1 - confidence_level
    lower_q = alpha / 2
    upper_q = 1 - alpha / 2

    # Train main model for point predictions
    main_model = xgb.XGBRegressor(
        objective='reg:squarederror',
        n_estimators=300,
        learning_rate=0.03,
        max_depth=6,
        subsample=0.8,
        colsample_bytree=0.8,
        random_state=42,
        n_jobs=-1
    )
    main_model.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=False)

    # For confidence intervals, use prediction intervals based on residuals
    # This is more reliable than quantile regression for XGBoost
    test_pred_mean = main_model.predict(X_test)

    # Calculate residuals on training set for variance estimation
    train_pred = main_model.predict(X_train)
    residuals = y_train - train_pred
    residual_std = np.std(residuals)

    # Calculate confidence intervals
    z_score = 1.96 if confidence_level == 0.95 else 2.576
    test_pred_lower = test_pred_mean - z_score * residual_std
    test_pred_upper = test_pred_mean + z_score * residual_std

    # Future predictions
    last_date = test_df.index.max()
    future_hours = forecast_days * 24
    future_pred_mean = []
    prediction_times = []

    # Use recent trips for lag features
    recent_trips = list(train_df['trips'].iloc[-168:].values) + list(test_df['trips'].values)

    for h in range(future_hours):
        pred_time = last_date + timedelta(hours=h+1)
        prediction_times.append(pred_time)

        future_features = pd.DataFrame(index=[pred_time])

        # Time features
        future_features['hour'] = pred_time.hour
        future_features['weekday'] = pred_time.weekday()
        future_features['day_of_month'] = pred_time.day
        future_features['week_of_year'] = pred_time.isocalendar()[1]
        future_features['is_weekend'] = int(pred_time.weekday() >= 5)

        # Cyclical encoding
        future_features['hour_sin'] = np.sin(2 * np.pi * pred_time.hour / 24)
        future_features['hour_cos'] = np.cos(2 * np.pi * pred_time.hour / 24)
        future_features['weekday_sin'] = np.sin(2 * np.pi * pred_time.weekday() / 7)
        future_features['weekday_cos'] = np.cos(2 * np.pi * pred_time.weekday() / 7)

        # Lag features (use mean prediction for consistency)
        if h == 0:
            future_features['lag_1h'] = recent_trips[-1]
            future_features['lag_24h'] = recent_trips[-24] if len(recent_trips) >= 24 else recent_trips[0]
            future_features['lag_168h'] = recent_trips[-168] if len(recent_trips) >= 168 else recent_trips[0]
        else:
            future_features['lag_1h'] = future_pred_mean[-1]
            future_features['lag_24h'] = future_pred_mean[-24] if len(future_pred_mean) >= 24 else recent_trips[-24]
            future_features['lag_168h'] = future_pred_mean[-168] if len(future_pred_mean) >= 168 else recent_trips[-168] if len(recent_trips) >= 168 else recent_trips[0]

        # Rolling features
        recent_for_rolling = recent_trips[-24:] + future_pred_mean
        future_features['rolling_mean_3h'] = np.mean(recent_for_rolling[-3:]) if len(recent_for_rolling) >= 3 else np.mean(recent_trips[-3:])
        future_features['rolling_mean_24h'] = np.mean(recent_for_rolling[-24:]) if len(recent_for_rolling) >= 24 else np.mean(recent_trips[-24:])
        future_features['rolling_std_24h'] = np.std(recent_for_rolling[-24:]) if len(recent_for_rolling) >= 24 else np.std(recent_trips[-24:])

        # Ensure we only use features that exist in our feature_cols
        available_features = [col for col in feature_cols if col in future_features.columns]

        # Make predictions
        pred_mean = main_model.predict(future_features[available_features])[0]
        future_pred_mean.append(pred_mean)

        recent_trips.append(pred_mean)
        if len(recent_trips) > 168:
            recent_trips.pop(0)

    # Convert to arrays
    future_pred_mean = np.array(future_pred_mean)

    # Calculate confidence intervals for future predictions
    # Increase uncertainty for future predictions
    future_uncertainty_factor = 1 + 0.05 * np.arange(1, future_hours + 1)  # 5% increase per hour
    future_pred_lower = future_pred_mean - z_score * residual_std * future_uncertainty_factor
    future_pred_upper = future_pred_mean + z_score * residual_std * future_uncertainty_factor

    return (test_pred_mean, test_pred_lower, test_pred_upper,
            future_pred_mean, future_pred_lower, future_pred_upper,
            prediction_times)


def train_combined_forecast_with_ci(forecast_days=7, confidence_level=0.95, show_actual=False, metric="trips"):
    """Train models and show predictions with confidence intervals."""
    global global_df, selected_metric, ml_taxi_type_value

    if global_df is None or global_df.empty:
        return px.line(title="Please load data first")

    # Use the stored metric
    metric_label = selected_metric.upper() if 'selected_metric' in globals() else metric.upper()

    # Create train/test split
    last_date = global_df.index.max()
    cutoff_date = last_date - timedelta(days=30)

    train_df = global_df[global_df.index <= cutoff_date].copy()
    test_df = global_df[global_df.index > cutoff_date].copy()

    if len(train_df) < 100 or len(test_df) < 10:
        return px.line(title="Insufficient data for training. Please load more partitions.")

    # Create figure
    fig = go.Figure()

    # Add historical data
    fig.add_trace(go.Scatter(
        x=train_df.index, y=train_df['trips'],
        mode='lines',
        name='Historical Data',
        line=dict(color='gray', width=1),
        opacity=0.7
    ))

    # Add actual test data
    fig.add_trace(go.Scatter(
        x=test_df.index, y=test_df['trips'],
        mode='lines',
        name='Actual (Test Period)',
        line=dict(color='black', width=2)
    ))

    # XGBoost Predictions with confidence intervals
    xgb_test_mean, xgb_test_lower, xgb_test_upper, \
    xgb_future_mean, xgb_future_lower, xgb_future_upper, xgb_future_times = train_xgboost_with_ci(
        train_df, test_df, forecast_days, confidence_level
    )

    # Calculate metrics for XGBoost
    xgb_mae = mean_absolute_error(test_df['trips'], xgb_test_mean)
    xgb_rmse = np.sqrt(mean_squared_error(test_df['trips'], xgb_test_mean))
    xgb_mse = mean_squared_error(test_df['trips'], xgb_test_mean)

    # Add XGBoost test predictions with CI
    fig.add_trace(go.Scatter(
        x=test_df.index, y=xgb_test_mean,
        mode='lines',
        name='XGBoost (Test)',
        line=dict(color='red', width=2)
    ))

    # XGBoost test confidence interval
    fig.add_trace(go.Scatter(
        x=test_df.index.tolist() + test_df.index.tolist()[::-1],
        y=xgb_test_upper.tolist() + xgb_test_lower.tolist()[::-1],
        fill='toself',
        fillcolor='rgba(255,0,0,0.1)',
        line=dict(color='rgba(255,0,0,0)'),
        name=f'XGBoost {int(confidence_level*100)}% CI (Test)',
        showlegend=True
    ))

    # XGBoost future predictions
    fig.add_trace(go.Scatter(
        x=xgb_future_times, y=xgb_future_mean,
        mode='lines',
        name='XGBoost (Future)',
        line=dict(color='red', width=2, dash='dash')
    ))

    # XGBoost future confidence interval
    fig.add_trace(go.Scatter(
        x=xgb_future_times + xgb_future_times[::-1],
        y=xgb_future_upper.tolist() + xgb_future_lower.tolist()[::-1],
        fill='toself',
        fillcolor='rgba(255,0,0,0.2)',
        line=dict(color='rgba(255,0,0,0)'),
        name=f'XGBoost {int(confidence_level*100)}% CI (Future)',
        showlegend=True
    ))

    # Prophet predictions if available
    prophet_mae, prophet_rmse, prophet_mse = None, None, None
    if PROPHET_AVAILABLE:
        try:
            prophet_test, prophet_future, prophet_times, prophet_test_lower, prophet_test_upper, \
            prophet_future_lower, prophet_future_upper = train_prophet_with_ci(
                train_df, test_df, forecast_days, confidence_level
            )

            prophet_mae = mean_absolute_error(test_df['trips'], prophet_test)
            prophet_rmse = np.sqrt(mean_squared_error(test_df['trips'], prophet_test))
            prophet_mse = mean_squared_error(test_df['trips'], prophet_test)

            # Add Prophet predictions and confidence intervals
            fig.add_trace(go.Scatter(
                x=test_df.index, y=prophet_test,
                mode='lines',
                name='Prophet (Test)',
                line=dict(color='blue', width=2)
            ))

            # Prophet confidence intervals
            fig.add_trace(go.Scatter(
                x=test_df.index.tolist() + test_df.index.tolist()[::-1],
                y=prophet_test_upper.tolist() + prophet_test_lower.tolist()[::-1],
                fill='toself',
                fillcolor='rgba(0,0,255,0.1)',
                line=dict(color='rgba(0,0,255,0)'),
                name=f'Prophet {int(confidence_level*100)}% CI (Test)',
                showlegend=True
            ))

            fig.add_trace(go.Scatter(
                x=prophet_times, y=prophet_future,
                mode='lines',
                name='Prophet (Future)',
                line=dict(color='blue', width=2, dash='dash')
            ))

            fig.add_trace(go.Scatter(
                x=prophet_times.tolist() + prophet_times.tolist()[::-1],
                y=prophet_future_upper.tolist() + prophet_future_lower.tolist()[::-1],
                fill='toself',
                fillcolor='rgba(0,0,255,0.2)',
                line=dict(color='rgba(0,0,255,0)'),
                name=f'Prophet {int(confidence_level*100)}% CI (Future)',
                showlegend=True
            ))
        except Exception as e:
            print(f"Prophet error: {e}")

    # Load actual future data if requested
    if show_actual and ml_taxi_type_value:
        future_start = last_date + timedelta(hours=1)
        future_end = future_start + timedelta(days=forecast_days)
        actual_future = load_actual_data_from_gcs(ml_taxi_type_value, future_start, future_end)

        if not actual_future.empty:
            # Apply metric selection to actual future data
            if selected_metric == "revenue":
                actual_future['trips'] = actual_future.get('revenue', actual_future['trips'])

            fig.add_trace(go.Scatter(
                x=actual_future.index, y=actual_future['trips'],
                mode='markers',
                name='Actual (Future)',
                marker=dict(color='green', size=8, symbol='star')
            ))

    # Add metric annotations
    fig.add_annotation(
        x=cutoff_date, y=train_df['trips'].max(),
        text="Test Start", showarrow=True,
        arrowhead=2, arrowsize=1, arrowwidth=2,
        arrowcolor="gray", ax=20, ay=-30
    )

    fig.add_annotation(
        x=last_date, y=test_df['trips'].max(),
        text="Forecast Start", showarrow=True,
        arrowhead=2, arrowsize=1, arrowwidth=2,
        arrowcolor="purple", ax=20, ay=-30
    )

    # Add metrics box
    metrics_text = f"<b>XGBoost Metrics:</b><br>MAE: {xgb_mae:.2f}<br>RMSE: {xgb_rmse:.2f}<br>MSE: {xgb_mse:.2f}"
    if prophet_mae is not None:
        metrics_text += f"<br><br><b>Prophet Metrics:</b><br>MAE: {prophet_mae:.2f}<br>RMSE: {prophet_rmse:.2f}<br>MSE: {prophet_mse:.2f}"

    fig.add_annotation(
        text=metrics_text,
        xref="paper", yref="paper",
        x=0.02, y=0.98,
        showarrow=False,
        font=dict(size=12),
        bgcolor="white",
        bordercolor="black",
        borderwidth=1,
        xanchor="left",
        yanchor="top"
    )

    # Add shaded region for future period
    fig.add_vrect(
        x0=last_date, x1=xgb_future_times[-1],
        fillcolor="lightgray", opacity=0.2,
        layer="below", line_width=0
    )

    # Update layout
    title_text = f"Combined Forecast Comparison for {metric_label}<br>"
    title_text += f"<sub>{int(confidence_level*100)}% Confidence Intervals | {forecast_days} Days Ahead</sub>"

    fig.update_layout(
        title=title_text,
        xaxis_title="Date",
        yaxis_title=f"Number of {metric_label}",
        hovermode='x unified',
        template="plotly_white",
        legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
        height=650
    )

    return fig

In [None]:
def train_prophet_with_ci(train_df, test_df, forecast_days, confidence_level):
    """Train Prophet and return predictions with confidence intervals."""
    prophet_train = pd.DataFrame({
        'ds': train_df.index,
        'y': train_df['trips'].values
    })

    model = Prophet(
        yearly_seasonality=True,
        weekly_seasonality=True,
        daily_seasonality=True,
        changepoint_prior_scale=0.05,
        seasonality_prior_scale=10,
        interval_width=confidence_level
    )
    model.add_seasonality(name='hourly', period=24, fourier_order=8)
    model.fit(prophet_train)

    # Test predictions
    test_dates = pd.DataFrame({'ds': test_df.index})
    test_forecast = model.predict(test_dates)

    # Future predictions
    last_date = test_df.index.max()
    future_dates = pd.date_range(
        start=last_date + timedelta(hours=1),
        periods=forecast_days * 24,
        freq='h'
    )
    future_df = pd.DataFrame({'ds': future_dates})
    future_forecast = model.predict(future_df)

    return (test_forecast['yhat'].values, future_forecast['yhat'].values, future_dates,
            test_forecast['yhat_lower'].values, test_forecast['yhat_upper'].values,
            future_forecast['yhat_lower'].values, future_forecast['yhat_upper'].values)

In [None]:
def detect_anomalies_from_db(taxi_type):
    """Load and visualize anomalies from pre-computed BigQuery tables."""
    if not taxi_type:
        return px.line(title="Please select a taxi type")

    try:
        # Query to get anomalies from the pipeline results
        query = f"""
        WITH anomaly_data AS (
            SELECT
                datetime,
                trips,
                is_anomaly,
                anomaly_score,
                reconstruction_error,
                threshold
            FROM `{PROJECT_ID}.{MlDATASET_NAME}.{taxi_type}_anomalies`
            ORDER BY datetime
        )
        SELECT * FROM anomaly_data
        """

        df = client.query(query).to_dataframe()

        if df.empty:
            return px.line(title=f"No anomaly data found for {taxi_type} taxi. Run the anomaly detection pipeline first.")

        # Create plot
        fig = go.Figure()

        # First, get the date range and partitions from anomaly data
        date_range_query = f"""
        SELECT
            MIN(datetime) as min_date,
            MAX(datetime) as max_date
        FROM `{PROJECT_ID}.{MlDATASET_NAME}.{taxi_type}_anomalies`
        """

        date_range_result = list(client.query(date_range_query).result())[0]
        min_date = date_range_result.min_date
        max_date = date_range_result.max_date

        # Get all available partitions for this taxi type
        all_partitions = get_available_partitions(taxi_type)

        # Filter partitions that could contain data in our date range
        min_year_month = f"{min_date.year}_{min_date.strftime('%m')}"
        max_year_month = f"{max_date.year}_{max_date.strftime('%m')}"

        relevant_partitions = [p for p in all_partitions
                              if min_year_month <= p <= max_year_month]

        # Build UNION query for relevant partitions
        union_queries = []
        for partition in relevant_partitions:
            union_queries.append(f"""
                SELECT
                    DATETIME(pickup_date, TIME(pickup_hour, 0, 0)) as datetime,
                    trips
                FROM `{PROJECT_ID}.{DATASET_NAME}.{taxi_type}_{partition}_hourly`
                WHERE DATETIME(pickup_date, TIME(pickup_hour, 0, 0)) >= DATETIME(TIMESTAMP('{min_date}'))
                AND DATETIME(pickup_date, TIME(pickup_hour, 0, 0)) <= DATETIME(TIMESTAMP('{max_date}'))
            """)

        if union_queries:
            hourly_query = " UNION ALL ".join(union_queries) + " ORDER BY datetime"
            hourly_df = client.query(hourly_query).to_dataframe()

            # Plot all hourly data
            if not hourly_df.empty:
                fig.add_trace(go.Scatter(
                    x=hourly_df['datetime'],
                    y=hourly_df['trips'],
                    mode='lines',
                    name='Hourly Traffic',
                    line=dict(color='lightblue', width=1),
                    opacity=0.7,
                    hovertemplate='<b>%{x}</b><br>Trips: %{y}<extra></extra>'
                ))

        # Anomalies with severity coloring
        anomalies = df[df['is_anomaly']]

        # Categorize anomalies by severity
        mild_anomalies = anomalies[anomalies['anomaly_score'] <= 1.5]
        moderate_anomalies = anomalies[(anomalies['anomaly_score'] > 1.5) & (anomalies['anomaly_score'] <= 3.0)]
        severe_anomalies = anomalies[anomalies['anomaly_score'] > 3.0]

        # Add mild anomalies
        if not mild_anomalies.empty:
            fig.add_trace(go.Scatter(
                x=mild_anomalies['datetime'],
                y=mild_anomalies['trips'],
                mode='markers',
                name='Mild Anomalies (Score ≤ 1.5)',
                marker=dict(color='yellow', size=6, symbol='circle'),
                hovertemplate='<b>%{x}</b><br>Trips: %{y}<br>Score: %{customdata:.2f}<extra></extra>',
                customdata=mild_anomalies['anomaly_score']
            ))

        # Add moderate anomalies
        if not moderate_anomalies.empty:
            fig.add_trace(go.Scatter(
                x=moderate_anomalies['datetime'],
                y=moderate_anomalies['trips'],
                mode='markers',
                name='Moderate Anomalies (1.5 < Score ≤ 3.0)',
                marker=dict(color='orange', size=8, symbol='diamond'),
                hovertemplate='<b>%{x}</b><br>Trips: %{y}<br>Score: %{customdata:.2f}<extra></extra>',
                customdata=moderate_anomalies['anomaly_score']
            ))

        # Add severe anomalies
        if not severe_anomalies.empty:
            fig.add_trace(go.Scatter(
                x=severe_anomalies['datetime'],
                y=severe_anomalies['trips'],
                mode='markers',
                name='Severe Anomalies (Score > 3.0)',
                marker=dict(color='red', size=10, symbol='x'),
                hovertemplate='<b>%{x}</b><br>Trips: %{y}<br>Score: %{customdata:.2f}<extra></extra>',
                customdata=severe_anomalies['anomaly_score']
            ))

        # Add holiday annotations
        date_range = pd.to_datetime(df['datetime'])
        min_date = date_range.min()
        max_date = date_range.max()

        for holiday_date, holiday_name in US_HOLIDAYS:
            holiday_dt = pd.to_datetime(holiday_date).tz_localize(None)
            if min_date.tz_localize(None) <= holiday_dt <= max_date.tz_localize(None):
                # Check if this holiday coincides with an anomaly
                holiday_anomalies = df[(pd.to_datetime(df['datetime']).dt.date == holiday_dt.date()) & df['is_anomaly']]

                # Find y-value for annotation
                holiday_data = df[pd.to_datetime(df['datetime']).dt.date == holiday_dt.date()]
                if not holiday_data.empty:
                    y_val = holiday_data['trips'].max()

                    # Different styling if holiday has anomalies
                    if not holiday_anomalies.empty:
                        max_score = holiday_anomalies['anomaly_score'].max()
                        annotation_text = f"{holiday_name}<br>Max Score: {max_score:.2f}"
                        border_color = "red"
                    else:
                        annotation_text = holiday_name
                        border_color = "green"

                    fig.add_annotation(
                        x=holiday_dt,
                        y=y_val,
                        text=annotation_text,
                        showarrow=True,
                        arrowhead=2,
                        arrowsize=1,
                        arrowwidth=2,
                        arrowcolor=border_color,
                        ax=0,
                        ay=-40,
                        bgcolor="white",
                        bordercolor=border_color,
                        borderwidth=2,
                        font=dict(size=10)
                    )

        # Update layout
        total_anomalies = df['is_anomaly'].sum()
        avg_score = df[df['is_anomaly']]['anomaly_score'].mean()

        fig.update_layout(
            title=f"Anomaly Detection Results for {taxi_type.title()} Taxi<br>" +
                  f"<sub>Total Anomalies: {total_anomalies} | Average Anomaly Score: {avg_score:.2f}</sub>",
            xaxis_title="Date",
            yaxis_title="Number of Trips",
            hovermode='x unified',
            template="plotly_white",
            height=600,
            legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01)
        )

        return fig

    except Exception as e:
        # print entre error
        print(
            e
        )
        return px.line(title=f"Error loading anomaly data: {str(e)}")

def get_anomaly_summary(taxi_type):
    """Get summary statistics for anomalies."""

    if not taxi_type:
        return "Please load data first"

    try:
        # Query summary statistics
        query = f"""
        WITH summary AS (
            SELECT
                COUNT(*) as total_hours,
                SUM(CAST(is_anomaly AS INT64)) as total_anomalies,
                AVG(CASE WHEN is_anomaly THEN anomaly_score END) as avg_anomaly_score,
                MAX(CASE WHEN is_anomaly THEN anomaly_score END) as max_anomaly_score,
                COUNT(DISTINCT DATE(datetime)) as days_analyzed,
                COUNT(DISTINCT FORMAT_DATETIME('%Y-%m', datetime)) as months_analyzed
            FROM `{PROJECT_ID}.{MlDATASET_NAME}.{taxi_type}_anomalies`
        ),
        severity_breakdown AS (
            SELECT
                SUM(CASE WHEN anomaly_score > 3.0 THEN 1 ELSE 0 END) as severe_count,
                SUM(CASE WHEN anomaly_score > 1.5 AND anomaly_score <= 3.0 THEN 1 ELSE 0 END) as moderate_count,
                SUM(CASE WHEN anomaly_score > 1.0 AND anomaly_score <= 1.5 THEN 1 ELSE 0 END) as mild_count
            FROM `{PROJECT_ID}.{MlDATASET_NAME}.{taxi_type}_anomalies`
            WHERE is_anomaly = TRUE
        )
        SELECT * FROM summary CROSS JOIN severity_breakdown
        """

        result = list(client.query(query).result())[0]

        summary_text = f"""
**Anomaly Detection Summary for {taxi_type.title()} Taxi**

**Coverage:**
- Months Analyzed: {result.months_analyzed}
- Days Analyzed: {result.days_analyzed}
- Total Hours: {result.total_hours:,}

**Anomaly Statistics:**
- Total Anomalies: {result.total_anomalies:,} ({result.total_anomalies/result.total_hours*100:.2f}%)
- Average Anomaly Score: {result.avg_anomaly_score:.2f}
- Maximum Anomaly Score: {result.max_anomaly_score:.2f}

**Severity Breakdown:**
- Severe (Score > 3.0): {result.severe_count}
- Moderate (1.5 < Score ≤ 3.0): {result.moderate_count}
- Mild (1.0 < Score ≤ 1.5): {result.mild_count}
        """

        return summary_text

    except Exception as e:
        return f"Error loading summary: {str(e)}"

# ==================== Data Loading Functions ====================

In [None]:


def update_partition_choices(taxi_type):
    """Update partition choices when taxi type changes."""
    if not taxi_type:
        return gr.update(choices=[]), gr.update(choices=[])

    try:
        partitions = get_available_partitions(taxi_type)
        return gr.update(choices=partitions, value=partitions[0] if partitions else None), \
               gr.update(choices=partitions, value=partitions[-1] if partitions else None)
    except:
        return gr.update(choices=[], value=None), gr.update(choices=[], value=None)


# Map Data

In [None]:
import pandas as pd
from io import BytesIO
from google.cloud import storage

# Create a storage client
storage_client = storage.Client()

# Specify the bucket and blob name
bucket_name = 'nyc_raw_data_bucket'
blob_name = 'taxi_zone_lookup.csv'

# Get the bucket and blob
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(blob_name)

# Download the blob's content as bytes
content = blob.download_as_bytes()

# Read the CSV data into a Pandas DataFrame
taxi_zone_lookup = pd.read_csv(BytesIO(content))

# Display the DataFrame (optional)
print(taxi_zone_lookup.head())


import geopandas as gpd
from google.cloud import storage
import os
import tempfile
import shutil

# Create a storage client
storage_client = storage.Client()

# Specify the bucket and the shapefile prefix
bucket_name = 'nyc_raw_data_bucket'
shapefile_prefix = 'taxi_zones/taxi_zones'

# Define all the required file extensions
# .dbf and .prj are also essential for attributes and CRS
required_extensions = ['.dbf','.prj','.sbn','.sbx','.shp', '.shp.xml','.shx']

# Create a temporary directory to store the downloaded files
temp_dir = tempfile.mkdtemp()

try:
    print(f"Downloading shapefile components to temporary directory: {temp_dir}")
    # Download each part of the shapefile
    for extension in required_extensions:
        blob_name = shapefile_prefix + extension
        blob = storage_client.bucket(bucket_name).blob(blob_name)

        # Construct the local file path
        local_path = os.path.join(temp_dir, os.path.basename(blob_name))

        try:
            blob.download_to_filename(local_path)
            print(f"Downloaded {blob_name}")
        except Exception as e:
            print(f"Error downloading {blob_name}: {e}")
            # The .prj file is optional for reading, so we can continue if it's missing
            if extension == '.prj':
                continue
            else:
                raise Exception(f"Failed to download a required file: {blob_name}")

    # Read the shapefile from the temporary directory
    shp_path = os.path.join(temp_dir, 'taxi_zones.shp')

    print("\nReading shapefile with GeoPandas...")
    taxi_zones_gdf = gpd.read_file(shp_path)

    print("\nSuccessfully read the Shapefile. Here is the GeoDataFrame head:")
    print(taxi_zones_gdf.head())

except Exception as e:
    print(f"Error during shapefile processing: {e}")

finally:
    # Clean up the temporary directory and its contents
    if os.path.exists(temp_dir):
        shutil.rmtree(temp_dir)
        print(f"\nCleaned up temporary directory: {temp_dir}")

import geopandas as gpd
from google.cloud import storage
import os
import tempfile
import shutil

# Create a storage client
storage_client = storage.Client()
bucket_name = 'nyc_raw_data_bucket'
shapefile_prefix = 'taxi_zones/taxi_zones'
required_extensions = ['.shp', '.shx', '.dbf', '.prj']

temp_dir = tempfile.mkdtemp()
try:
    for extension in required_extensions:
        blob_name = shapefile_prefix + extension
        blob = storage_client.bucket(bucket_name).blob(blob_name)
        local_path = os.path.join(temp_dir, os.path.basename(blob_name))
        blob.download_to_filename(local_path)

    shp_path = os.path.join(temp_dir, 'taxi_zones.shp')

    # Read the shapefile
    taxi_zones_gdf = gpd.read_file(shp_path)

    print("Original CRS:")
    print(taxi_zones_gdf.crs)

    # Reproject the GeoDataFrame to WGS 84 (EPSG:4326)
    # This is the standard CRS for longitude and latitude
    print("\nReprojecting to WGS 84 (EPSG:4326)...")
    taxi_zones_gdf_wgs84 = taxi_zones_gdf.to_crs(epsg=4326)

    print("\nNew CRS:")
    print(taxi_zones_gdf_wgs84.crs)

    # You can now see the coordinates in a longitude/latitude format
    print("\nHead of GeoDataFrame with reprojected coordinates:")
    print(taxi_zones_gdf_wgs84.head())

except Exception as e:
    print(f"Error during shapefile processing: {e}")

finally:
    if os.path.exists(temp_dir):
        shutil.rmtree(temp_dir)

taxi_zone_lookup['LocationID'] = taxi_zone_lookup['LocationID'].astype(int)

# Merge the two dataframes
merged_df = pd.merge(taxi_zone_lookup, taxi_zones_gdf_wgs84, left_on='LocationID', right_on='OBJECTID', how='inner')

# Remove rows with missing geometries before plotting
merged_df_cleaned = merged_df.dropna(subset=['geometry'])

# FIX: Explicitly convert the DataFrame to a GeoDataFrame
# and set the CRS to WGS 84 (EPSG:4326) which is the standard for Lon/Lat data.
merged_gdf_cleaned = gpd.GeoDataFrame(merged_df_cleaned, geometry='geometry', crs='EPSG:4326')

merged_gdf_cleaned["LocationID"] = merged_gdf_cleaned["OBJECTID"]

# Fleet Insights Utilities

In [None]:
def get_fleet_date_range_new(taxi_type):
    """Get date range from new prediction tables"""
    query = f"""
    SELECT
        MIN(date) as min_date,
        MAX(date) as max_date
    FROM `{PROJECT_ID}.{OUTPUT_DATASET}.fleet_recommender_{taxi_type}_predictions_new_xgb`
    """
    df = bq_client.query(query).to_dataframe()

    if not df.empty:
        min_date = df.iloc[0]['min_date'].strftime('%Y-%m-%d')
        max_date = df.iloc[0]['max_date'].strftime('%Y-%m-%d')
        range_text = f"Available: {min_date} to {max_date}"
        return min_date, max_date, range_text
    return None, None, "No data available"

def get_zones_with_predictions(taxi_type):
    """Get list of zones ordered by trip volume"""
    query = f"""
    SELECT DISTINCT zone_id,
           SUM(xgb_pred_y_single + xgb_pred_y_small + xgb_pred_y_medium + xgb_pred_y_large) as total_trips
    FROM `{PROJECT_ID}.{OUTPUT_DATASET}.fleet_recommender_{taxi_type}_predictions_new_xgb`
    WHERE prediction_type = 'future'
    GROUP BY zone_id
    ORDER BY total_trips DESC
    """
    df = bq_client.query(query).to_dataframe()

    zone_list = ["All Zones"]
    for _, row in df.iterrows():
        zone_list.append(f"Zone {row['zone_id']} ({row['total_trips']:.0f} trips)")
    return zone_list

def load_zone_metrics_summary(taxi_type):
    """Load and format zone-level metrics"""
    query = f"""
    SELECT zone_id, model_type, target,
           ROUND(AVG(smape), 1) as avg_smape,
           ROUND(AVG(mae), 2) as avg_mae,
           ROUND(AVG(rmse), 2) as avg_rmse
    FROM `{PROJECT_ID}.{OUTPUT_DATASET}.fleet_recommender_{taxi_type}_metrics_new`
    GROUP BY zone_id, model_type, target
    ORDER BY zone_id, model_type, target
    """
    df = bq_client.query(query).to_dataframe()

    if df.empty:
        return "No metrics available"

    # Format metrics by zone
    summary = []
    for zone_id in df['zone_id'].unique():
        zone_df = df[df['zone_id'] == zone_id]
        summary.append(f"\n=== Zone {zone_id} ===")

        for model in ['xgb', 'knn']:
            model_df = zone_df[zone_df['model_type'] == model]
            if not model_df.empty:
                summary.append(f"\n{model.upper()} Model:")
                for _, row in model_df.iterrows():
                    summary.append(f"  {row['target']}: SMAPE={row['avg_smape']}%, MAE={row['avg_mae']}")

    return "\n".join(summary)


def extract_zone_id(zone_string):
    """Extract zone ID from dropdown string"""
    if zone_string == "All Zones":
        return None
    return int(zone_string.split()[1])

In [None]:
def plot_fleet_timeseries_new(taxi_type, model_type, start_date, end_date, zone_id, display_type, pred_type):
    """Create time series plot for new zone-based predictions"""

    # Build the WHERE clause based on filters
    where_conditions = [f"date >= '{start_date}'", f"date <= '{end_date}'"]

    if zone_id is not None:
        where_conditions.append(f"zone_id = {zone_id}")

    # Filter by prediction type
    if pred_type == "Test Only":
        where_conditions.append("prediction_type = 'test'")
    elif pred_type == "Future Only":
        where_conditions.append("prediction_type = 'future'")

    where_clause = " AND ".join(where_conditions)

    # Query based on model type
    if model_type == "both":
        # Get data from both models
        xgb_query = f"""
        SELECT
            date,
            hr,
            {'zone_id,' if zone_id is not None else ''}
            prediction_type,
            {'SUM' if zone_id is None else ''}(xgb_pred_y_single) as pred_y_single,
            {'SUM' if zone_id is None else ''}(xgb_pred_y_small) as pred_y_small,
            {'SUM' if zone_id is None else ''}(xgb_pred_y_medium) as pred_y_medium,
            {'SUM' if zone_id is None else ''}(xgb_pred_y_large) as pred_y_large,
            {'SUM' if zone_id is None else ''}(xgb_pred_y_single_lower) as pred_y_single_lower,
            {'SUM' if zone_id is None else ''}(xgb_pred_y_single_upper) as pred_y_single_upper,
            {'SUM' if zone_id is None else ''}(actual_y_single) as actual_y_single,
            {'SUM' if zone_id is None else ''}(actual_y_small) as actual_y_small,
            {'SUM' if zone_id is None else ''}(actual_y_medium) as actual_y_medium,
            {'SUM' if zone_id is None else ''}(actual_y_large) as actual_y_large,
            'XGBoost' as model
        FROM `{PROJECT_ID}.{OUTPUT_DATASET}.fleet_recommender_{taxi_type}_predictions_new_xgb`
        WHERE {where_clause}
        {'GROUP BY date, hr, prediction_type' if zone_id is None else ''}
        ORDER BY date, hr
        """

        knn_query = f"""
        SELECT
            date,
            hr,
            {'zone_id,' if zone_id is not None else ''}
            prediction_type,
            {'SUM' if zone_id is None else ''}(knn_pred_y_single) as pred_y_single,
            {'SUM' if zone_id is None else ''}(knn_pred_y_small) as pred_y_small,
            {'SUM' if zone_id is None else ''}(knn_pred_y_medium) as pred_y_medium,
            {'SUM' if zone_id is None else ''}(knn_pred_y_large) as pred_y_large,
            {'SUM' if zone_id is None else ''}(knn_pred_y_single_lower) as pred_y_single_lower,
            {'SUM' if zone_id is None else ''}(knn_pred_y_single_upper) as pred_y_single_upper,
            {'SUM' if zone_id is None else ''}(actual_y_single) as actual_y_single,
            {'SUM' if zone_id is None else ''}(actual_y_small) as actual_y_small,
            {'SUM' if zone_id is None else ''}(actual_y_medium) as actual_y_medium,
            {'SUM' if zone_id is None else ''}(actual_y_large) as actual_y_large,
            'KNN' as model
        FROM `{PROJECT_ID}.{OUTPUT_DATASET}.fleet_recommender_{taxi_type}_predictions_new_knn`
        WHERE {where_clause}
        {'GROUP BY date, hr, prediction_type' if zone_id is None else ''}
        ORDER BY date, hr
        """

        xgb_df = bq_client.query(xgb_query).to_dataframe()
        knn_df = bq_client.query(knn_query).to_dataframe()
        df = pd.concat([xgb_df, knn_df], ignore_index=True)

    else:
        # Single model
        table_name = f"fleet_recommender_{taxi_type}_predictions_new_{model_type}"
        prefix = model_type

        query = f"""
        SELECT
            date,
            hr,
            {'zone_id,' if zone_id is not None else ''}
            prediction_type,
            {'SUM' if zone_id is None else ''}({prefix}_pred_y_single) as pred_y_single,
            {'SUM' if zone_id is None else ''}({prefix}_pred_y_small) as pred_y_small,
            {'SUM' if zone_id is None else ''}({prefix}_pred_y_medium) as pred_y_medium,
            {'SUM' if zone_id is None else ''}({prefix}_pred_y_large) as pred_y_large,
            {'SUM' if zone_id is None else ''}({prefix}_pred_y_single_lower) as pred_y_single_lower,
            {'SUM' if zone_id is None else ''}({prefix}_pred_y_single_upper) as pred_y_single_upper,
            {'SUM' if zone_id is None else ''}({prefix}_pred_y_small_lower) as pred_y_small_lower,
            {'SUM' if zone_id is None else ''}({prefix}_pred_y_small_upper) as pred_y_small_upper,
            {'SUM' if zone_id is None else ''}({prefix}_pred_y_medium_lower) as pred_y_medium_lower,
            {'SUM' if zone_id is None else ''}({prefix}_pred_y_medium_upper) as pred_y_medium_upper,
            {'SUM' if zone_id is None else ''}({prefix}_pred_y_large_lower) as pred_y_large_lower,
            {'SUM' if zone_id is None else ''}({prefix}_pred_y_large_upper) as pred_y_large_upper,
            {'SUM' if zone_id is None else ''}(actual_y_single) as actual_y_single,
            {'SUM' if zone_id is None else ''}(actual_y_small) as actual_y_small,
            {'SUM' if zone_id is None else ''}(actual_y_medium) as actual_y_medium,
            {'SUM' if zone_id is None else ''}(actual_y_large) as actual_y_large,
            '{model_type.upper()}' as model
        FROM `{PROJECT_ID}.{OUTPUT_DATASET}.{table_name}`
        WHERE {where_clause}
        {'GROUP BY date, hr, prediction_type' if zone_id is None else ''}
        ORDER BY date, hr
        """
        df = bq_client.query(query).to_dataframe()

    if df.empty:
        return create_empty_plot("No data available for selected filters")

    # Create datetime column
    df['datetime'] = pd.to_datetime(df['date']) + pd.to_timedelta(df['hr'], unit='h')
    df = df.sort_values('datetime')

    # Calculate total trips
    df['pred_total'] = df['pred_y_single'] + df['pred_y_small'] + df['pred_y_medium'] + df['pred_y_large']

    # Create the plot based on display type
    if display_type == "All Passengers":
        fig = create_all_passengers_plot(df, zone_id, model_type)
    elif display_type == "By Passenger Type":
        fig = create_by_passenger_type_plot(df, zone_id, model_type)
    else:  # Model Comparison
        fig = create_model_comparison_plot(df, zone_id)

    return fig

def create_all_passengers_plot(df, zone_id, model_type):
    """Create plot showing total passenger predictions"""
    fig = go.Figure()

    # Colors for different elements
    colors = {
        'XGBoost': '#1f77b4',
        'KNN': '#ff7f0e',
        'actual': '#2ca02c',
        'test': 'rgba(255, 0, 0, 0.1)',
        'future': 'rgba(0, 0, 255, 0.1)'
    }

    if model_type == "both":
        # Plot both models
        for model in ['XGBoost', 'KNN']:
            model_df = df[df['model'] == model]
            fig.add_trace(go.Scatter(
                x=model_df['datetime'],
                y=model_df['pred_total'],
                name=f'{model} Prediction',
                line=dict(color=colors[model], width=2),
                mode='lines'
            ))

            # Add confidence intervals if available
            if 'pred_y_single_upper' in model_df.columns:
                # Calculate total CI (simplified - sum of individual CIs)
                upper_total = model_df['pred_y_single_upper'].fillna(0) + \
                             model_df['pred_y_small_upper'].fillna(0) + \
                             model_df['pred_y_medium_upper'].fillna(0) + \
                             model_df['pred_y_large_upper'].fillna(0)
                lower_total = model_df['pred_y_single_lower'].fillna(0) + \
                             model_df['pred_y_small_lower'].fillna(0) + \
                             model_df['pred_y_medium_lower'].fillna(0) + \
                             model_df['pred_y_large_lower'].fillna(0)

                fig.add_trace(go.Scatter(
                    x=model_df['datetime'],
                    y=upper_total,
                    fill=None,
                    mode='lines',
                    line_color='rgba(0,0,0,0)',
                    showlegend=False,
                    hoverinfo='skip'
                ))

                fig.add_trace(go.Scatter(
                    x=model_df['datetime'],
                    y=lower_total,
                    fill='tonexty',
                    mode='lines',
                    line_color='rgba(0,0,0,0)',
                    name=f'{model} 95% CI',
                    fillcolor=f'rgba{tuple(list(int(colors[model].lstrip("#")[i:i+2], 16) for i in (0, 2, 4)) + [0.2])}'
                ))
    else:
        # Single model
        fig.add_trace(go.Scatter(
            x=df['datetime'],
            y=df['pred_total'],
            name=f'{model_type.upper()} Prediction',
            line=dict(color=colors.get(model_type.upper(), '#1f77b4'), width=2),
            mode='lines'
        ))

        # Add confidence intervals if single model
        if 'pred_y_single_upper' in df.columns:
            upper_total = df['pred_y_single_upper'].fillna(0) + \
                         df['pred_y_small_upper'].fillna(0) + \
                         df['pred_y_medium_upper'].fillna(0) + \
                         df['pred_y_large_upper'].fillna(0)
            lower_total = df['pred_y_single_lower'].fillna(0) + \
                         df['pred_y_small_lower'].fillna(0) + \
                         df['pred_y_medium_lower'].fillna(0) + \
                         df['pred_y_large_lower'].fillna(0)

            fig.add_trace(go.Scatter(
                x=df['datetime'],
                y=upper_total,
                fill=None,
                mode='lines',
                line_color='rgba(0,0,0,0)',
                showlegend=False,
                hoverinfo='skip'
            ))

            fig.add_trace(go.Scatter(
                x=df['datetime'],
                y=lower_total,
                fill='tonexty',
                mode='lines',
                line_color='rgba(0,0,0,0)',
                name='95% CI',
                fillcolor='rgba(31, 119, 180, 0.2)'
            ))

    # Add actual values if present
    test_df = df[df['prediction_type'] == 'test'].copy()
    if not test_df.empty and 'actual_y_single' in test_df.columns:
        test_df['actual_total'] = test_df['actual_y_single'].fillna(0) + \
                                  test_df['actual_y_small'].fillna(0) + \
                                  test_df['actual_y_medium'].fillna(0) + \
                                  test_df['actual_y_large'].fillna(0)
        fig.add_trace(go.Scatter(
            x=test_df['datetime'],
            y=test_df['actual_total'],
            name='Actual',
            line=dict(color=colors['actual'], width=2, dash='dot'),
            mode='lines'
        ))

    # Add shaded regions for test/future periods
    test_periods = df[df['prediction_type'] == 'test']['datetime']
    future_periods = df[df['prediction_type'] == 'future']['datetime']

    if not test_periods.empty:
        fig.add_vrect(
            x0=test_periods.min(), x1=test_periods.max(),
            fillcolor=colors['test'], layer="below", line_width=0,
            annotation_text="Test Period", annotation_position="top left"
        )

    if not future_periods.empty:
        fig.add_vrect(
            x0=future_periods.min(), x1=future_periods.max(),
            fillcolor=colors['future'], layer="below", line_width=0,
            annotation_text="Future Predictions", annotation_position="top left"
        )

    # Update layout
    zone_text = f"Zone {zone_id}" if zone_id else "All Zones"
    fig.update_layout(
        title=f"Passenger Predictions - {zone_text}",
        xaxis_title="Date/Time",
        yaxis_title="Total Passengers",
        hovermode='x unified',
        height=600
    )

    return fig

def create_by_passenger_type_plot(df, zone_id, model_type):
    """Create subplots for each passenger type"""
    from plotly.subplots import make_subplots

    passenger_types = ['y_single', 'y_small', 'y_medium', 'y_large']
    passenger_labels = ['Single', 'Small Groups', 'Medium Groups', 'Large Groups']

    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=passenger_labels,
        shared_xaxes=True,
        vertical_spacing=0.1
    )

    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']

    for idx, (ptype, label) in enumerate(zip(passenger_types, passenger_labels)):
        row = idx // 2 + 1
        col = idx % 2 + 1

        if model_type == "both":
            for model in ['XGBoost', 'KNN']:
                model_df = df[df['model'] == model]
                fig.add_trace(
                    go.Scatter(
                        x=model_df['datetime'],
                        y=model_df[f'pred_{ptype}'],
                        name=f'{model}',
                        line=dict(width=2),
                        showlegend=(idx == 0),  # Only show legend for first subplot
                    ),
                    row=row, col=col
                )
        else:
            fig.add_trace(
                go.Scatter(
                    x=df['datetime'],
                    y=df[f'pred_{ptype}'],
                    name=label,
                    line=dict(color=colors[idx], width=2),
                ),
                row=row, col=col
            )

            # Add confidence intervals
            if f'pred_{ptype}_upper' in df.columns:
                fig.add_trace(
                    go.Scatter(
                        x=df['datetime'],
                        y=df[f'pred_{ptype}_upper'],
                        fill=None,
                        mode='lines',
                        line_color='rgba(0,0,0,0)',
                        showlegend=False,
                        hoverinfo='skip'
                    ),
                    row=row, col=col
                )

                fig.add_trace(
                    go.Scatter(
                        x=df['datetime'],
                        y=df[f'pred_{ptype}_lower'],
                        fill='tonexty',
                        mode='lines',
                        line_color='rgba(0,0,0,0)',
                        showlegend=False,
                        fillcolor=f'rgba{tuple(list(int(colors[idx].lstrip("#")[i:i+2], 16) for i in (0, 2, 4)) + [0.2])}'
                    ),
                    row=row, col=col
                )

        # Add actuals if available
        test_df = df[df['prediction_type'] == 'test']
        if not test_df.empty and f'actual_{ptype}' in test_df.columns:
            fig.add_trace(
                go.Scatter(
                    x=test_df['datetime'],
                    y=test_df[f'actual_{ptype}'],
                    name='Actual' if idx == 0 else '',
                    line=dict(color='black', width=1, dash='dot'),
                    showlegend=(idx == 0),
                ),
                row=row, col=col
            )

    zone_text = f"Zone {zone_id}" if zone_id else "All Zones"
    fig.update_layout(
        title=f"Predictions by Passenger Type - {zone_text}",
        height=800,
        hovermode='x unified'
    )

    return fig

def create_model_comparison_plot(df, zone_id):
    """Create plot comparing XGBoost vs KNN predictions"""
    if 'XGBoost' not in df['model'].values or 'KNN' not in df['model'].values:
        return create_empty_plot("Both models needed for comparison")

    fig = go.Figure()

    xgb_df = df[df['model'] == 'XGBoost']
    knn_df = df[df['model'] == 'KNN']

    # Merge on datetime to ensure alignment
    merged = pd.merge(
        xgb_df[['datetime', 'pred_total']].rename(columns={'pred_total': 'xgb_total'}),
        knn_df[['datetime', 'pred_total']].rename(columns={'pred_total': 'knn_total'}),
        on='datetime'
    )

    # Add model predictions
    fig.add_trace(go.Scatter(
        x=merged['datetime'],
        y=merged['xgb_total'],
        name='XGBoost',
        line=dict(color='#1f77b4', width=2),
        mode='lines'
    ))

    fig.add_trace(go.Scatter(
        x=merged['datetime'],
        y=merged['knn_total'],
        name='KNN',
        line=dict(color='#ff7f0e', width=2),
        mode='lines'
    ))



    # Add actual values if present
    test_df = df[df['prediction_type'] == 'test'].groupby('datetime').first()
    if not test_df.empty and 'actual_y_single' in test_df.columns:
        test_df['actual_total'] = test_df['actual_y_single'].fillna(0) + \
                                  test_df['actual_y_small'].fillna(0) + \
                                  test_df['actual_y_medium'].fillna(0) + \
                                  test_df['actual_y_large'].fillna(0)
        fig.add_trace(go.Scatter(
            x=test_df.index,
            y=test_df['actual_total'],
            name='Actual',
            line=dict(color='#2ca02c', width=2, dash='dot'),
            mode='lines'
        ))

    zone_text = f"Zone {zone_id}" if zone_id else "All Zones"
    fig.update_layout(
        title=f"Model Comparison - {zone_text}",
        xaxis_title="Date/Time",
        yaxis_title="Total Passengers",
        # yaxis2=dict(
        #     title="Difference",
        #     overlaying='y',
        #     side='right'
        # ),
        hovermode='x unified',
        height=600
    )

    return fig

def create_empty_plot(message):
    """Create an empty plot with a message"""
    fig = go.Figure()
    fig.add_annotation(
        x=0.5, y=0.5,
        text=message,
        xref="paper", yref="paper",
        showarrow=False,
        font=dict(size=20)
    )
    fig.update_layout(
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        height=400
    )
    return fig

In [None]:
from functools import lru_cache

## fleet detailed predictions table utility

In [None]:
def load_zone_predictions_table(taxi_type, zone_id, date, pred_type):
    """Load detailed predictions table for a specific zone and date"""

    # Parse the date
    try:
        selected_date = pd.to_datetime(date).date()
    except:
        return pd.DataFrame({"Error": ["Invalid date format. Please use YYYY-MM-DD"]})

    # Build query based on prediction type
    pred_type_condition = f"prediction_type = '{pred_type.lower()}'"

    # Build WHERE clause
    where_conditions = [
        f"date = '{selected_date}'",
        pred_type_condition
    ]

    if zone_id is not None:
        where_conditions.append(f"zone_id = {zone_id}")

    where_clause = " AND ".join(where_conditions)

    # Query both models
    query = f"""
    WITH xgb_data AS (
        SELECT
            zone_id,
            hr as hour,
            xgb_pred_y_single as xgb_single,
            xgb_pred_y_single_lower as xgb_single_lower,
            xgb_pred_y_single_upper as xgb_single_upper,
            xgb_pred_y_small as xgb_small,
            xgb_pred_y_small_lower as xgb_small_lower,
            xgb_pred_y_small_upper as xgb_small_upper,
            xgb_pred_y_medium as xgb_medium,
            xgb_pred_y_medium_lower as xgb_medium_lower,
            xgb_pred_y_medium_upper as xgb_medium_upper,
            xgb_pred_y_large as xgb_large,
            xgb_pred_y_large_lower as xgb_large_lower,
            xgb_pred_y_large_upper as xgb_large_upper,
            actual_y_single,
            actual_y_small,
            actual_y_medium,
            actual_y_large
        FROM `{PROJECT_ID}.{OUTPUT_DATASET}.fleet_recommender_{taxi_type}_predictions_new_xgb`
        WHERE {where_clause}
    ),
    knn_data AS (
        SELECT
            zone_id,
            hr as hour,
            knn_pred_y_single as knn_single,
            knn_pred_y_single_lower as knn_single_lower,
            knn_pred_y_single_upper as knn_single_upper,
            knn_pred_y_small as knn_small,
            knn_pred_y_small_lower as knn_small_lower,
            knn_pred_y_small_upper as knn_small_upper,
            knn_pred_y_medium as knn_medium,
            knn_pred_y_medium_lower as knn_medium_lower,
            knn_pred_y_medium_upper as knn_medium_upper,
            knn_pred_y_large as knn_large,
            knn_pred_y_large_lower as knn_large_lower,
            knn_pred_y_large_upper as knn_large_upper
        FROM `{PROJECT_ID}.{OUTPUT_DATASET}.fleet_recommender_{taxi_type}_predictions_new_knn`
        WHERE {where_clause}
    )
    SELECT
        COALESCE(x.zone_id, k.zone_id) as zone_id,
        COALESCE(x.hour, k.hour) as hour,
        -- XGBoost predictions with CIs
        ROUND(x.xgb_single, 2) as XGB_Single,
        CONCAT('[', ROUND(x.xgb_single_lower, 2), '-', ROUND(x.xgb_single_upper, 2), ']') as XGB_Single_CI,
        ROUND(x.xgb_small, 2) as XGB_Small,
        CONCAT('[', ROUND(x.xgb_small_lower, 2), '-', ROUND(x.xgb_small_upper, 2), ']') as XGB_Small_CI,
        ROUND(x.xgb_medium, 2) as XGB_Medium,
        CONCAT('[', ROUND(x.xgb_medium_lower, 2), '-', ROUND(x.xgb_medium_upper, 2), ']') as XGB_Medium_CI,
        ROUND(x.xgb_large, 2) as XGB_Large,
        CONCAT('[', ROUND(x.xgb_large_lower, 2), '-', ROUND(x.xgb_large_upper, 2), ']') as XGB_Large_CI,
        -- KNN predictions with CIs
        ROUND(k.knn_single, 2) as KNN_Single,
        CONCAT('[', ROUND(k.knn_single_lower, 2), '-', ROUND(k.knn_single_upper, 2), ']') as KNN_Single_CI,
        ROUND(k.knn_small, 2) as KNN_Small,
        CONCAT('[', ROUND(k.knn_small_lower, 2), '-', ROUND(k.knn_small_upper, 2), ']') as KNN_Small_CI,
        ROUND(k.knn_medium, 2) as KNN_Medium,
        CONCAT('[', ROUND(k.knn_medium_lower, 2), '-', ROUND(k.knn_medium_upper, 2), ']') as KNN_Medium_CI,
        ROUND(k.knn_large, 2) as KNN_Large,
        CONCAT('[', ROUND(k.knn_large_lower, 2), '-', ROUND(k.knn_large_upper, 2), ']') as KNN_Large_CI,
        -- Totals
        ROUND(x.xgb_single + x.xgb_small + x.xgb_medium + x.xgb_large, 2) as XGB_Total,
        ROUND(k.knn_single + k.knn_small + k.knn_medium + k.knn_large, 2) as KNN_Total
        {', ROUND(x.actual_y_single, 2) as Actual_Single' if pred_type.lower() == 'test' else ''}
        {', ROUND(x.actual_y_small, 2) as Actual_Small' if pred_type.lower() == 'test' else ''}
        {', ROUND(x.actual_y_medium, 2) as Actual_Medium' if pred_type.lower() == 'test' else ''}
        {', ROUND(x.actual_y_large, 2) as Actual_Large' if pred_type.lower() == 'test' else ''}
        {', ROUND(x.actual_y_single + x.actual_y_small + x.actual_y_medium + x.actual_y_large, 2) as Actual_Total' if pred_type.lower() == 'test' else ''}
    FROM xgb_data x
    FULL OUTER JOIN knn_data k ON x.zone_id = k.zone_id AND x.hour = k.hour
    ORDER BY zone_id, hour
    """

    df = bq_client.query(query).to_dataframe()

    if df.empty:
        return pd.DataFrame({"Info": [f"No {pred_type.lower()} predictions found for {date}"]})

    # Convert hour and zone_id to string type before adding summary row
    df['hour'] = df['hour'].astype(str)
    df['zone_id'] = df['zone_id'].astype(str)

    # Create summary row
    summary = {}

    # Add numeric columns
    numeric_cols = df.select_dtypes(include=[np.number]).columns
    for col in numeric_cols:
        summary[col] = df[col].sum()

    # Add non-numeric columns
    summary['hour'] = 'TOTAL'
    summary['zone_id'] = '-'

    # Add CI columns
    for col in df.columns:
        if '_CI' in col and col not in summary:
            summary[col] = '-'

    # Append summary row
    df = pd.concat([df, pd.DataFrame([summary])], ignore_index=True)

    return df

## plot_zone_performance_comparison and other utilites

In [None]:
def plot_zone_performance_comparison(taxi_type, perf_metric, perf_passenger_type):
    """Create bar chart comparing zone performance across models"""

    # Map display metric names to column names
    metric_map = {
        'MAE': 'mae',
        'RMSE': 'rmse'
    }

    metric_col = metric_map[perf_metric]

    # Build query
    if perf_passenger_type == "All":
        # Average across all passenger types
        query = f"""
        WITH zone_metrics AS (
            SELECT
                zone_id,
                model_type,
                AVG({metric_col}) as avg_metric
            FROM `{PROJECT_ID}.{OUTPUT_DATASET}.fleet_recommender_{taxi_type}_metrics_new`
            GROUP BY zone_id, model_type
        )
        SELECT
            zone_id,
            MAX(CASE WHEN model_type = 'xgb' THEN avg_metric END) as xgb_metric,
            MAX(CASE WHEN model_type = 'knn' THEN avg_metric END) as knn_metric
        FROM zone_metrics
        GROUP BY zone_id
        ORDER BY zone_id
        """
    else:
        # Specific passenger type
        query = f"""
        SELECT
            zone_id,
            MAX(CASE WHEN model_type = 'xgb' THEN {metric_col} END) as xgb_metric,
            MAX(CASE WHEN model_type = 'knn' THEN {metric_col} END) as knn_metric
        FROM `{PROJECT_ID}.{OUTPUT_DATASET}.fleet_recommender_{taxi_type}_metrics_new`
        WHERE target = '{perf_passenger_type}'
        GROUP BY zone_id
        ORDER BY zone_id
        """

    df = bq_client.query(query).to_dataframe()

    if df.empty:
        return create_empty_plot("No performance metrics available")

    # Calculate which model is better for each zone
    df['better_model'] = df.apply(lambda row: 'XGBoost' if row['xgb_metric'] < row['knn_metric'] else 'KNN', axis=1)
    df['improvement'] = abs(df['xgb_metric'] - df['knn_metric'])

    # Sort by the better metric value
    df['best_metric'] = df[['xgb_metric', 'knn_metric']].min(axis=1)
    df = df.sort_values('best_metric')

    # Calculate average metrics
    xgb_avg = df['xgb_metric'].mean()
    knn_avg = df['knn_metric'].mean()
    overall_avg = df[['xgb_metric', 'knn_metric']].mean().mean()

    # Create grouped bar chart
    fig = go.Figure()

    # Add XGBoost bars
    fig.add_trace(go.Bar(
        name='XGBoost',
        x=df['zone_id'].astype(str),
        y=df['xgb_metric'],
        marker_color='#1f77b4',
        text=df['xgb_metric'].round(2),
        textposition='auto',
    ))

    # Add KNN bars
    fig.add_trace(go.Bar(
        name='KNN',
        x=df['zone_id'].astype(str),
        y=df['knn_metric'],
        marker_color='#ff7f0e',
        text=df['knn_metric'].round(2),
        textposition='auto',
    ))

    # Add average lines
    fig.add_hline(y=xgb_avg, line_dash="dash", line_color="#1f77b4",
                  annotation_text=f"XGBoost Avg: {xgb_avg:.2f}",
                  annotation_position="right")

    fig.add_hline(y=knn_avg, line_dash="dash", line_color="#ff7f0e",
                  annotation_text=f"KNN Avg: {knn_avg:.2f}",
                  annotation_position="right")

    # Update layout
    passenger_label = "All Types" if perf_passenger_type == "All" else perf_passenger_type.replace('y_', '').title()

    fig.update_layout(
        title=f'{perf_metric} by Zone - {passenger_label} Passengers',
        xaxis_title='Zone ID',
        yaxis_title=perf_metric,
        barmode='group',
        height=600,
        showlegend=True,
        hovermode='x unified'
    )

    # Add annotations for best performing zones and average metrics
    best_zones = df.nsmallest(3, 'best_metric')
    annotation_text = f"<b>Average {perf_metric}:</b><br>"
    annotation_text += f"XGBoost: {xgb_avg:.3f}<br>"
    annotation_text += f"KNN: {knn_avg:.3f}<br>"
    annotation_text += f"Overall: {overall_avg:.3f}<br><br>"
    annotation_text += "<b>Top 3 Zones:</b><br>"
    for _, row in best_zones.iterrows():
        annotation_text += f"Zone {row['zone_id']}: {row['better_model']} ({row['best_metric']:.2f})<br>"

    fig.add_annotation(
        text=annotation_text,
        xref="paper", yref="paper",
        x=0.02, y=0.98,
        showarrow=False,
        bgcolor="white",
        bordercolor="gray",
        borderwidth=1,
        align="left"
    )

    # If there are many zones, only show every nth label
    if len(df) > 20:
        tickvals = list(range(0, len(df), len(df)//20))
        ticktext = [df.iloc[i]['zone_id'] for i in tickvals]
        fig.update_xaxes(tickvals=tickvals, ticktext=ticktext)

    return fig

In [None]:
# First, add this function at the top level to load and cache geometry data
from functools import lru_cache

@lru_cache(maxsize=1)
def get_cached_zone_geometry():
    """Load and cache zone geometry data"""
    import geopandas as gpd
    from google.cloud import storage
    import os
    import tempfile
    import shutil

    storage_client = storage.Client()
    bucket_name = 'nyc_raw_data_bucket'

    # Load the lookup table
    blob = storage_client.bucket(bucket_name).blob('taxi_zone_lookup.csv')
    content = blob.download_as_bytes()
    taxi_zone_lookup = pd.read_csv(BytesIO(content))
    taxi_zone_lookup['LocationID'] = taxi_zone_lookup['LocationID'].astype(int)

    # Load the shapefile
    shapefile_prefix = 'taxi_zones/taxi_zones'
    required_extensions = ['.shp', '.shx', '.dbf', '.prj']

    temp_dir = tempfile.mkdtemp()
    try:
        for extension in required_extensions:
            blob_name = shapefile_prefix + extension
            blob = storage_client.bucket(bucket_name).blob(blob_name)
            local_path = os.path.join(temp_dir, os.path.basename(blob_name))
            blob.download_to_filename(local_path)

        shp_path = os.path.join(temp_dir, 'taxi_zones.shp')
        taxi_zones_gdf = gpd.read_file(shp_path)
        taxi_zones_gdf_wgs84 = taxi_zones_gdf.to_crs(epsg=4326)

        # Merge with lookup
        merged_df = pd.merge(taxi_zone_lookup, taxi_zones_gdf_wgs84,
                            left_on='LocationID', right_on='OBJECTID', how='inner')
        merged_df_cleaned = merged_df.dropna(subset=['geometry'])
        merged_gdf_cleaned = gpd.GeoDataFrame(merged_df_cleaned,
                                              geometry='geometry',
                                              crs='EPSG:4326')
        merged_gdf_cleaned["LocationID"] = merged_gdf_cleaned["OBJECTID"]

        return merged_gdf_cleaned

    finally:
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)

# The plotting function
def plot_zone_demand_map(taxi_type, date_type, single_date, hour, start_date, end_date,
                         passenger_type, model, pred_type):
    """Create geographic visualization of zone demand predictions"""
    import plotly.express as px
    import numpy as np

    # Build query based on date type
    if date_type == "Single Date":
        date_condition = f"date = '{single_date}' AND hr = {hour}"
        title_suffix = f"on {single_date} at {hour}:00"
    else:
        date_condition = f"date >= '{start_date}' AND date <= '{end_date}'"
        title_suffix = f"from {start_date} to {end_date}"

    # Add prediction type filter
    if pred_type != "All":
        date_condition += f" AND prediction_type = '{pred_type.lower()}'"

    # Build the query
    table_name = f"fleet_recommender_{taxi_type}_predictions_new_{model}"

    if passenger_type == "Total":
        metric_calc = f"""
        SUM({model}_pred_y_single + {model}_pred_y_small +
            {model}_pred_y_medium + {model}_pred_y_large) as total_pred_trips
        """
    else:
        metric_calc = f"SUM({model}_pred_{passenger_type}) as total_pred_trips"

    query = f"""
    SELECT
        zone_id,
        {metric_calc},
        COUNT(*) as num_predictions
    FROM `{PROJECT_ID}.{OUTPUT_DATASET}.{table_name}`
    WHERE {date_condition}
    GROUP BY zone_id
    ORDER BY zone_id
    """

    # Execute query
    df = bq_client.query(query).to_dataframe()

    if df.empty:
        return create_empty_plot("No data available for selected parameters")

    try:
        # Load cached geometry data
        zone_gdf = get_cached_zone_geometry()

        # Create a copy to avoid modifying cached data
        zone_gdf = zone_gdf.copy()

        # Merge prediction data with geometry
        map_df = zone_gdf.merge(df, left_on='LocationID', right_on='zone_id', how='left')

        # CRITICAL FIX: Handle all NA/NaN values before passing to Plotly
        # Replace NA values in all columns
        map_df = map_df.fillna({
            'total_pred_trips': 0,
            'zone_id': 0,
            'num_predictions': 0
        })

        # For string columns, replace any remaining NA with empty string
        for col in ['Zone', 'Borough', 'zone', 'borough']:
            if col in map_df.columns:
                map_df[col] = map_df[col].fillna('').astype(str)

        # Ensure zone and borough columns exist
        if 'zone' not in map_df.columns:
            map_df['zone'] = map_df['Zone'] if 'Zone' in map_df.columns else ''
        if 'borough' not in map_df.columns:
            map_df['borough'] = map_df['Borough'] if 'Borough' in map_df.columns else ''

        # Convert to proper types and handle any pandas NA
        map_df['total_pred_trips'] = pd.to_numeric(map_df['total_pred_trips'], errors='coerce').fillna(0)
        map_df['zone_id'] = pd.to_numeric(map_df['zone_id'], errors='coerce').fillna(0).astype(int)

        # Create the choropleth map
        fig = px.choropleth_mapbox(
            map_df,
            geojson=map_df.geometry,
            locations=map_df.index,
            color='total_pred_trips',
            color_continuous_scale='YlOrRd',
            mapbox_style="carto-positron",
            center={"lat": 40.7128, "lon": -74.0060},
            zoom=9,
            opacity=0.7,
            labels={'total_pred_trips': 'Predicted Trips'},
            hover_data={
                'zone_id': True,
                'zone': True,
                'borough': True,
                'total_pred_trips': ':.1f'
            }
        )

        # Update layout
        passenger_label = "Total" if passenger_type == "Total" else passenger_type.replace('y_', '').title()
        fig.update_layout(
            title=f"{passenger_label} Passenger Predictions - {model.upper()} Model<br>{title_suffix}",
            height=700,
            margin={"r":0,"t":60,"l":0,"b":0}
        )

        # Add color bar formatting
        fig.update_coloraxes(
            colorbar_title_text="Predicted<br>Trips",
            colorbar_thickness=15
        )

        return fig

    except Exception as e:
        print(f"Error creating map: {str(e)}")
        import traceback
        traceback.print_exc()
        # Fallback to bar chart
        return create_zone_demand_bar_chart(df, passenger_type, model, title_suffix)

def create_zone_demand_bar_chart(df, passenger_type, model, title_suffix):
    """Fallback visualization using bar chart when map fails"""
    import plotly.graph_objects as go

    # Sort by demand and show top 30 zones
    df_sorted = df.sort_values('total_pred_trips', ascending=False).head(30)

    fig = go.Figure()

    fig.add_trace(go.Bar(
        x=df_sorted['zone_id'].astype(str),
        y=df_sorted['total_pred_trips'],
        marker_color='lightcoral',
        text=df_sorted['total_pred_trips'].round(1),
        textposition='auto',
        hovertemplate='Zone %{x}<br>Predicted Trips: %{y:.1f}<extra></extra>'
    ))

    passenger_label = "Total" if passenger_type == "Total" else passenger_type.replace('y_', '').title()
    fig.update_layout(
        title=f"Top 30 Zones by {passenger_label} Passenger Demand<br>{model.upper()} Model - {title_suffix}",
        xaxis_title="Zone ID",
        yaxis_title="Predicted Trips",
        height=600,
        showlegend=False,
        hovermode='x unified'
    )

    return fig

# Gradio Dashboard

In [None]:
with gr.Blocks(title="NYC Taxi Dashboard") as demo:
    gr.Markdown("""
    # NYC Taxi Interactive Dashboard

    Analyze NYC taxi trip patterns, forecast demand, and detect anomalies using BigQuery data.
    """)

    with gr.Tab("Data Explorer"):
        with gr.Row():
            with gr.Column(scale=1):
                explore_taxi_type = gr.Dropdown(
                    choices=TAXI_TYPES,
                    label="Select Taxi Type",
                    value=None
                )
                explore_partitions = gr.CheckboxGroup(
                    choices=[],
                    label="Select Time Periods (YYYY_MM)",
                    value=[]
                )

            with gr.Column(scale=1):
                metric = gr.Radio(
                    ["Trips", "Revenue"],
                    value="Trips",
                    label="Metric to Display"
                )
                granularity = gr.Radio(
                    ["Daily", "Hourly"],
                    value="Daily",
                    label="Time Granularity"
                )

        explore_plot = gr.Plot(label="Time Series Visualization")

        # Update partitions when taxi type changes
        explore_taxi_type.change(
            lambda x: gr.update(choices=get_available_partitions(x) if x else []),
            inputs=[explore_taxi_type],
            outputs=[explore_partitions]
        )

        # Update plot when any parameter changes
        for component in [explore_taxi_type, explore_partitions, metric, granularity]:
            component.change(
                plot_time_series,
                inputs=[explore_taxi_type, explore_partitions, metric, granularity],
                outputs=[explore_plot]
            )

    with gr.Tab("ML & Forecasting"):
        gr.Markdown("""
        ### Load Data and Generate Forecasts
        Select data range, load it, then generate forecasts with confidence intervals.
        """)

        # Data Loading Section
        gr.Markdown("#### Step 1: Load Training Data")
        with gr.Row():
            ml_taxi_type = gr.Dropdown(
                choices=TAXI_TYPES,
                label="Select Taxi Type",
                value=None
            )
            forecast_metric = gr.Radio(
                ["trips", "revenue"],
                value="trips",
                label="Metric to Forecast"
            )

        with gr.Row():
            start_partition = gr.Dropdown(
                choices=[],
                label="Start Period (YYYY_MM)",
                value=None
            )
            end_partition = gr.Dropdown(
                choices=[],
                label="End Period (YYYY_MM)",
                value=None
            )

        load_button = gr.Button("Load and Prepare Data", variant="primary")
        load_status = gr.Textbox(label="Loading Status", lines=3)

        # Forecasting Section
        gr.Markdown("#### Step 2: Configure and Generate Forecast")
        with gr.Row():
            forecast_days = gr.Slider(
                minimum=1,
                maximum=14,
                value=7,
                step=1,
                label="Days to Forecast",
                info="Number of days to predict into the future"
            )
            confidence_level = gr.Radio(
                choices=[0.95, 0.99],
                value=0.95,
                label="Confidence Level",
                info="Confidence interval coverage"
            )
            show_actual_future = gr.Checkbox(
                label="Show Actual Future Data (if available)",
                value=False
            )

        forecast_button = gr.Button("Generate Forecast", variant="primary")
        forecast_plot = gr.Plot(label="Forecast Results with Confidence Intervals")

        # Event handlers
        ml_taxi_type.change(
            update_partition_choices,
            inputs=[ml_taxi_type],
            outputs=[start_partition, end_partition]
        )

        load_button.click(
            lambda taxi, metric, start, end: load_and_prepare_data_with_metric(taxi, start, end, metric),
            inputs=[ml_taxi_type, forecast_metric, start_partition, end_partition],
            outputs=[load_status]
        )

        forecast_button.click(
            lambda days, conf, actual, metric: train_combined_forecast_with_ci(days, conf, actual, metric),
            inputs=[forecast_days, confidence_level, show_actual_future, forecast_metric],
            outputs=[forecast_plot]
        )

    with gr.Tab("Anomaly Detection"):
        gr.Markdown("""
        ### Pre-computed Anomaly Detection Results
        View anomalies detected by the pipeline, categorized by severity with holiday annotations.
        """)

        # Taxi type selector
        anomaly_taxi_type = gr.Dropdown(
            choices=TAXI_TYPES,
            label="Select Taxi Type",
            value=None
        )

        with gr.Row():
            with gr.Column(scale=3):
                anomaly_plot = gr.Plot(label="Anomaly Detection Results")
            with gr.Column(scale=1):
                anomaly_summary = gr.Textbox(
                    label="Anomaly Summary",
                    lines=15,
                    value="Select a taxi type to see summary"
                )
        anomaly_button = gr.Button("Load Anomaly Results", variant="primary")

        anomaly_button.click(
            detect_anomalies_from_db,
            inputs=[anomaly_taxi_type],
            outputs=[anomaly_plot]
        )

        anomaly_button.click(
            get_anomaly_summary,
            inputs=[anomaly_taxi_type],
            outputs=[anomaly_summary]
        )

    # Updated Fleet Recommender tab
    with gr.Tab("Fleet Recommendations"):
        gr.Markdown("""
        ### Zone-Based Fleet Optimization Predictions
        View passenger demand predictions by zone using XGBoost and KNN models.
        Compare test performance with future predictions.
        """)

        with gr.Row():
            fleet_taxi_type = gr.Dropdown(
                choices=["yellow", "green"],
                label="Taxi Type",
                value="yellow"
            )
            fleet_model_type = gr.Dropdown(
                choices=["xgb", "knn", "both"],
                label="Model Type",
                value="both"
            )
            fleet_date_range = gr.Textbox(
                label="Available Date Range",
                value="Loading...",
                interactive=False
            )

        # Hidden components
        fleet_min_date = gr.State(value=None)
        fleet_max_date = gr.State(value=None)
        zone_df_state = gr.State(value=pd.DataFrame())

        # Model Performance Section
        # with gr.Row():
        #     with gr.Column():
        #         gr.Markdown("### Model Performance Summary")
        #         metrics_display = gr.Textbox(
        #             label="Zone-Level Performance Metrics (SMAPE)",
        #             lines=15,
        #             value="Select taxi type to load metrics"
        #         )

        with gr.Tabs():
            with gr.Tab("Time Series View"):
                with gr.Row():
                    zone_select = gr.Dropdown(
                        choices=["All Zones"],
                        label="Zone ID (Ordered by Trip Volume)",
                        value="All Zones"
                    )
                    ts_display_type = gr.Radio(
                        choices=[ "By Passenger Type", "Model Comparison"],
                        value="By Passenger Type",
                        label="Display Type"
                    )
                    prediction_type_filter = gr.Radio(
                        choices=["All"], #  "Test Only", "Future Only"
                        value="All",
                        label="Data (All : Test+Future)"
                    )

                with gr.Row():
                    ts_start_date = gr.Textbox(
                        label="Start Date (YYYY-MM-DD)",
                        value="2025-06-24"
                    )
                    ts_end_date = gr.Textbox(
                        label="End Date (YYYY-MM-DD)",
                        value="2025-07-15"
                    )

                fleet_ts_button = gr.Button("Generate Time Series", variant="primary")

                with gr.Row():
                    fleet_ts_plot = gr.Plot(label="Zone Predictions Time Series")

                # Updated plot function
                def plot_zone_timeseries(taxi_type, model_type, start_date, end_date, zone_string, display_type, pred_type):
                    zone_id = extract_zone_id(zone_string)
                    return plot_fleet_timeseries_new(taxi_type, model_type, start_date, end_date, zone_id, display_type, pred_type)

                fleet_ts_button.click(
                    plot_zone_timeseries,
                    inputs=[fleet_taxi_type, fleet_model_type, ts_start_date, ts_end_date,
                          zone_select, ts_display_type, prediction_type_filter],
                    outputs=[fleet_ts_plot]
                )

            with gr.Tab("Zone Performance Comparison"):
                with gr.Row():
                    perf_metric = gr.Radio(
                        choices=[ "MAE", "RMSE"],
                        value="MAE",
                        label="Performance Metric"
                    )
                    perf_passenger_type = gr.Dropdown(
                        choices=["All", "y_single", "y_small", "y_medium", "y_large"],
                        value="All",
                        label="Passenger Type"
                    )

                perf_button = gr.Button("Generate Performance Comparison", variant="primary")
                perf_plot = gr.Plot(label="Zone Performance Comparison")

                perf_button.click(
                    plot_zone_performance_comparison,
                    inputs=[fleet_taxi_type, perf_metric, perf_passenger_type],
                    outputs=[perf_plot]
                )


            with gr.Tab("Geographic View"):
                gr.Markdown("""
                ### Geographic Distribution of Predicted Demand
                Visualize passenger predictions by zone on an interactive map.
                """)

                with gr.Row():
                    with gr.Column(scale=1):
                        geo_date_type = gr.Radio(
                            choices=["Single Date", "Date Range"],
                            value="Single Date",
                            label="Date Selection Type"
                        )

                        # Single date inputs
                        geo_single_date = gr.Textbox(
                            label="Select Date (YYYY-MM-DD)",
                            value="",
                            visible=True
                        )
                        geo_hour = gr.Slider(
                            minimum=0,
                            maximum=23,
                            value=12,
                            step=1,
                            label="Hour of Day (for single date)",
                            visible=True
                        )

                        # Date range inputs
                        geo_start_date = gr.Textbox(
                            label="Start Date (YYYY-MM-DD)",
                            value="",
                            visible=False
                        )
                        geo_end_date = gr.Textbox(
                            label="End Date (YYYY-MM-DD)",
                            value="",
                            visible=False
                        )

                    with gr.Column(scale=1):
                        geo_passenger_type = gr.Dropdown(
                            choices=["Total", "y_single", "y_small", "y_medium", "y_large"],
                            value="Total",
                            label="Passenger Type"
                        )
                        geo_model_select = gr.Radio(
                            choices=["xgb", "knn"],
                            value="xgb",
                            label="Model"
                        )
                        geo_pred_type = gr.Radio(
                            choices=["All"],
                            value="All",
                            label="Prediction Type"
                        )

                geo_map_button = gr.Button("Generate Map", variant="primary")
                geo_map = gr.Plot(label="Zone Demand Heatmap")

                # Toggle visibility based on date type selection
                def toggle_date_inputs(date_type):
                    if date_type == "Single Date":
                        return (
                            gr.update(visible=True),   # geo_single_date
                            gr.update(visible=True),   # geo_hour
                            gr.update(visible=False),  # geo_start_date
                            gr.update(visible=False)   # geo_end_date
                        )
                    else:
                        return (
                            gr.update(visible=False),  # geo_single_date
                            gr.update(visible=False),  # geo_hour
                            gr.update(visible=True),   # geo_start_date
                            gr.update(visible=True)    # geo_end_date
                        )

                geo_date_type.change(
                    toggle_date_inputs,
                    inputs=[geo_date_type],
                    outputs=[geo_single_date, geo_hour, geo_start_date, geo_end_date]
                )

                # Plot function
                geo_map_button.click(
                    plot_zone_demand_map,
                    inputs=[fleet_taxi_type, geo_date_type, geo_single_date, geo_hour,
                            geo_start_date, geo_end_date, geo_passenger_type,
                            geo_model_select, geo_pred_type],
                    outputs=[geo_map]
                )

                # Auto-populate dates when taxi type changes
                fleet_taxi_type.change(
                    lambda max_d: max_d if max_d else "",
                    inputs=[fleet_max_date],
                    outputs=[geo_single_date]
                )

                fleet_taxi_type.change(
                    lambda min_d, max_d: (min_d, max_d) if min_d and max_d else ("", ""),
                    inputs=[fleet_min_date, fleet_max_date],
                    outputs=[geo_start_date, geo_end_date]
                )

            with gr.Tab("Detailed Predictions Table"):
                with gr.Row():
                    table_zone = gr.Dropdown(
                        choices=["All Zones"],
                        label="Zone ID",
                        value="All Zones"
                    )
                    table_date = gr.Textbox(
                        label="Date (YYYY-MM-DD)",
                        value=""
                    )
                    table_pred_type = gr.Radio(
                        choices=["Test", "Future"],
                        value="Future",
                        label="Prediction Type"
                    )

                table_button = gr.Button("Load Predictions Table", variant="primary")

                with gr.Row():
                    predictions_table = gr.Dataframe(
                        label="Hourly Predictions with Confidence Intervals"
                    )

                def load_predictions_table(taxi_type, zone_string, date, pred_type):
                    zone_id = extract_zone_id(zone_string)
                    return load_zone_predictions_table(taxi_type, zone_id, date, pred_type)

                table_button.click(
                    load_predictions_table,
                    inputs=[fleet_taxi_type, table_zone, table_date, table_pred_type],
                    outputs=[predictions_table]
                )

        # Update functions
        def update_fleet_interface(taxi_type):
            min_date, max_date, range_text = get_fleet_date_range_new(taxi_type)
            zone_list = get_zones_with_predictions(taxi_type)
            metrics_text = load_zone_metrics_summary(taxi_type)

            return (
                range_text,
                min_date,
                max_date,
                gr.update(choices=zone_list, value=zone_list[0] if zone_list else "All Zones"),
                metrics_text,
                gr.update(choices=zone_list, value=zone_list[0] if zone_list else "All Zones"),
                gr.update(choices=zone_list, value=zone_list[0] if zone_list else "All Zones"),
                gr.update(choices=zone_list, value=zone_list[0] if zone_list else "All Zones"),
            )

        # Event handlers
        metrics_display = gr.State(value = None)
        fleet_taxi_type.change(
            update_fleet_interface,
            inputs=[fleet_taxi_type],
            outputs=[fleet_date_range, fleet_min_date, fleet_max_date,
                    zone_select, metrics_display, metrics_display, table_zone, zone_df_state]
        )

        # outputs=[fleet_date_range, fleet_min_date, fleet_max_date,
        #             zone_select, metrics_display, test_zone, table_zone, zone_df_state]

        # Auto-populate dates
        def populate_dates(min_date, max_date):
            if min_date and max_date:
                # Show test period by default
                test_start = pd.to_datetime(max_date) - pd.Timedelta(days=21)
                test_end = max_date
                return test_start.strftime('%Y-%m-%d'), test_end
            return "", ""

        fleet_taxi_type.change(
            lambda min_d, max_d: populate_dates(min_d, max_d),
            inputs=[fleet_min_date, fleet_max_date],
            outputs=[ts_start_date, ts_end_date]
        )

        fleet_taxi_type.change(
            lambda max_d: max_d if max_d else "",
            inputs=[fleet_max_date],
            outputs=[table_date]
        )

        # Add initial load event
        demo.load(
            update_fleet_interface,
            inputs=[fleet_taxi_type],
            outputs=[fleet_date_range, fleet_min_date, fleet_max_date,
                    zone_select, metrics_display, metrics_display, table_zone, zone_df_state]
        )

        # Also trigger date population on initial load
        demo.load(
            lambda: populate_dates(*get_fleet_date_range_new("yellow")[:2]),
            inputs=[],
            outputs=[ts_start_date, ts_end_date]
        )

        # Set initial table date
        demo.load(
            lambda: get_fleet_date_range_new("yellow")[1],  # max_date
            inputs=[],
            outputs=[table_date]
        )


demo.launch(debug=True)