In [1]:
%pip install protobuf==3.20.*
%matplotlib inline

[0mNote: you may need to restart the kernel to use updated packages.


In [7]:
import mlflow
import json
import pandas as pd
import numpy as np
from prophet import Prophet, serialize
from prophet.diagnostics import cross_validation, performance_metrics

SOURCE_DATA = (
    "https://raw.githubusercontent.com/facebook/prophet/master/examples/example_retail_sales.csv"
)
ARTIFACT_PATH = "model"
np.random.seed(12345)

In [8]:
tracking_server_uri = "http://138.68.70.41:5000"  # set to your server URI

try:
    mlflow.set_tracking_uri(tracking_server_uri)
except:
    print(
        """Couldn't connect to remote MLFLOW tracking server""")

In [9]:
def extract_params(pr_model):
    return {attr: getattr(pr_model, attr) for attr in serialize.SIMPLE_ATTRIBUTES}


sales_data = pd.read_csv(SOURCE_DATA)

In [5]:
with mlflow.start_run():

    model = Prophet().fit(sales_data)

    params = extract_params(model)

    metric_keys = ["mse", "rmse", "mae", "mape", "mdape", "smape", "coverage"]
    metrics_raw = cross_validation(
        model=model,
        horizon="365 days",
        period="180 days",
        initial="710 days",
        parallel="threads",
        disable_tqdm=True,
    )
    cv_metrics = performance_metrics(metrics_raw)
    metrics = {k: cv_metrics[k].mean() for k in metric_keys}

    print(f"Logged Metrics: \n{json.dumps(metrics, indent=2)}")
    print(f"Logged Params: \n{json.dumps(params, indent=2)}")

    mlflow.prophet.log_model(model, artifact_path=ARTIFACT_PATH)
    mlflow.log_params(params)
    mlflow.log_metrics(metrics)
    model_uri = mlflow.get_artifact_uri(ARTIFACT_PATH)
    print(f"Model artifact logged to: {model_uri}")

06:39:07 - cmdstanpy - INFO - Chain [1] start processing
06:39:07 - cmdstanpy - INFO - Chain [1] done processing
06:39:07 - cmdstanpy - INFO - Chain [1] start processing
06:39:08 - cmdstanpy - INFO - Chain [1] start processing
06:39:08 - cmdstanpy - INFO - Chain [1] start processing
06:39:08 - cmdstanpy - INFO - Chain [1] start processing
06:39:08 - cmdstanpy - INFO - Chain [1] start processing
06:39:08 - cmdstanpy - INFO - Chain [1] start processing
06:39:09 - cmdstanpy - INFO - Chain [1] done processing
06:39:09 - cmdstanpy - INFO - Chain [1] done processing
06:39:09 - cmdstanpy - INFO - Chain [1] done processing
06:39:09 - cmdstanpy - INFO - Chain [1] done processing
06:39:09 - cmdstanpy - INFO - Chain [1] done processing
06:39:21 - cmdstanpy - INFO - Chain [1] start processing
06:39:21 - cmdstanpy - INFO - Chain [1] start processing
06:39:22 - cmdstanpy - INFO - Chain [1] start processing
06:39:22 - cmdstanpy - INFO - Chain [1] done processing
06:39:22 - cmdstanpy - INFO - Chain [1

Logged Metrics: 
{
  "mse": 340564777.91752225,
  "rmse": 18179.623413231748,
  "mae": 12593.541194330715,
  "mape": 0.03706827648289505,
  "mdape": 0.025694433404124896,
  "smape": 0.03687997438944849,
  "coverage": 0.38520673668275146
}
Logged Params: 
{
  "growth": "linear",
  "n_changepoints": 25,
  "specified_changepoints": false,
  "changepoint_range": 0.8,
  "yearly_seasonality": "auto",
  "weekly_seasonality": "auto",
  "daily_seasonality": "auto",
  "seasonality_mode": "additive",
  "seasonality_prior_scale": 10.0,
  "changepoint_prior_scale": 0.05,
  "holidays_prior_scale": 10.0,
  "mcmc_samples": 0,
  "interval_width": 0.8,
  "uncertainty_samples": 1000,
  "y_scale": 518253.0,
  "logistic_floor": false,
  "country_holidays": null,
  "component_modes": {
    "additive": [
      "yearly",
      "additive_terms",
      "extra_regressors_additive",
      "holidays"
    ],
    "multiplicative": [
      "multiplicative_terms",
      "extra_regressors_multiplicative"
    ]
  }
}
Mo

In [6]:
loaded_model = mlflow.prophet.load_model(model_uri)

forecast = loaded_model.predict(loaded_model.make_future_dataframe(60))

print(f"forecast:\n${forecast.head(30)}")

  setattr(model, attribute, pd.Timestamp.utcfromtimestamp(model_dict[attribute]))


forecast:
$           ds          trend     yhat_lower     yhat_upper    trend_lower  \
0  1992-01-01  162820.063263  118166.649881  138688.731517  162820.063263   
1  1992-02-01  163871.565185  123400.860208  144044.438588  163871.565185   
2  1992-03-01  164855.228273  158737.323973  180320.923371  164855.228273   
3  1992-04-01  165906.730195  153254.639305  172532.936708  165906.730195   
4  1992-05-01  166924.312701  168505.442049  189376.576692  166924.312701   
5  1992-06-01  167975.814623  161177.737814  180549.360919  167975.814623   
6  1992-07-01  168993.397128  161534.718692  181448.304837  168993.397128   
7  1992-08-01  170044.899050  168502.285194  189169.600248  170044.899050   
8  1992-09-01  171096.400972  149014.628845  168958.146803  171096.400972   
9  1992-10-01  172113.983477  159547.797557  179608.039499  172113.983477   
10 1992-11-01  173165.485427  162309.471495  182575.543383  173165.485427   
11 1992-12-01  174183.067960  208567.460734  229124.725880  17418