In [0]:
%pip install prophet
%pip install plotly

In [0]:
# Here we create our Databricks notebooks inputs that Dash will utilize to fill in interactive options
dbutils.widgets.text("us-state", "All States", "State Dropdown")
dbutils.widgets.text("forecast-forward-days", "180", "Forecast days")

In [0]:
# Using Databricks' PySpark interface
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("ProductForecasting").getOrCreate()

# Assuming the dataset is stored in a CSV format (adjust as needed)
product_data = spark.read.csv("/databricks-datasets/retail-org/*", header=True, inferSchema=True)

In [0]:
from pyspark.sql.functions import from_unixtime, to_date

# Grab relevant columns from the source data
selected_data = product_data.select("customer_id", "state", "city", "valid_from", "units_purchased", "loyalty_segment")

# Convert the `valid_from` column from a UNIX timestamp to a date
cleaned_data = selected_data.withColumn("purchase_date", to_date(from_unixtime("valid_from")))

# Show the data with the new date column (COMMENTED OUT)
# cleaned_data.show()


In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, from_unixtime, to_date

spark = SparkSession.builder.appName("ProductForecasting").getOrCreate()

# Assuming the dataset is stored in a CSV format (adjust as needed)
product_data = spark.read.csv("/databricks-datasets/retail-org/*", header=True, inferSchema=True)

selected_data = product_data.select("customer_id", "state", "city", "valid_from", "units_purchased", "loyalty_segment")

# Convert the `valid_from` column from a UNIX timestamp to a date
cleaned_data = selected_data.withColumn("purchase_date", to_date(from_unixtime("valid_from")))
cleaned_data = cleaned_data.filter(col("units_purchased").rlike("^[0-9]+(\.[0-9]+)?$"))

# Filter by state based on the widget value
selected_state = dbutils.widgets.get("us-state")
if selected_state != "All States":
    cleaned_data = cleaned_data.filter(col("state") == selected_state)

# COMMENTING OUT SHOW STATEMENT
# cleaned_data.show()

In [0]:
from prophet import Prophet
from prophet.plot import plot_plotly, plot_components_plotly
from pyspark.sql.functions import to_date

# Convert the Spark DataFrame to Pandas DataFrame
pdf = cleaned_data.select(to_date('purchase_date').alias('ds'), 'units_purchased').toPandas()
pdf['units_purchased'] = pdf['units_purchased'].astype(float)
pdf.rename(columns={'units_purchased': 'y'}, inplace=True)

# Prepare data for Prophet
prophet_df = pdf.groupby('ds').sum().reset_index()

# Initialize Prophet and fit the model
model = Prophet(yearly_seasonality=True, daily_seasonality=True)
model.fit(prophet_df)

# Predict X days into the future
num_days = int(dbutils.widgets.get("forecast-forward-days"))  # Example: predict 30 days into the future; change this value as needed
future = model.make_future_dataframe(periods=num_days)
forecast = model.predict(future)

# Plot the forecasts
fig = plot_plotly(model, forecast)

fig.update_layout(
    autosize=True,
    width=None,  # removing hardcoded width
    height=None,  # removing hardcoded height
)

fig.update_layout(
    # White background
    plot_bgcolor="#F9F7F4",
    paper_bgcolor="#F9F7F4",
    # Titles and fonts
    title="Forecast Results",
    title_font=dict(size=24, family="Arial, sans-serif", color="#1B3139"),
    # Axis labels
    xaxis=dict(
        title="Number of product units",
        titlefont=dict(size=18, color="#1B3139"),
        showgrid=True,
        gridcolor="lightgrey",
        gridwidth=0.5,
        zerolinecolor="lightgrey",
        zerolinewidth=0.5,
        tickfont=dict(size=14, color="#1B3139"),
    ),
    yaxis=dict(
        title="Order date",
        titlefont=dict(size=18, color="#1B3139"),
        showgrid=True,
        gridcolor="lightgrey",
        gridwidth=0.5,
        zerolinecolor="lightgrey",
        zerolinewidth=0.5,
        tickfont=dict(size=14, color="#1B3139"),
    ),
    # Legend styling
    legend=dict(font=dict(size=14, color="grey")),
)
fig

In [0]:
import plotly.tools as tls
import json
import plotly.utils

# Write the Plotly figure to JSON
fig_json = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)

# Store the JSON in Databricks File Storage
path_to_save = "/tmp/forecast_plot.json"
dbutils.fs.put(path_to_save, fig_json, overwrite=True)