In [0]:
%pip install prophet

Collecting prophet
  Downloading prophet-1.1.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.5 kB)
Collecting cmdstanpy>=1.0.4 (from prophet)
  Downloading cmdstanpy-1.2.5-py3-none-any.whl.metadata (4.0 kB)
Collecting holidays<1,>=0.25 (from prophet)
  Downloading holidays-0.79-py3-none-any.whl.metadata (47 kB)
Collecting tqdm>=4.36.1 (from prophet)
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting importlib_resources (from prophet)
  Downloading importlib_resources-6.5.2-py3-none-any.whl.metadata (3.9 kB)
Collecting stanio<2.0.0,>=0.4.0 (from cmdstanpy>=1.0.4->prophet)
  Downloading stanio-0.5.1-py3-none-any.whl.metadata (1.6 kB)
Downloading prophet-1.1.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.4 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/14.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/14.4 MB[0m [31m12.1 MB/s[0m

In [0]:
from prophet import Prophet
import pandas as pd
import numpy as np
from pyspark.sql import functions as F

spark.sql("""
  CREATE OR REPLACE TEMP VIEW franchise_daily AS
  SELECT
    CAST(dateTime AS DATE) AS ds,
    franchiseID,
    SUM(totalPrice)      AS y_revenue,
    COUNT(transactionID) AS y_txn_count
  FROM samples.bakehouse.sales_transactions
  GROUP BY CAST(dateTime AS DATE), franchiseID
""")

daily = spark.table("franchise_daily").toPandas()

#check if all columns exist
expected = {"ds","franchiseID","y_revenue","y_txn_count"}
missing = expected - set(daily.columns)
if missing:
    raise ValueError(f"Missing columns: {missing}")

daily["ds"] = pd.to_datetime(daily["ds"])
daily = daily.sort_values(["franchiseID","ds"]).reset_index(drop=True)

def _fit_forecast(pdf: pd.DataFrame, value_col: str, horizon_days: int = 30,
                  use_cap: bool = True, cap_multiplier: float = 1.25) -> pd.DataFrame:
    out = []
    for fid, g in pdf.groupby("franchiseID"):
        g = g[["ds", value_col]].dropna()

        g = g.rename(columns={value_col: "y"})
        g["floor"] = 0.0

        if use_cap:
            cap_val = max(g["y"].max() * cap_multiplier, 1.0)
            g["cap"] = cap_val
        else:
            g["cap"] = np.inf

        #model
        m = Prophet(
            growth="logistic",              
            daily_seasonality=True,
            weekly_seasonality=True,
            yearly_seasonality=True
        )
        m.fit(g[["ds","y","floor","cap"]])

        future = m.make_future_dataframe(periods=horizon_days, freq="D")
        future["floor"] = 0.0
        future["cap"] = g["cap"].iloc[0]

        fc = m.predict(future)[["ds","yhat","yhat_lower","yhat_upper"]]
        fc[["yhat","yhat_lower","yhat_upper"]] = fc[["yhat","yhat_lower","yhat_upper"]].clip(lower=0)

        fc["franchiseID"] = fid
        fc["target"] = value_col
        out.append(fc)

    if not out:
        return pd.DataFrame(columns=["ds","yhat","yhat_lower","yhat_upper","franchiseID","target"])
    return pd.concat(out, ignore_index=True)

rev_in = daily[["ds","franchiseID","y_revenue"]]
cnt_in = daily[["ds","franchiseID","y_txn_count"]]

fc_rev = _fit_forecast(rev_in, "y_revenue", horizon_days=30, use_cap=True, cap_multiplier=1.25)
fc_cnt = _fit_forecast(cnt_in, "y_txn_count", horizon_days=30, use_cap=True, cap_multiplier=1.25)

forecast_all = pd.concat([fc_rev, fc_cnt], ignore_index=True)

actuals_rev = rev_in.rename(columns={"y_revenue":"y"})
actuals_rev["target"] = "y_revenue"
actuals_rev["is_actual"] = True

actuals_cnt = cnt_in.rename(columns={"y_txn_count":"y"})
actuals_cnt["target"] = "y_txn_count"
actuals_cnt["is_actual"] = True

actuals_all = pd.concat([actuals_rev, actuals_cnt], ignore_index=True)

fc_plot = forecast_all.rename(columns={"yhat":"y"})[["ds","franchiseID","y","target","yhat_lower","yhat_upper"]]
fc_plot["is_actual"] = False

combined = pd.concat([
    actuals_all.assign(yhat_lower=pd.NA, yhat_upper=pd.NA),
    fc_plot
], ignore_index=True)

spark.createDataFrame(combined).createOrReplaceTempView("franchise_sales_actuals_and_forecast")

display(
    spark.sql("""
      SELECT *
      FROM franchise_sales_actuals_and_forecast
      ORDER BY franchiseID, target, ds
    """)
)


03:46:24 - cmdstanpy - INFO - Chain [1] start processing
03:46:26 - cmdstanpy - INFO - Chain [1] done processing
03:46:26 - cmdstanpy - INFO - Chain [1] start processing
03:46:27 - cmdstanpy - INFO - Chain [1] done processing
03:46:27 - cmdstanpy - INFO - Chain [1] start processing
03:46:27 - cmdstanpy - INFO - Chain [1] done processing
03:46:28 - cmdstanpy - INFO - Chain [1] start processing
03:46:31 - cmdstanpy - INFO - Chain [1] done processing
03:46:31 - cmdstanpy - INFO - Chain [1] start processing
03:46:32 - cmdstanpy - INFO - Chain [1] done processing
03:46:32 - cmdstanpy - INFO - Chain [1] start processing
03:46:32 - cmdstanpy - INFO - Chain [1] done processing
03:46:32 - cmdstanpy - INFO - Chain [1] start processing
03:46:33 - cmdstanpy - INFO - Chain [1] done processing
03:46:33 - cmdstanpy - INFO - Chain [1] start processing
03:46:35 - cmdstanpy - INFO - Chain [1] done processing
03:46:35 - cmdstanpy - INFO - Chain [1] start processing
03:46:35 - cmdstanpy - INFO - Chain [1]

ds,franchiseID,y,target,is_actual,yhat_lower,yhat_upper
2024-05-01T00:00:00.000Z,3000000,24.0,y_revenue,True,,
2024-05-01T00:00:00.000Z,3000000,66.89690701603415,y_revenue,False,0.0,146.0437730158573
2024-05-02T00:00:00.000Z,3000000,204.0,y_revenue,True,,
2024-05-02T00:00:00.000Z,3000000,198.1891962574821,y_revenue,False,119.68632899619436,281.2561287660565
2024-05-03T00:00:00.000Z,3000000,282.0,y_revenue,True,,
2024-05-03T00:00:00.000Z,3000000,197.15687452732607,y_revenue,False,118.0394520342465,281.88369297655225
2024-05-04T00:00:00.000Z,3000000,96.0,y_revenue,True,,
2024-05-04T00:00:00.000Z,3000000,79.9042846262064,y_revenue,False,0.0,159.0494592267143
2024-05-05T00:00:00.000Z,3000000,252.0,y_revenue,True,,
2024-05-05T00:00:00.000Z,3000000,253.082824608689,y_revenue,False,173.01452706963303,333.1768270517223
