In [None]:
import json
import pandas as pd
from prophet import Prophet, serialize
from prophet.diagnostics import cross_validation, performance_metrics
import mlflow
import sys

import logging
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG,
                    format=f"%(asctime)s - [%(levelname)s] - %(name)s - (%(filename)s).%(funcName)s(%(lineno)d) - %(message)s")
logger = logging.getLogger(__name__)

ARTIFACT_PATH = "model"

In [None]:
SOURCE_DATA = (
    "https://raw.githubusercontent.com/rkrikbaev/model-training/master/jupyter/project/fp_archives.csv"
)

In [None]:
def extract_params(self, pr_model):
    return {attr: getattr(pr_model, attr) for attr in serialize.SIMPLE_ATTRIBUTES}

In [None]:
def train(model, dataframe, settings):

    # Init prophet model

    model = Prophet(
        growth=settings["growth"],
        seasonality_mode=settings["seasonality_mode"],
        changepoint_prior_scale=settings['changepoint_prior_scale'],
        seasonality_prior_scale=settings['seasonality_prior_scale'],
        daily_seasonality=settings['daily_seasonality'],
        weekly_seasonality=settings['weekly_seasonality'],
        yearly_seasonality=settings['yearly_seasonality']
    )

    for season in settings['seasonality']:
        model.add_seasonality(
            name=season['name'],
            period=season['period'],
            fourier_order=season['fourier_order']
        )

    with mlflow.start_run():

        fitted_model = model.fit(df)
        mlflow.prophet.log_model(fitted_model, artifact_path=ARTIFACT_PATH)
        params = extract_params(fitted_model)

        metric_keys = ["mse", "rmse", "mae",
                       "mape", "mdape", "smape", "coverage"]
        
        cross_validation_params = settings.get('cross_validation')
        cross_validation_enable = settings.get('cross_validation_enabled')

        if cross_validation_params and cross_validation_enable:
            
            metrics_raw = cross_validation(
                model=fitted_model,
                horizon=cross_validation_params.get('horizon'),  # "365",
                period=cross_validation_params.get('period'),  # "180",
                initial=cross_validation_params.get('initial'),  # "710",
                parallel=cross_validation_params.get(
                    'parallel'),  # "threads",
                disable_tqdm=cross_validation_params.get(
                    'disable_tqdm'),  # True,
                units=cross_validation_params.get('units')  # days
            )

            cv_metrics = performance_metrics(metrics_raw)
            metrics = {k: cv_metrics[k].mean() for k in metric_keys}

            # logger.debug(
            #     f"Logged Metrics: \n{json.dumps(metrics, indent=2)}")
            # logger.debug(
            #     f"Logged Params: \n{json.dumps(params, indent=2)}")

        mlflow.prophet.log_model(model, artifact_path=ARTIFACT_PATH)
        mlflow.log_metrics(metrics)
        mlflow.log_params(params)

        model_uri = mlflow.get_artifact_uri(ARTIFACT_PATH)

        logger.debug(f"Model artifact logged to: {model_uri}")


In [None]:
# model settings
settings = {
        "growth": "linear", 
        "seasonality_mode": "multiplicative", 
        "changepoint_prior_scale": 30,
        "seasonality_prior_scale": 35,
        "interval_width": 0.98,
        "daily_seasonality": "auto",
        "weekly_seasonality": "auto",
        "yearly_seasonality": false, 
        "seasonality": [{"name": "hour","period": 0.417, "fourier_order": 5}], 
        "cross_validation":{
            "horizon":"48 hours", "period":"24", "initial":"144 hours","parallel":"threads","disable_tqdm":true
        }
    }

In [None]:
train(model, dataframe, settings)
