In [1]:
import pandas as pd
from datetime import datetime

from src.paths import TRAINING_DATA

# Starts

## Train-test split on starting data

In [3]:
start_table = pd.read_parquet(path = TRAINING_DATA/"integers/start_june.parquet")

In [None]:
def train_test_split(
        data: pd.DataFrame,
        scenario: str,
        cutoff_date: datetime,
        target_column: str
) -> tuple[pd.DataFrame, pd.Series, pd.DataFrame, pd.Series]:
        
    """
    This is just a primitive splitting function that treats all data
    before a certain date as training data, and everything after that
    as test data
    """

    training_data = data[data[f"{scenario}_hour"] < cutoff_date].reset_index(drop=True)

    test_data = data[data[f"{scenario}_hour"] > cutoff_date].reset_index(drop=True)

    x_train = training_data.drop(columns=[target_column])
    y_train = training_data[target_column]

    x_test = test_data.drop(columns=[target_column])
    y_test = test_data[target_column]

    return x_train, y_train, x_test, y_test


In [4]:
x_train, y_train, x_test, y_test = train_test_split(
    data=start_table,
    scenario="start",
    cutoff_date=datetime(2023,6,1,0,0,0),
    target_column="trips_next_hour"
)

#for data in [x_train, y_train, x_test, y_test]:

print(f"{x_train.shape = }")
print(f"{y_train.shape = }")
print(f"{x_test.shape = }")
print(f"{y_test.shape = }")

x_train.shape = (0, 674)
y_train.shape = (0,)
x_test.shape = (2624327, 674)
y_test.shape = (2624327,)


### Defining the base model

In [5]:
import numpy as np 

class BaselineModel:
    
    def fit(self, x_train: pd.DataFrame, y_train: pd.Series):
        pass 
    
    def predict(self, x_test: pd.DataFrame) -> np.array:
        
        return x_test["trips_previous_1_hour"]

### Base Model Predictions 

In [6]:
model = BaselineModel()
predictions = model.predict(x_test)

### Mean absolute error

In [7]:
from sklearn.metrics import mean_absolute_error

start_test_mae = mean_absolute_error(y_test, predictions)

In [8]:
start_test_mae

0.026907469991353974

# Stops