In [9]:
from typing import List, Dict
import warnings

from prophet import Prophet
import numpy as np
import pandas as pd
import pickle
import multiprocessing
try:
    multiprocessing.set_start_method("fork")
except RuntimeError as e:
    pass

from vacances_scolaires_france import SchoolHolidayDates


warnings.simplefilter("ignore")

In [10]:
def load_gz_csv_to_df(filepath: str) -> pd.DataFrame:
    try:
        return pd.read_csv(
            filepath, 
            compression='gzip', 
            header=0, sep=',', 
            error_bad_lines=False
        )
    except FileNotFoundError:
        exit(1)

bu_df = load_gz_csv_to_df('data/bu_feat.csv.gz')
train_df = load_gz_csv_to_df('data/train.csv.gz')
test_df = load_gz_csv_to_df('data/train.csv.gz')

### Processing data

In [11]:
train_bu_df = train_df.merge(bu_df, on="but_num_business_unit")
assert train_bu_df.shape[0] == train_df.shape[0]

In [12]:
# remove correlated feature
train_bu_df = train_bu_df.drop(["but_postcode", "zod_idr_zone_dgr"], axis=1)

In [13]:
train_bu_df["day_id"] = pd.to_datetime(train_bu_df["day_id"], format="%Y-%m-%d")
train_bu_df["year"] = train_bu_df["day_id"].dt.year
train_bu_df["week_number"] = train_bu_df["day_id"].dt.week

In [14]:
train_bu_df.but_num_business_unit.unique().shape

(322,)

In [15]:
def create_holidays_df(start_year: int, end_year: int) -> pd.DataFrame:
    years = np.arange(start_year, end_year +1)
    d = SchoolHolidayDates()
    holidays_list = []
    for holiday_year in years:
        holidays_list += list(d.holidays_for_year(holiday_year).keys())
    df = pd.DataFrame(holidays_list, columns=["ds"])
    df["holiday_name"] = "fr_holiday"
    df.rename(columns={"holiday_name": "holiday"}, inplace=True)
    return df

In [16]:
holidays_df = create_holidays_df(2012, 2017)

In [17]:
def init_prophet_model(holidays_df: pd.DataFrame) -> Prophet:
    m = Prophet(
        holidays=holidays_df,
        holidays_prior_scale=1
    )
    return m

In [18]:
train_weekly_tunrnover_by_store_and_dpt = (
    train_bu_df
        .groupby(["but_num_business_unit", "dpt_num_department", "year", "week_number"], as_index=False)
        .agg({"turnover": "sum"})
)

In [19]:
train_weekly_tunrnover_by_store_and_dpt["but_bu_dpt_id"] = (
    train_weekly_tunrnover_by_store_and_dpt["but_num_business_unit"].astype(str) 
        + "-" 
        + train_weekly_tunrnover_by_store_and_dpt["dpt_num_department"].astype(str)
)

In [20]:
train_weekly_tunrnover_by_store_and_dpt["ds"] = (
    train_weekly_tunrnover_by_store_and_dpt["year"].astype(str) 
        + "-" 
        + train_weekly_tunrnover_by_store_and_dpt["week_number"].astype(str)
)

In [21]:
train_weekly_tunrnover_by_store_and_dpt = (
    train_weekly_tunrnover_by_store_and_dpt[["but_bu_dpt_id", "ds", "turnover"]]
    .rename(columns={"turnover": "y"})
)

In [22]:
train_weekly_tunrnover_by_store_and_dpt["ds"] = (
    pd.to_datetime(
        train_weekly_tunrnover_by_store_and_dpt["ds"] + "-1", 
        format="%Y-%W-%w"
    )
)

In [23]:
def create_store_dpt_df(df: pd.DataFrame) -> List[pd.DataFrame]:
    return [df[df["but_bu_dpt_id"] == i] for i in df["but_bu_dpt_id"].unique()]

In [24]:
train_weekly_tunrnover_by_store_and_dpt_list = create_store_dpt_df(train_weekly_tunrnover_by_store_and_dpt)

In [25]:
def remove_n_last_values(dfs: List[pd.DataFrame], n_value_to_remove: int) -> List[pd.DataFrame]:
    return [df.drop(df.tail(n_value_to_remove).index) for df in dfs]

In [26]:
train_weekly_tunrnover_by_store_and_dpt_list = remove_n_last_values(train_weekly_tunrnover_by_store_and_dpt_list, 8)

In [31]:
train_weekly_tunrnover_by_store_and_dpt_list_two_values = train_weekly_tunrnover_by_store_and_dpt_list[:2]

## Train model

