# Inference Script for VTA Ridership Prediction

***For predictions without Weather Data***

## Import packages

In [1]:
import pandas as pd
import numpy as np
import pickle
import os

In [2]:
CLEAN_DATA_FOLDER = "clean_data"
STAGING_DATA_FOLDER = "staging_data"
MODEL_DATA_FOLDER = "models"
model = pickle.load(open(os.path.join(MODEL_DATA_FOLDER, "base_xgboost_wo_weather.pkl"), "rb"))
line_sequence_stop = pd.read_csv(os.path.join(STAGING_DATA_FOLDER, "line_sequence_stop.csv"))
stop_names = pd.read_csv(os.path.join(STAGING_DATA_FOLDER, "stop_names.csv"))
stops = pd.read_csv(os.path.join(STAGING_DATA_FOLDER, "stops.csv"))

## User Input here

- Enter stop name in small caps
- TMAX is maximum temperature on that day in Fahrenheit
- TMIN is minimum temperature on that day in Fahrenheit
- PRCP is precipitation in inches

In [3]:
INPUT_DATE = "2018-04-24"
INPUT_HOLIDAY = False
INPUT_SPECIAL = False
INPUT_STOP_NAME = "baypointe"

In [4]:
def determine_service(month, day, weekday, holiday, special):
    # July 4th is considered regardless of the weekday or holiday status unless it is special
    if month == 7 and day == 4:
        return 4

    # Special days handling
    if special:
        if weekday in range(5):  # Monday to Friday
            return 5
        elif weekday == 5:  # Saturday
            return 6
        elif weekday == 6 or holiday:  # Sunday or holiday
            return 7

    # Regular days handling
    if holiday:
        return 3  # Sunday/Holiday mapping
    if weekday == 5:
        return 2  # Saturday mapping
    if weekday == 6:
        return 3  # Sunday mapping

    # Default to weekday if no other conditions are met
    return 1

date = pd.to_datetime(INPUT_DATE).date()
year = pd.to_datetime(INPUT_DATE).year
month = pd.to_datetime(INPUT_DATE).month
day = pd.to_datetime(INPUT_DATE).day_of_year
weekday = pd.to_datetime(INPUT_DATE).weekday
service = determine_service(month, day, weekday, INPUT_HOLIDAY, INPUT_SPECIAL)
date_df = pd.DataFrame({"Year": [year], "Day": [day], "Service": [service], "Date": [date]})
date_df

Unnamed: 0,Year,Day,Service,Date
0,2018,114,1,2018-04-24


In [5]:
stop_names = stop_names[stop_names["Stop Name"].str.contains(INPUT_STOP_NAME.upper())]
stop_names

Unnamed: 0,Stop Id,Stop Name
1099,1366,TASMAN & BAYPOINTE
1103,1377,TASMAN & BAYPOINTE
3454,4760,BAYPOINTE STATION (0)
3455,4761,BAYPOINTE STATION (1)
3499,4801,BAYPOINTE STATION (0)
3504,4806,BAYPOINTE STATION (1)


## User Input here

In [6]:
INPUT_STOP_IDS = [4760, 4761]

In [7]:
stops = stops[stops["Stop Id"].isin(INPUT_STOP_IDS)]
input_df = stop_names.merge(stops, how="inner", on="Stop Id")
input_df = input_df.merge(line_sequence_stop, on="Stop Id", how="inner").sort_values(
    ["Stop Id", "Line", "Direction Number"]
)
input_df = input_df.merge(date_df, how="cross")
input_df[
    [
        "Day",
        "Line",
        "Service",
        "Direction Number",
        "Sequence",
        "Latitude",
        "Longitude",
    ]
]

Unnamed: 0,Day,Line,Service,Direction Number,Sequence,Latitude,Longitude
0,114,912,1,0,17,37.410778,-121.94153
1,114,913,1,0,5,37.410778,-121.94153
2,114,917,1,0,2,37.410778,-121.94153
3,114,921,1,0,28,37.410778,-121.94153
4,114,912,1,1,10,37.41053,-121.942314


In [8]:
predictions = model.predict(
    input_df[
        [
            "Day",
            "Line",
            "Service",
            "Direction Number",
            "Sequence",
            "Latitude",
            "Longitude",
        ]
    ]
)
pred_df = pd.DataFrame({"On": predictions}).apply(np.floor)
output_df = pd.concat(
    [
        input_df[["Date", "Stop Name", "Line", "Service", "Direction Number", "Sequence"]],
        pred_df,
    ],
    axis=1,
).sort_values(["Date", "Line", "Service", "Direction Number"])
output_df

Unnamed: 0,Date,Stop Name,Line,Service,Direction Number,Sequence,On
0,2018-04-24,BAYPOINTE STATION (0),912,1,0,17,-15.0
4,2018-04-24,BAYPOINTE STATION (1),912,1,1,10,-9.0
1,2018-04-24,BAYPOINTE STATION (0),913,1,0,5,6.0
2,2018-04-24,BAYPOINTE STATION (0),917,1,0,2,-47.0
3,2018-04-24,BAYPOINTE STATION (0),921,1,0,28,-87.0