In [32]:
def fit_model(dfs: List[pd.DataFrame], holiday_df: pd.DataFrame) -> List[Prophet]:
    models = {}
    for df in dfs:
        df = df.reset_index(drop=True)
        model = init_prophet_model(holiday_df)
        df
        df_fit = df[["ds", "y"]]
        model.fit(df_fit)
        key = df.loc[0]["but_bu_dpt_id"]
        models[key] = model
    return models

In [33]:
models = fit_model(train_weekly_tunrnover_by_store_and_dpt_list_two_values, holidays_df)

INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.


Initial log joint probability = -5.2924
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes 
      99       374.093   5.51942e-06       101.613      0.4176      0.4176      132   
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes 
     106       374.098   3.63117e-05       95.1402   3.786e-07       0.001      174  LS failed, Hessian reset 
     192       374.101   9.39217e-09       79.8676      0.4686      0.4686      285   
Optimization terminated normally: 
  Convergence detected: absolute parameter change was below tolerance
Initial log joint probability = -4.55831
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes 
      99       405.493   0.000777116       81.4875           1           1      126   
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes 
     169       405.613   5.83851e-05       87.2615   7.034e-07       0.001

In [None]:
# save model
def save_models(filename: str, models: Dict[str, Prophet]) -> None:
    with open(filename, "wb") as f:
        pickle.dump(models, f, pickle.HIGHEST_PROTOCOL)
save_models("models.pkl", models)

## Inference

In [36]:
def inference(models: Dict[str, Prophet]) -> Dict[str, pd.DataFrame]:
    return {key: models.get(key).predict(models.get(key).make_future_dataframe(periods=8, freq='W')) for key in models.keys()}

In [37]:
predictions = inference(models)

In [39]:
predictions.get("1-73")

Unnamed: 0,ds,trend,yhat_lower,yhat_upper,trend_lower,trend_upper,additive_terms,additive_terms_lower,additive_terms_upper,fr_holiday,...,holidays,holidays_lower,holidays_upper,yearly,yearly_lower,yearly_upper,multiplicative_terms,multiplicative_terms_lower,multiplicative_terms_upper,yhat
0,2012-12-24,26.800854,-32.356730,40.165866,26.800854,26.800854,-22.036001,-22.036001,-22.036001,-8.224847,...,-8.224847,-8.224847,-8.224847,-13.811155,-13.811155,-13.811155,0.0,0.0,0.0,4.764853
1,2013-01-07,27.136429,-13.726616,56.039773,27.136429,27.136429,-4.651057,-4.651057,-4.651057,0.000000,...,0.000000,0.000000,0.000000,-4.651057,-4.651057,-4.651057,0.0,0.0,0.0,22.485373
2,2013-01-14,27.304217,-17.401051,59.960374,27.304217,27.304217,-6.803341,-6.803341,-6.803341,0.000000,...,0.000000,0.000000,0.000000,-6.803341,-6.803341,-6.803341,0.0,0.0,0.0,20.500876
3,2013-01-21,27.472004,-20.639267,50.633936,27.472004,27.472004,-11.900751,-11.900751,-11.900751,0.000000,...,0.000000,0.000000,0.000000,-11.900751,-11.900751,-11.900751,0.0,0.0,0.0,15.571254
4,2013-01-28,27.639792,-26.849437,46.794354,27.639792,27.639792,-17.764322,-17.764322,-17.764322,0.000000,...,0.000000,0.000000,0.000000,-17.764322,-17.764322,-17.764322,0.0,0.0,0.0,9.875469
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
243,2017-08-27,64.775538,34.471818,102.194680,64.773618,64.777634,5.012546,5.012546,5.012546,-8.224847,...,-8.224847,-8.224847,-8.224847,13.237392,13.237392,13.237392,0.0,0.0,0.0,69.788083
244,2017-09-03,64.894109,33.813533,103.195968,64.890723,64.898068,3.532285,3.532285,3.532285,-8.224847,...,-8.224847,-8.224847,-8.224847,11.757132,11.757132,11.757132,0.0,0.0,0.0,68.426395
245,2017-09-10,65.012681,37.136921,108.546184,65.007965,65.018719,9.116945,9.116945,9.116945,0.000000,...,0.000000,0.000000,0.000000,9.116945,9.116945,9.116945,0.0,0.0,0.0,74.129626
246,2017-09-17,65.131253,31.790833,104.168679,65.124130,65.139845,4.000725,4.000725,4.000725,0.000000,...,0.000000,0.000000,0.000000,4.000725,4.000725,4.000725,0.0,0.0,0.0,69.131978
