In [4]:
from pprint import pprint

from sklearn.ensemble import RandomForestRegressor

from mlflow import MlflowClient

### Initializing the MLflow Client

Depending on where you are running this notebook, your configuration may vary for how you initialize the MLflow Client in the following cell. 

For this example, we're using a locally running tracking server, but other options are available (The easiest is to use the free managed service within [Databricks Community Edition](https://community.cloud.databricks.com/)). 

Please see [the guide to running notebooks here](https://www.mlflow.org/docs/latest/getting-started/running-notebooks/index.html) for more information on setting the tracking server uri and configuring access to either managed or self-managed MLflow tracking servers.

In [2]:
# NOTE: review the links mentioned above for guidance on connecting to a managed tracking server, such as the free Databricks Community Edition

client = MlflowClient(tracking_uri="http://127.0.0.1:8080")

#### Search Experiments with the MLflow Client API

Let's take a look at the Default Experiment that is created for us.

This safe 'fallback' experiment will store Runs that we create if we don't specify a 
new experiment. 

In [5]:
# Search experiments without providing query terms behaves effectively as a 'list' action

try:
	all_experiments = client.search_experiments()
	print(all_experiments)
except Exception as e:
	print(f"An error occurred: {e}")

[<Experiment: artifact_location='mlflow-artifacts:/932508242658220903', creation_time=1734093770908, experiment_id='932508242658220903', last_update_time=1734093770908, lifecycle_stage='active', name='Apple_Models', tags={'mlflow.note.content': 'This is the grocery forecasting project. This '
                        'experiment contains the produce models for apples.',
 'project_name': 'grocery-forecasting',
 'project_quarter': 'Q3-2023',
 'store_dept': 'produce',
 'team': 'stores-ml'}>, <Experiment: artifact_location='mlflow-artifacts:/0', creation_time=1734093687161, experiment_id='0', last_update_time=1734093687161, lifecycle_stage='active', name='Default', tags={}>]


In [6]:
# Extract the experiment name and lifecycle_stage

default_experiment = [
    {"name": experiment.name, "lifecycle_stage": experiment.lifecycle_stage}
    for experiment in all_experiments
    if experiment.name == "Default"
][0]

pprint(default_experiment)

{'lifecycle_stage': 'active', 'name': 'Default'}


### Creating a new Experiment

In this section, we'll:

* create a new MLflow Experiment
* apply metadata in the form of Experiment Tags

In [7]:
experiment_description = (
    "This is the grocery forecasting project. "
    "This experiment contains the produce models for apples."
)

experiment_tags = {
    "project_name": "grocery-forecasting",
    "store_dept": "produce",
    "team": "stores-ml",
    "project_quarter": "Q3-2023",
    "mlflow.note.content": experiment_description,
}

produce_apples_experiment = client.create_experiment(name="Apple_Models", tags=experiment_tags)

RestException: RESOURCE_ALREADY_EXISTS: Experiment 'Apple_Models' already exists.

In [8]:
# Use search_experiments() to search on the project_name tag key

apples_experiment = client.search_experiments(
    filter_string="tags.`project_name` = 'grocery-forecasting'"
)

pprint(apples_experiment[0])

<Experiment: artifact_location='mlflow-artifacts:/932508242658220903', creation_time=1734093770908, experiment_id='932508242658220903', last_update_time=1734093770908, lifecycle_stage='active', name='Apple_Models', tags={'mlflow.note.content': 'This is the grocery forecasting project. This '
                        'experiment contains the produce models for apples.',
 'project_name': 'grocery-forecasting',
 'project_quarter': 'Q3-2023',
 'store_dept': 'produce',
 'team': 'stores-ml'}>


In [9]:
# Access individual tag data

print(apples_experiment[0].tags["team"])

stores-ml


### Running our first model training

In this section, we'll:

* create a synthetic data set that is relevant to a simple demand forecasting task
* start an MLflow run
* log metrics, parameters, and tags to the run
* save the model to the run
* register the model during model logging

#### Synthetic data generator for demand of apples

Keep in mind that this is purely for demonstration purposes. 

The demand value is purely artificial and is deliberately covariant with the features. This is not a particularly realistic real-world scenario (if it were, we wouldn't need Data Scientists!). 

In [10]:
from datetime import datetime, timedelta

import numpy as np
import pandas as pd


def generate_apple_sales_data_with_promo_adjustment(base_demand: int = 1000, n_rows: int = 5000):
    """
    Generates a synthetic dataset for predicting apple sales demand with seasonality and inflation.

    This function creates a pandas DataFrame with features relevant to apple sales.
    The features include date, average_temperature, rainfall, weekend flag, holiday flag,
    promotional flag, price_per_kg, and the previous day's demand. The target variable,
    'demand', is generated based on a combination of these features with some added noise.

    Args:
        base_demand (int, optional): Base demand for apples. Defaults to 1000.
        n_rows (int, optional): Number of rows (days) of data to generate. Defaults to 5000.

    Returns:
        pd.DataFrame: DataFrame with features and target variable for apple sales prediction.

    Example:
        >>> df = generate_apple_sales_data_with_seasonality(base_demand=1200, n_rows=6000)
        >>> df.head()
    """

    # Set seed for reproducibility
    np.random.seed(9999)

    # Create date range
    dates = [datetime.now() - timedelta(days=i) for i in range(n_rows)]
    dates.reverse()

    # Generate features
    df = pd.DataFrame(
        {
            "date": dates,
            "average_temperature": np.random.uniform(10, 35, n_rows),
            "rainfall": np.random.exponential(5, n_rows),
            "weekend": [(date.weekday() >= 5) * 1 for date in dates],
            "holiday": np.random.choice([0, 1], n_rows, p=[0.97, 0.03]),
            "price_per_kg": np.random.uniform(0.5, 3, n_rows),
            "month": [date.month for date in dates],
        }
    )

    # Introduce inflation over time (years)
    df["inflation_multiplier"] = 1 + (df["date"].dt.year - df["date"].dt.year.min()) * 0.03

    # Incorporate seasonality due to apple harvests
    df["harvest_effect"] = np.sin(2 * np.pi * (df["month"] - 3) / 12) + np.sin(
        2 * np.pi * (df["month"] - 9) / 12
    )

    # Modify the price_per_kg based on harvest effect
    df["price_per_kg"] = df["price_per_kg"] - df["harvest_effect"] * 0.5

    # Adjust promo periods to coincide with periods lagging peak harvest by 1 month
    peak_months = [4, 10]  # months following the peak availability
    df["promo"] = np.where(
        df["month"].isin(peak_months),
        1,
        np.random.choice([0, 1], n_rows, p=[0.85, 0.15]),
    )

    # Generate target variable based on features
    base_price_effect = -df["price_per_kg"] * 50
    seasonality_effect = df["harvest_effect"] * 50
    promo_effect = df["promo"] * 200

    df["demand"] = (
        base_demand
        + base_price_effect
        + seasonality_effect
        + promo_effect
        + df["weekend"] * 300
        + np.random.normal(0, 50, n_rows)
    ) * df["inflation_multiplier"]  # adding random noise

    # Add previous day's demand
    df["previous_days_demand"] = df["demand"].shift(1)
    df["previous_days_demand"].fillna(method="bfill", inplace=True)  # fill the first row

    # Drop temporary columns
    df.drop(columns=["inflation_multiplier", "harvest_effect", "month"], inplace=True)

    return df

In [11]:
# Generate the dataset!

data = generate_apple_sales_data_with_promo_adjustment(base_demand=1_000, n_rows=1_000)

data[-20:]

The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  df["previous_days_demand"].fillna(method="bfill", inplace=True)  # fill the first row
  df["previous_days_demand"].fillna(method="bfill", inplace=True)  # fill the first row


Unnamed: 0,date,average_temperature,rainfall,weekend,holiday,price_per_kg,promo,demand,previous_days_demand
980,2024-11-24 15:13:31.894357,34.130183,1.454065,1,0,1.449177,0,1289.802447,1319.085782
981,2024-11-25 15:13:31.894357,32.353643,9.462859,0,0,2.856503,0,818.951553,1289.802447
982,2024-11-26 15:13:31.894357,18.816833,0.39147,0,0,1.326429,0,963.352029,818.951553
983,2024-11-27 15:13:31.894357,34.533012,2.120477,0,0,0.970131,0,1039.385504,963.352029
984,2024-11-28 15:13:31.894357,23.057202,2.365705,0,0,1.049931,0,991.427049,1039.385504
985,2024-11-29 15:13:31.894357,34.810165,3.089005,0,0,2.035149,0,974.971149,991.427049
986,2024-11-30 15:13:31.894357,29.208905,3.673292,1,0,2.518098,0,1374.249547,974.971149
987,2024-12-01 15:13:31.894357,16.428676,4.077782,1,0,1.268979,0,1381.118915,1374.249547
988,2024-12-02 15:13:31.894357,32.067512,2.734454,0,0,0.762317,0,1040.492007,1381.118915
989,2024-12-03 15:13:31.894357,31.938203,13.883486,0,0,1.153301,0,967.04047,1040.492007


### Train and log the model

We're now ready to import our model class and train a ``RandomForestRegressor``

In [13]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import train_test_split

import mlflow

# Use the fluent API to set the tracking uri and the active experiment
mlflow.set_tracking_uri("http://127.0.0.1:8080")

# Sets the current active experiment to the "Apple_Models" experiment and returns the Experiment metadata
apple_experiment = mlflow.set_experiment("Apple_Models")

# Define a run name for this iteration of training.
# If this is not set, a unique name will be auto-generated for your run.
run_name = "apples_rf_test"

# Define an artifact path that the model will be saved to.
artifact_path = "rf_apples"

In [14]:
# Split the data into features and target and drop irrelevant date field and target field
X = data.drop(columns=["date", "demand"])
y = data["demand"]

# Split the data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

params = {
    "n_estimators": 100,
    "max_depth": 6,
    "min_samples_split": 10,
    "min_samples_leaf": 4,
    "bootstrap": True,
    "oob_score": False,
    "random_state": 888,
}

# Train the RandomForestRegressor
rf = RandomForestRegressor(**params)

# Fit the model on the training data
rf.fit(X_train, y_train)

# Predict on the validation set
y_pred = rf.predict(X_val)

# Calculate error metrics
mae = mean_absolute_error(y_val, y_pred)
mse = mean_squared_error(y_val, y_pred)
rmse = np.sqrt(mse)
r2 = r2_score(y_val, y_pred)

# Assemble the metrics we're going to write into a collection
metrics = {"mae": mae, "mse": mse, "rmse": rmse, "r2": r2}

# Initiate the MLflow run context
with mlflow.start_run(run_name=run_name) as run:
    # Log the parameters used for the model fit
    mlflow.log_params(params)

    # Log the error metrics that were calculated during validation
    mlflow.log_metrics(metrics)

    # Log an instance of the trained model for later use
    mlflow.sklearn.log_model(sk_model=rf, input_example=X_val, artifact_path=artifact_path)



🏃 View run apples_rf_test at: http://127.0.0.1:8080/#/experiments/932508242658220903/runs/b2e1a6ee823445fca8369a54db8f43f9
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/932508242658220903


In [18]:
from mlflow.models import validate_serving_input

model_uri = 'runs:/b2e1a6ee823445fca8369a54db8f43f9/rf_apples'

# The model is logged with an input example. MLflow converts
# it into the serving payload format for the deployed model endpoint,
# and saves it to 'serving_input_payload.json'
serving_payload = """{
  "dataframe_split": {
    "columns": [
      "average_temperature",
      "rainfall",
      "weekend",
      "holiday",
      "price_per_kg",
      "promo",
      "previous_days_demand"
    ],
    "data": [
      [
        22.108521493396175,
        4.999862802843424,
        0,
        0,
        0.5831366510376563,
        0,
        1051.328705658278
      ],
      [
        18.5816764411248,
        4.512746531416022,
        0,
        0,
        1.8072856251094098,
        1,
        998.9237425630752
      ],
      [
        33.51975007333654,
        2.8241536413050925,
        0,
        0,
        2.3637409774787232,
        0,
        923.2566397019071
      ],
      [
        21.7615434554752,
        6.430238857667999,
        0,
        0,
        2.0476852776923393,
        0,
        1179.562235728095
      ],
      [
        10.894443056369184,
        0.13964898317696367,
        0,
        0,
        1.1940921516527256,
        0,
        901.3318652793271
      ],
      [
        22.92560063118369,
        7.565755558113839,
        1,
        0,
        2.72148050252662,
        0,
        1100.8492474590696
      ],
      [
        27.970269869665415,
        4.832402669922336,
        0,
        0,
        0.5034190641984683,
        1,
        951.1984757786024
      ],
      [
        24.304273996732025,
        4.3762328748638355,
        0,
        0,
        2.937326926420265,
        0,
        926.8536546124926
      ],
      [
        17.224883481474926,
        15.160548234833247,
        0,
        0,
        2.2034069059056702,
        0,
        834.3910464104001
      ],
      [
        23.29153787550105,
        16.796970135192815,
        0,
        0,
        1.0719556306440046,
        0,
        833.3675169174304
      ],
      [
        13.970545295560433,
        13.469018114664433,
        1,
        0,
        2.8525732698805286,
        0,
        927.0436017995498
      ],
      [
        14.43116012323845,
        0.3653052706104143,
        1,
        0,
        1.8057458814700333,
        0,
        952.2577441169552
      ],
      [
        26.22875718601829,
        0.3453453261157885,
        1,
        0,
        2.198581232446027,
        0,
        881.5704307147403
      ],
      [
        31.474589577787594,
        11.953320045310178,
        1,
        0,
        2.7163756001662582,
        0,
        1244.9540014780737
      ],
      [
        28.84662912211835,
        2.3461959259254677,
        1,
        0,
        2.328025428070177,
        1,
        1492.5460100498874
      ],
      [
        25.053224454895386,
        18.05058099432231,
        0,
        0,
        0.8102362066443983,
        0,
        953.4822329720608
      ],
      [
        29.681849610300794,
        22.69284963979996,
        1,
        0,
        2.2563321799601184,
        1,
        1259.4688509186276
      ],
      [
        29.28975979890451,
        1.2759921348939778,
        0,
        0,
        1.3511310938791963,
        0,
        1154.1672778915063
      ],
      [
        29.317018341946035,
        5.669994191329914,
        0,
        0,
        2.011116359440318,
        1,
        1208.1869871145916
      ],
      [
        24.85223471221545,
        3.1958696527776524,
        0,
        0,
        2.937301302853353,
        0,
        866.4044810170313
      ],
      [
        12.351825611550034,
        6.239153247146133,
        0,
        0,
        0.5687943730193393,
        0,
        978.6256069661443
      ],
      [
        16.004647882132986,
        1.1495763234310585,
        1,
        0,
        2.4864912310749654,
        0,
        936.9901884058146
      ],
      [
        21.43846192245929,
        1.986386725133165,
        1,
        0,
        2.242945480574978,
        0,
        1172.6837230310457
      ],
      [
        32.40053258233264,
        5.370225765171288,
        0,
        0,
        1.1239631003366415,
        0,
        892.2669331439647
      ],
      [
        32.793408413642155,
        8.65443209041338,
        1,
        0,
        0.8773769903916603,
        1,
        1481.6471673405842
      ],
      [
        21.544217839491846,
        38.19607596973325,
        0,
        0,
        0.9527734716576302,
        0,
        943.7140666566439
      ],
      [
        10.761733503766798,
        4.954419615831154,
        0,
        0,
        1.371643996117502,
        0,
        916.0784683848129
      ],
      [
        29.208904708840713,
        3.6732923207851043,
        1,
        0,
        2.5180977727750076,
        0,
        974.971148569308
      ],
      [
        22.17096614080456,
        0.42054670120468496,
        1,
        0,
        1.7323442890597849,
        0,
        883.9147998290535
      ],
      [
        18.93043464969746,
        10.961642332323407,
        0,
        0,
        2.7372853814996363,
        1,
        1274.1183629586462
      ],
      [
        26.634091714791175,
        32.83464278387956,
        0,
        0,
        1.9714276794194812,
        0,
        964.9889771709173
      ],
      [
        17.221157212954072,
        5.306922837042894,
        1,
        0,
        2.156810616033737,
        0,
        868.1978580039328
      ],
      [
        14.710160700623618,
        5.08152279841387,
        0,
        0,
        0.6400372911749268,
        0,
        1010.3970150061883
      ],
      [
        28.138953323004632,
        0.9092279904046898,
        0,
        0,
        1.0364080156551636,
        0,
        1106.300800617545
      ],
      [
        29.695464966640767,
        17.229408962374052,
        1,
        0,
        2.345597780879591,
        0,
        1047.3889622570882
      ],
      [
        28.984015657400555,
        9.254755024707384,
        0,
        0,
        2.622916814035615,
        1,
        1230.9790487826415
      ],
      [
        24.13033280012668,
        13.346697530251326,
        0,
        0,
        2.1209145353219427,
        0,
        1269.5115494741046
      ],
      [
        19.675353604210304,
        1.400553733271207,
        0,
        0,
        0.6300314236106394,
        1,
        1177.1194783989476
      ],
      [
        32.35591583782326,
        2.9245168241988386,
        0,
        0,
        2.5753987125795184,
        0,
        945.8482360065685
      ],
      [
        34.429233249851194,
        1.3603242265053148,
        0,
        0,
        1.1766736582280664,
        0,
        1083.0307459027474
      ],
      [
        22.80873719666382,
        4.506542723551091,
        1,
        0,
        1.8995332999809005,
        0,
        1110.6266504914763
      ],
      [
        31.014553019754942,
        3.910451560665037,
        0,
        0,
        2.372679709541895,
        0,
        854.466881438221
      ],
      [
        19.55663335992133,
        13.36205668127347,
        0,
        0,
        2.1412675158489636,
        0,
        1083.308128268899
      ],
      [
        24.611129081487018,
        3.7992599574614654,
        0,
        0,
        2.0445407004154648,
        0,
        1197.0330053952132
      ],
      [
        30.5305626939251,
        8.524914148649831,
        0,
        0,
        1.1465717118133942,
        0,
        962.3063254295041
      ],
      [
        32.08505345627441,
        7.179913108581863,
        0,
        0,
        1.3305808759565187,
        0,
        923.4239511119424
      ],
      [
        20.53535410101436,
        16.24430730508892,
        1,
        0,
        1.2845130738231754,
        0,
        1286.9665455708932
      ],
      [
        20.61422689345525,
        9.1838022439914,
        1,
        0,
        2.333960339506183,
        0,
        1294.4358275135962
      ],
      [
        17.21798666288529,
        4.690641917379042,
        0,
        0,
        1.3633392262933977,
        0,
        870.6249479018355
      ],
      [
        30.16454024986868,
        0.6795598605579917,
        0,
        0,
        2.674786642263345,
        0,
        946.056869697844
      ],
      [
        18.457686015510724,
        0.9980518802229248,
        0,
        0,
        2.552776081721043,
        1,
        1176.0950738060158
      ],
      [
        29.285065229638317,
        8.08644790570472,
        0,
        0,
        1.0664183015085633,
        0,
        907.0109682207168
      ],
      [
        23.510515741582317,
        4.975060346735287,
        0,
        0,
        1.9006178350585912,
        1,
        860.5664785216209
      ],
      [
        12.092716597801711,
        1.9817241351744879,
        0,
        0,
        2.8830089009938193,
        1,
        1203.1924599355768
      ],
      [
        23.77576416567278,
        0.8787413251084308,
        0,
        0,
        0.8970858255284704,
        1,
        1120.6907323105534
      ],
      [
        18.33262084556379,
        2.6086843703063396,
        0,
        0,
        2.3570538921181923,
        1,
        1043.607501325181
      ],
      [
        28.42827040391559,
        11.330547693659067,
        0,
        0,
        0.6499720658422015,
        1,
        1174.4268022570368
      ],
      [
        32.82623697937525,
        1.6896753641896576,
        0,
        0,
        1.6183631953443873,
        0,
        1037.9943254945838
      ],
      [
        26.819472830721676,
        12.816325735787856,
        0,
        0,
        2.6833120501522902,
        1,
        963.3955707700563
      ],
      [
        17.29031894211911,
        5.423436937938916,
        0,
        0,
        2.9765195602740055,
        0,
        1267.591960762133
      ],
      [
        27.22695793817839,
        0.4853122855753981,
        0,
        0,
        2.55751105219743,
        0,
        847.3943715721261
      ],
      [
        22.577278117642532,
        0.8473811881774718,
        0,
        0,
        2.013186705020111,
        0,
        802.1218254079765
      ],
      [
        20.131178371812325,
        1.3734413966308243,
        0,
        0,
        2.8817076907294243,
        0,
        925.7247313015956
      ],
      [
        19.229374319897968,
        26.046962839951412,
        1,
        0,
        1.0329166753500156,
        1,
        1125.6373937022563
      ],
      [
        13.808789907206078,
        1.7847971755654832,
        0,
        0,
        2.881374455744681,
        0,
        1261.5947442525023
      ],
      [
        30.443774749059422,
        15.741158804343108,
        1,
        0,
        1.8679595911800218,
        0,
        1034.7126103781677
      ],
      [
        30.908415318655074,
        5.409149815353188,
        0,
        0,
        2.6523059009311387,
        1,
        1022.9125502042689
      ],
      [
        24.9118484273987,
        3.094775312564415,
        0,
        0,
        2.264513470709355,
        0,
        972.8555858916286
      ],
      [
        15.411269073238861,
        1.2579356599019835,
        1,
        0,
        2.8215892096455066,
        0,
        1248.002143146448
      ],
      [
        23.521504881291865,
        9.038219508571641,
        0,
        0,
        2.143561229123568,
        0,
        940.6959397125279
      ],
      [
        10.191715361143496,
        2.7555152212649423,
        0,
        0,
        1.6309370952980293,
        1,
        960.3127878649954
      ],
      [
        10.252095785901126,
        1.5616519142432521,
        0,
        0,
        1.3160125317497489,
        0,
        922.4293179281531
      ],
      [
        13.038741982697,
        9.054787861241568,
        0,
        0,
        0.6874408222634198,
        1,
        1156.4771227228684
      ],
      [
        11.221298948928442,
        8.567609144750413,
        1,
        0,
        2.7679552471840037,
        1,
        977.7017937479728
      ],
      [
        30.7937984095172,
        13.040614943735715,
        0,
        0,
        2.789349211507519,
        1,
        928.5107685198442
      ],
      [
        22.668805259819873,
        1.91699997884316,
        0,
        0,
        2.4326477721597675,
        0,
        888.6780485869272
      ],
      [
        21.15368777344179,
        2.9212701855513643,
        0,
        0,
        1.8152806273657491,
        0,
        1222.019345894772
      ],
      [
        15.588290478749924,
        0.17473897993069165,
        0,
        0,
        1.7561351494267992,
        1,
        1109.3665112127276
      ],
      [
        29.359336571214584,
        0.63977581601823,
        0,
        0,
        1.2822758550306448,
        1,
        1470.286195607608
      ],
      [
        16.265527030847366,
        1.188112074342612,
        1,
        0,
        2.922990843938428,
        0,
        858.9622698061885
      ],
      [
        32.14648281895079,
        4.493310252465343,
        1,
        0,
        2.4600614439225033,
        0,
        1274.242245400353
      ],
      [
        30.411808112353285,
        2.1924798493533113,
        0,
        0,
        2.9808101520075523,
        0,
        947.7787959124908
      ],
      [
        29.352578784902754,
        9.427013580052922,
        0,
        0,
        1.3004811464730008,
        0,
        988.7356182199699
      ],
      [
        23.971565217583112,
        4.755236719416376,
        0,
        0,
        1.197550406134796,
        0,
        819.4084648647976
      ],
      [
        13.05748558984076,
        13.983965450755612,
        0,
        0,
        2.2031959166093724,
        0,
        914.3087593126703
      ],
      [
        31.053250746369688,
        20.442206107202434,
        0,
        0,
        2.1556778217105443,
        0,
        1036.1985073649266
      ],
      [
        29.4513006144873,
        5.0214629838405855,
        0,
        0,
        2.493084651779208,
        0,
        1034.4223723493196
      ],
      [
        19.581214534115787,
        8.397297493500734,
        1,
        0,
        1.5502308520307766,
        0,
        1311.8288844734213
      ],
      [
        16.88752439573821,
        2.6724555216908685,
        0,
        0,
        1.4841428576189353,
        1,
        1202.4113886634573
      ],
      [
        15.28141050098759,
        0.08170770789706876,
        0,
        0,
        1.2622526114498314,
        1,
        917.1619140787003
      ],
      [
        28.975649513150216,
        0.5469303183960459,
        0,
        0,
        1.8208457456474942,
        1,
        1170.7816848434784
      ],
      [
        23.803281409632476,
        2.6897398875975664,
        0,
        0,
        2.770200445859715,
        0,
        1150.047878911875
      ],
      [
        25.936016457757887,
        0.10457802708676972,
        0,
        0,
        2.1564707784780848,
        0,
        966.5389313770353
      ],
      [
        30.809603561173674,
        1.5734029042895805,
        0,
        0,
        2.3174412643546702,
        1,
        1486.4238701908637
      ],
      [
        10.747677017006543,
        2.281940605967524,
        0,
        0,
        2.9502518761144296,
        0,
        894.7666990042491
      ],
      [
        25.080426329692088,
        7.34247936621006,
        0,
        0,
        1.1999953074304408,
        0,
        1184.2963978076432
      ],
      [
        12.557020505509222,
        0.5120047620740619,
        0,
        0,
        2.422139905353877,
        1,
        1645.8293742992657
      ],
      [
        12.067055264456702,
        3.000936129900939,
        0,
        0,
        1.049961743123666,
        0,
        1028.1689732569196
      ],
      [
        10.92079474955222,
        2.076343461507871,
        1,
        0,
        2.9447247976536968,
        0,
        1254.6031109991168
      ],
      [
        24.72722650913,
        4.358486611134337,
        1,
        0,
        2.3574936485998705,
        0,
        1263.3516774741918
      ],
      [
        24.975509764751738,
        7.673254147231514,
        0,
        0,
        0.5920340913715826,
        1,
        931.605842141436
      ],
      [
        25.872462363772414,
        7.566978424567661,
        0,
        0,
        2.2470044871200225,
        0,
        1253.0299084655237
      ],
      [
        27.548036805999402,
        5.653288590244777,
        0,
        0,
        1.6051786770056244,
        1,
        1204.9220412027273
      ],
      [
        22.587634137114613,
        6.956538160793878,
        0,
        0,
        2.2620834066954423,
        1,
        916.8978164671827
      ],
      [
        16.698799184584274,
        0.22978037622273473,
        0,
        0,
        2.535241061030907,
        0,
        953.2715956545204
      ],
      [
        20.168993577598137,
        0.3711241639788016,
        1,
        0,
        0.6244656154610311,
        0,
        1324.946342491275
      ],
      [
        27.885111538443507,
        7.642592650870115,
        0,
        0,
        0.9334835405271763,
        0,
        1359.07587464401
      ],
      [
        30.610707640188313,
        1.9904762333531818,
        1,
        0,
        1.2680145271930017,
        0,
        1358.1155492490495
      ],
      [
        15.94750252880475,
        1.5837625632109993,
        1,
        0,
        0.6522837906051349,
        0,
        1300.2710977922022
      ],
      [
        29.123715726344788,
        4.335304605907174,
        1,
        0,
        2.829572356102825,
        0,
        1282.4824291046152
      ],
      [
        14.543606103917647,
        7.826779056083347,
        0,
        0,
        1.1266669482864173,
        0,
        994.0453135678556
      ],
      [
        23.925063192924377,
        0.807550516239855,
        0,
        0,
        1.2580387661843484,
        1,
        1151.704991890692
      ],
      [
        34.81016487077909,
        3.089004633196642,
        0,
        0,
        2.03514942504146,
        0,
        991.4270493783655
      ],
      [
        13.095000154198853,
        2.646102737611062,
        0,
        0,
        0.7266508782108556,
        0,
        989.4649668540077
      ],
      [
        15.3861851118143,
        17.454546770258368,
        0,
        0,
        2.295232891257471,
        0,
        824.3636502727462
      ],
      [
        15.60728735463093,
        5.913457783778428,
        0,
        0,
        1.236183300462886,
        0,
        933.4947798954165
      ],
      [
        18.406009082460088,
        2.090345863094241,
        0,
        0,
        2.3789247402420823,
        0,
        859.6929427271624
      ],
      [
        28.985655586335298,
        8.354725777384974,
        0,
        0,
        1.9618718080686404,
        0,
        1158.6009855992825
      ],
      [
        23.997005326808267,
        5.567135908917223,
        0,
        0,
        1.5364895146552473,
        0,
        929.8005402910641
      ],
      [
        16.93453141367946,
        1.589775681254445,
        1,
        0,
        2.710805838400959,
        0,
        1182.1273701067466
      ],
      [
        29.85436616781977,
        3.6934772404378027,
        0,
        0,
        1.3247531363684137,
        0,
        935.5658476693709
      ],
      [
        33.5749293305596,
        0.34878626925884026,
        0,
        0,
        1.5909463180135992,
        0,
        965.110089764852
      ],
      [
        18.751261387914,
        11.315686921208751,
        1,
        0,
        1.3968195362033047,
        0,
        953.1538915702718
      ],
      [
        23.690749651406044,
        0.9838096025781334,
        0,
        0,
        1.629978720433022,
        0,
        1268.4570815092236
      ],
      [
        27.36436540579858,
        1.6095678722570452,
        1,
        0,
        2.485682318145182,
        0,
        893.6191777274109
      ],
      [
        14.509493015379451,
        3.3584219372201307,
        1,
        0,
        1.5764734674568097,
        0,
        1102.370151810782
      ],
      [
        28.22664669088436,
        2.846734465349468,
        0,
        0,
        2.502747928289,
        1,
        1199.3675065412806
      ],
      [
        27.51726510668766,
        12.239872770926736,
        0,
        0,
        0.5135793722426331,
        0,
        1017.9503878998721
      ],
      [
        21.041017983881915,
        20.196503134789307,
        1,
        0,
        2.174707480258091,
        0,
        971.636069805646
      ],
      [
        31.374652391836065,
        1.9959414014642631,
        0,
        0,
        0.6109285497384578,
        0,
        1037.2819576915733
      ],
      [
        26.427373058543196,
        1.107466380236168,
        0,
        0,
        0.7449173814446428,
        0,
        964.2081215469492
      ],
      [
        24.29775074367221,
        0.13324790859422891,
        0,
        0,
        0.5945321046785268,
        0,
        1006.2713430076722
      ],
      [
        24.90734589887265,
        9.978131646218602,
        1,
        0,
        0.7978973708757173,
        0,
        1170.9564745508562
      ],
      [
        28.661072193798642,
        10.329865041540028,
        0,
        0,
        2.2905906629208417,
        0,
        1378.5763111594313
      ],
      [
        33.70870790183085,
        10.05144692942915,
        0,
        0,
        1.4532929018969003,
        0,
        911.0177096888465
      ],
      [
        26.771346242492438,
        0.36446939847994736,
        1,
        0,
        1.5029147348961605,
        1,
        1001.1014125244027
      ],
      [
        25.908731930876137,
        1.9996811900480105,
        0,
        0,
        2.6626089849532044,
        0,
        1056.8675567763726
      ],
      [
        16.596037676477067,
        2.8034171418464555,
        0,
        0,
        1.4071021436360223,
        1,
        1173.1971224618226
      ],
      [
        20.891628113637136,
        0.8888138167412301,
        0,
        0,
        0.8618641892249961,
        1,
        1088.069187541359
      ],
      [
        17.84403336185514,
        6.956866009730013,
        0,
        0,
        2.4522841543971636,
        0,
        875.3827805280149
      ],
      [
        25.88513582789545,
        2.1772571988211955,
        0,
        0,
        1.2388577550192321,
        0,
        883.7907824880645
      ],
      [
        25.207287333131998,
        18.393240804781602,
        0,
        0,
        2.9717656683827913,
        1,
        878.7205213185906
      ],
      [
        24.951835520216438,
        12.843135771432095,
        1,
        0,
        2.669665476709379,
        0,
        1309.3538279663817
      ],
      [
        14.160536471560146,
        5.77321810302736,
        0,
        0,
        1.410954079524896,
        0,
        1385.7227247819187
      ],
      [
        11.082190270288358,
        7.781476800172927,
        0,
        0,
        1.394455051861505,
        0,
        1012.0985825014927
      ],
      [
        23.132015966346053,
        13.864774875305217,
        0,
        0,
        1.8133593837203439,
        0,
        891.8906760113558
      ],
      [
        29.05469686485519,
        0.2683912355645199,
        0,
        0,
        2.606258252764605,
        0,
        1072.8683580668733
      ],
      [
        11.228267509262832,
        4.624626076069966,
        1,
        0,
        2.9408448965103338,
        0,
        1004.5438439824329
      ],
      [
        33.61226474490128,
        12.288102113265932,
        0,
        0,
        1.1084312000651826,
        1,
        1011.4792555175707
      ],
      [
        25.46813810051694,
        3.939187621884701,
        1,
        0,
        0.6875757983108931,
        0,
        894.3186455288161
      ],
      [
        10.283599710975613,
        1.8557648602243846,
        1,
        0,
        0.5402323706655991,
        0,
        998.3571209049486
      ],
      [
        33.29672725623779,
        1.797184619314932,
        0,
        0,
        2.436981227164976,
        0,
        871.5616511950544
      ],
      [
        19.55360564582868,
        1.5438370836205184,
        0,
        0,
        0.5582257547578429,
        0,
        940.7059269987569
      ],
      [
        12.067791424802891,
        2.9469165986181056,
        0,
        0,
        2.7471703633540527,
        0,
        884.0931050246614
      ],
      [
        26.59227604132641,
        5.212865596071465,
        0,
        1,
        0.9538805221895186,
        1,
        1057.546629916174
      ],
      [
        15.007808544570825,
        11.855552647205291,
        1,
        0,
        2.518514634546121,
        0,
        1567.612214534364
      ],
      [
        20.918072205789535,
        6.922467941088819,
        0,
        0,
        2.285479009224534,
        0,
        1226.8273143376775
      ],
      [
        17.389889298844174,
        0.003891801055316009,
        0,
        0,
        2.421435594445552,
        0,
        1064.9152542092959
      ],
      [
        13.2375456619273,
        24.823390553141223,
        0,
        0,
        1.7670781194096967,
        0,
        957.0598855892824
      ],
      [
        27.679400446122326,
        0.1288357299746379,
        0,
        0,
        1.65674518022293,
        0,
        922.1778246908694
      ],
      [
        25.856942878179876,
        6.539644750495073,
        1,
        0,
        1.0888260283291786,
        0,
        1558.7498283437012
      ],
      [
        19.161213495319892,
        7.808770349086581,
        0,
        0,
        2.668954516328073,
        1,
        1238.3367981201427
      ],
      [
        24.668555544653042,
        7.361391877932926,
        0,
        0,
        2.9646211494155206,
        0,
        934.9766732560884
      ],
      [
        13.84466757905626,
        9.102270187821022,
        1,
        0,
        1.4056691373353833,
        0,
        876.6709700006564
      ],
      [
        20.86769981927791,
        9.420274418798622,
        0,
        0,
        1.5419780869798547,
        0,
        915.8434892513626
      ],
      [
        12.577945863570358,
        2.228281997093106,
        0,
        0,
        1.76611147875859,
        1,
        799.0048603155748
      ],
      [
        25.368869095338958,
        5.633690888132435,
        0,
        0,
        1.2321207564622176,
        0,
        933.5192052695807
      ],
      [
        15.261990970090917,
        0.6964433989669181,
        0,
        0,
        1.2597609370021376,
        0,
        834.9176528177843
      ],
      [
        31.9852871729979,
        0.8058769124762932,
        1,
        0,
        0.9395601108051945,
        0,
        1033.5772166889076
      ],
      [
        11.095879814993491,
        7.2425658298612925,
        0,
        0,
        1.7403939577544512,
        0,
        971.2483392219427
      ],
      [
        28.747165238416766,
        11.135497041941534,
        0,
        0,
        2.241175917574339,
        0,
        1011.3682391121076
      ],
      [
        29.16683854601734,
        1.7817884500033612,
        0,
        0,
        2.31306699252215,
        0,
        1029.3329991646574
      ],
      [
        23.594121243075794,
        7.886189234074968,
        0,
        0,
        1.013020839720443,
        0,
        997.0529126032878
      ],
      [
        31.938202702891232,
        13.883486154214955,
        0,
        0,
        1.1533012486477838,
        0,
        1040.49200748148
      ],
      [
        11.492979301022793,
        0.1534744307569143,
        0,
        0,
        2.592652940242667,
        1,
        985.0997332293492
      ],
      [
        19.17792207181764,
        0.42353577858551456,
        1,
        0,
        2.798863824275413,
        1,
        920.2357176041074
      ],
      [
        12.483095894781972,
        2.8672557865796815,
        0,
        0,
        0.8294037859722545,
        0,
        933.2409192103671
      ],
      [
        29.214988894335665,
        3.2985320677173298,
        1,
        0,
        0.8080822896243616,
        0,
        1306.8216506904625
      ],
      [
        32.658846568321934,
        3.3596804295795546,
        0,
        1,
        1.7940367917452902,
        0,
        906.8671254281309
      ],
      [
        17.96562909672644,
        1.6218495696502067,
        0,
        0,
        0.6005046364810538,
        0,
        906.5549484942018
      ],
      [
        17.90738951309897,
        8.305736388285478,
        0,
        0,
        1.8745334356348458,
        0,
        945.8692814554005
      ],
      [
        10.729638808602243,
        5.5094915566119145,
        0,
        0,
        0.7070385254731941,
        1,
        1050.2418319761248
      ],
      [
        21.430898879013487,
        5.64042117151491,
        1,
        0,
        1.3381749846419178,
        1,
        945.6478583139993
      ],
      [
        32.19147135812911,
        5.601062136667486,
        0,
        0,
        2.4984728218613625,
        0,
        1302.6426069510676
      ],
      [
        31.91378514388379,
        9.091574722716189,
        0,
        0,
        0.5321196956370737,
        0,
        949.1732909890336
      ],
      [
        21.708940966306812,
        8.233107547443158,
        0,
        0,
        2.015120624589527,
        0,
        1043.997396962466
      ],
      [
        27.044475380555966,
        1.8390131913027932,
        0,
        0,
        2.2233071193639566,
        0,
        846.5829457588546
      ],
      [
        14.847890019924932,
        2.946541459559287,
        0,
        0,
        1.9145443509869122,
        0,
        1388.3782983954459
      ],
      [
        28.49907420228602,
        0.3564672253908388,
        0,
        0,
        2.512272075350981,
        1,
        1155.3997452469903
      ],
      [
        27.43195082906649,
        12.054383857751938,
        1,
        1,
        2.7661183234439775,
        1,
        1494.1875337600263
      ],
      [
        11.575088687984616,
        5.870562962466059,
        0,
        0,
        2.353714200758234,
        1,
        1165.178289155776
      ],
      [
        11.682434819350162,
        12.289391314755921,
        0,
        0,
        1.8308565965097778,
        0,
        939.8089795585657
      ],
      [
        12.769594970823299,
        11.940160731486332,
        0,
        1,
        2.602904645445574,
        0,
        989.7363630870133
      ],
      [
        31.44347599563063,
        2.9529361929336764,
        0,
        0,
        1.3971407828357378,
        1,
        1037.815557402177
      ],
      [
        14.571487420957624,
        4.750893975027972,
        0,
        0,
        1.5352147002619279,
        1,
        1220.7526446881486
      ],
      [
        33.7037656488921,
        5.5958010578880995,
        0,
        0,
        0.7944130795068962,
        0,
        1014.6724977122175
      ],
      [
        20.146641498527885,
        3.042891652793189,
        0,
        0,
        2.2192204798869684,
        0,
        866.9508295027033
      ],
      [
        26.630056025805125,
        1.8118656265838866,
        0,
        0,
        2.8345748164486095,
        1,
        1104.3299136570588
      ],
      [
        11.618436599352332,
        4.947186311276282,
        0,
        0,
        2.828524887054267,
        1,
        876.6330034289858
      ],
      [
        13.19727350639349,
        13.107930500087328,
        0,
        0,
        2.0567446584474904,
        0,
        1189.3974550462199
      ]
    ]
  }
}"""

# Validate the serving payload works on the model
validate_serving_input(model_uri, serving_payload)

array([ 987.48318549, 1114.15466894,  912.88735455,  910.9066891 ,
        961.10582016, 1176.59388616, 1166.39415581,  903.15917518,
        891.64402313,  973.48230807, 1212.81527293, 1275.93515111,
       1248.5004357 , 1168.67226134, 1401.41355159,  976.90626717,
       1398.76445223,  971.74688703, 1161.69432941,  890.55538691,
       1019.5512567 , 1210.15764514, 1237.747613  ,  979.04316723,
       1553.59108408,  975.47357846,  980.63244688, 1193.3114078 ,
       1267.51713585, 1096.08421477,  938.44551208, 1237.32748066,
        992.80431599,  974.05170711, 1201.6216651 , 1132.93734691,
        918.73259208, 1189.69752409,  908.58634911,  978.05622473,
       1269.54173206,  893.56537442,  922.26986379,  917.37377946,
        987.01758421,  971.93765461, 1290.36741436, 1184.90534839,
        964.27676139,  895.22588785, 1092.60798386,  974.51864101,
       1098.92413091, 1102.65702236, 1209.86582887, 1086.39658139,
       1207.75347507,  950.53998414, 1107.39651109,  891.71428

In [19]:
logged_model = 'runs:/b2e1a6ee823445fca8369a54db8f43f9/rf_apples'

# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(logged_model)

# Predict on a Pandas DataFrame.
loaded_model.predict(pd.DataFrame(data))



array([1167.61984826,  941.13886303,  975.02560985,  986.48419675,
        987.99115034,  888.85894315, 1268.98093356, 1274.51303265,
        971.80912428,  987.76494552,  950.53998414,  962.96445473,
       1113.60609836, 1419.82399974, 1410.3134755 , 1088.0820796 ,
       1170.14204567, 1154.28821671, 1151.86993756, 1090.14866633,
       1470.95818953, 1416.93861964, 1187.22124061, 1209.86582887,
       1173.13364001, 1204.35265554, 1074.95821574, 1488.27431608,
       1444.36153694, 1092.46889723, 1086.39658139, 1196.52577665,
       1200.20297963, 1086.85471346, 1398.4609471 , 1382.90372246,
       1109.46100471, 1104.93521271, 1127.37518878, 1145.25896315,
       1125.45997931, 1425.01062206, 1283.24991163,  904.31551221,
        963.66386066,  894.02561944,  881.56684885, 1095.78909948,
       1395.76661295, 1401.39360884,  985.51172565,  924.21171875,
        936.04153483,  865.44850867,  916.04020317, 1488.88906595,
       1300.92422879,  894.98220907, 1168.05011498,  978.05622

In [25]:
import mlflow
import mlflow.tensorflow
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import pandas as pd

# Generate synthetic data
data = generate_apple_sales_data_with_promo_adjustment(base_demand=1_000, n_rows=1_000)
X = data.drop(columns=["date", "demand"])
y = data["demand"]

# Split the data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Define the model architecture
model = Sequential([
    Dense(64, activation='relu', input_shape=(X_train.shape[1],)),
    Dense(32, activation='relu'),
    Dense(1)
])

# Compile the model
model.compile(optimizer=Adam(learning_rate=0.001), loss='mean_squared_error', metrics=['mae'])

# Log the model and training process with MLflow
mlflow.set_tracking_uri("http://127.0.0.1:8080")
mlflow.set_experiment("Apple_Models")

with mlflow.start_run(run_name="apples_nn_test") as run:
    mlflow.tensorflow.autolog()

    # Train the model
    history = model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=50, batch_size=32)

    # Get the run ID
    run_id = run.info.run_id
    print(f"Run ID: {run_id}")

The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  df["previous_days_demand"].fillna(method="bfill", inplace=True)  # fill the first row
  df["previous_days_demand"].fillna(method="bfill", inplace=True)  # fill the first row
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/50
[1m18/25[0m [32m━━━━━━━━━━━━━━[0m[37m━━━━━━[0m [1m0s[0m 3ms/step - loss: 870695.8750 - mae: 902.9325  



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 24ms/step - loss: 773011.2500 - mae: 837.5743 - val_loss: 128284.4062 - val_mae: 315.5689
Epoch 2/50
[1m22/25[0m [32m━━━━━━━━━━━━━━━━━[0m[37m━━━[0m [1m0s[0m 3ms/step - loss: 83767.9531 - mae: 236.6381  



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - loss: 80380.5469 - mae: 230.2921 - val_loss: 46895.0898 - val_mae: 159.3569
Epoch 3/50
[1m23/25[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 3ms/step - loss: 47500.7383 - mae: 164.7202 



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - loss: 47415.7539 - mae: 164.9060 - val_loss: 45867.4883 - val_mae: 159.6621
Epoch 4/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 46067.8125 - mae: 165.0128 - val_loss: 46261.3750 - val_mae: 162.0467
Epoch 5/50
[1m 1/25[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 36ms/step - loss: 44277.8281 - mae: 172.0137



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - loss: 44069.8281 - mae: 161.9542 - val_loss: 45851.2461 - val_mae: 159.9083
Epoch 6/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - loss: 49116.1797 - mae: 171.1547 - val_loss: 45987.3359 - val_mae: 160.9281
Epoch 7/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 44941.4805 - mae: 162.4988 - val_loss: 45875.6992 - val_mae: 160.4032
Epoch 8/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - loss: 42689.9492 - mae: 156.2705 - val_loss: 45965.1094 - val_mae: 161.0311
Epoch 9/50
[1m 1/25[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 33ms/step - loss: 32960.0781 - mae: 127.1657



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - loss: 44178.0586 - mae: 160.8337 - val_loss: 45690.8984 - val_mae: 159.2775
Epoch 10/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - loss: 44301.8711 - mae: 162.0094 - val_loss: 45720.5195 - val_mae: 159.8416
Epoch 11/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 43764.6836 - mae: 160.5575 - val_loss: 45743.6055 - val_mae: 160.1866
Epoch 12/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - loss: 43634.7891 - mae: 158.7435 - val_loss: 45774.1406 - val_mae: 160.5150
Epoch 13/50
[1m 1/25[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 34ms/step - loss: 66218.5312 - mae: 217.1445



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - loss: 44849.6758 - mae: 163.6651 - val_loss: 45630.0938 - val_mae: 159.7803
Epoch 14/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 43755.7969 - mae: 160.5888 - val_loss: 45786.9414 - val_mae: 160.8586
Epoch 15/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 46935.2070 - mae: 165.1911 - val_loss: 45959.3203 - val_mae: 161.7298
Epoch 16/50
[1m 1/25[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 32ms/step - loss: 40450.7109 - mae: 164.0104



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - loss: 44703.6680 - mae: 162.0584 - val_loss: 45581.3281 - val_mae: 160.0394
Epoch 17/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 47454.5938 - mae: 166.5486 - val_loss: 45695.0156 - val_mae: 160.7991
Epoch 18/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - loss: 45691.0156 - mae: 164.6211 - val_loss: 45653.2969 - val_mae: 157.7888
Epoch 19/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 47225.6875 - mae: 165.3897 - val_loss: 45970.2852 - val_mae: 162.1232
Epoch 20/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 43952.4336 - mae: 161.5458 - val_loss: 45615.3789 - val_mae: 160.8147
Epoch 21/50
[1m 1/25[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 40ms/step - loss: 36108.2812 - mae: 154.6852



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - loss: 41573.2852 - mae: 154.0695 - val_loss: 45210.0117 - val_mae: 158.1982
Epoch 22/50
[1m 1/25[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m2s[0m 104ms/step - loss: 49352.9297 - mae: 181.7302



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - loss: 45140.5195 - mae: 163.2953 - val_loss: 45179.0000 - val_mae: 158.4604
Epoch 23/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - loss: 45053.9961 - mae: 162.9442 - val_loss: 45266.6250 - val_mae: 159.4357
Epoch 24/50
[1m 1/25[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 41ms/step - loss: 39496.2969 - mae: 173.5960



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - loss: 44421.7422 - mae: 164.4934 - val_loss: 45146.7734 - val_mae: 158.8608
Epoch 25/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 41208.3320 - mae: 155.2137 - val_loss: 45728.8281 - val_mae: 161.8227
Epoch 26/50
[1m 1/25[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 33ms/step - loss: 38108.0508 - mae: 150.2316



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - loss: 42117.0820 - mae: 156.7954 - val_loss: 44990.1250 - val_mae: 157.7323
Epoch 27/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - loss: 47922.6602 - mae: 168.6184 - val_loss: 45164.6094 - val_mae: 159.6009
Epoch 28/50
[1m 1/25[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 41ms/step - loss: 47871.6172 - mae: 157.4718



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - loss: 44213.5391 - mae: 160.1396 - val_loss: 44891.8594 - val_mae: 157.6864
Epoch 29/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - loss: 41859.1016 - mae: 157.3878 - val_loss: 45493.1914 - val_mae: 161.3634
Epoch 30/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 44459.0117 - mae: 160.5053 - val_loss: 45064.0000 - val_mae: 159.6107
Epoch 31/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 47042.4336 - mae: 166.1865 - val_loss: 45728.9297 - val_mae: 162.4207
Epoch 32/50
[1m 1/25[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 34ms/step - loss: 38549.2539 - mae: 152.5335



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - loss: 42964.9062 - mae: 157.9765 - val_loss: 44710.5117 - val_mae: 157.8289
Epoch 33/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - loss: 45494.1602 - mae: 162.7274 - val_loss: 44758.1914 - val_mae: 158.4861
Epoch 34/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - loss: 43397.4727 - mae: 159.4102 - val_loss: 45310.4883 - val_mae: 156.8609
Epoch 35/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 41859.6406 - mae: 157.7470 - val_loss: 46303.3594 - val_mae: 164.5471
Epoch 36/50
[1m 1/25[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 41ms/step - loss: 29729.1289 - mae: 124.3392



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - loss: 47220.8828 - mae: 165.8809 - val_loss: 44635.2812 - val_mae: 158.4330
Epoch 37/50
[1m 1/25[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 32ms/step - loss: 56211.7969 - mae: 190.9212



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - loss: 45277.6094 - mae: 162.9026 - val_loss: 44447.0547 - val_mae: 157.3495
Epoch 38/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 44554.3789 - mae: 160.5087 - val_loss: 45086.8906 - val_mae: 160.8498
Epoch 39/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 43276.8203 - mae: 157.8212 - val_loss: 44975.0781 - val_mae: 160.5451
Epoch 40/50
[1m 1/25[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 33ms/step - loss: 35367.7227 - mae: 140.5056



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - loss: 41135.1055 - mae: 153.0143 - val_loss: 44317.4844 - val_mae: 157.4525
Epoch 41/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - loss: 45617.1445 - mae: 164.9426 - val_loss: 44783.8398 - val_mae: 160.0587
Epoch 42/50
[1m 1/25[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1s[0m 49ms/step - loss: 68859.3203 - mae: 197.5869



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - loss: 46073.1719 - mae: 162.5866 - val_loss: 44177.2656 - val_mae: 156.7612
Epoch 43/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 43825.7656 - mae: 160.0350 - val_loss: 44222.2109 - val_mae: 157.7308
Epoch 44/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 44192.3672 - mae: 164.0469 - val_loss: 44333.3008 - val_mae: 158.5511
Epoch 45/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - loss: 39913.1211 - mae: 152.3642 - val_loss: 44405.3750 - val_mae: 159.1049
Epoch 46/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - loss: 41211.8242 - mae: 155.5473 - val_loss: 45087.7695 - val_mae: 161.8024
Epoch 47/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - loss: 41424.3008 - mae: 157.1539 - val_loss: 44294.4492 - val_mae: 158.9604
Epoch 48/50
[1m 1/25[0m [37m━━



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - loss: 45327.8086 - mae: 161.8060 - val_loss: 43840.7812 - val_mae: 156.4322
Epoch 49/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 42216.6719 - mae: 157.2460  



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 12ms/step - loss: 42268.8867 - mae: 157.3746 - val_loss: 43781.5312 - val_mae: 156.3678
Epoch 50/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - loss: 40527.9727 - mae: 155.3786 - val_loss: 44100.6133 - val_mae: 158.6726




Run ID: 14f866cf6602450692bca9cacaea096a
🏃 View run apples_nn_test at: http://127.0.0.1:8080/#/experiments/932508242658220903/runs/14f866cf6602450692bca9cacaea096a
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/932508242658220903


In [43]:
# Load the model using MLflow
model_uri = f"runs:/14f866cf6602450692bca9cacaea096a/model"
loaded_model = mlflow.tensorflow.load_model(model_uri)

# Prepare fresh, unseen data points
fresh_data = pd.DataFrame({
    "average_temperature": [20.5, 25.3, 30.1],
    "rainfall": [5.2, 3.1, 0.0],
    "weekend": [0, 1, 0],
    "holiday": [0, 0, 1],
    "price_per_kg": [1.5, 2.0, 1.8],
    "promo": [1, 0, 1],
    "previous_days_demand": [950.0, 1020.0, 980.0]
})

# Make predictions
predictions = loaded_model.predict(fresh_data)
print(predictions)





[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 107ms/step
[[928.94867]
 [999.8631 ]
 [962.14667]]


In [44]:
%%writefile streamlit_app.py

import streamlit as st
import mlflow.sklearn
import pandas as pd

# Replace <run_id> with the actual run ID you retrieved
run_id = "14f866cf6602450692bca9cacaea096a"
model_uri = f"runs:/{run_id}/rf_apples"

# Load the model
model = mlflow.sklearn.load_model(model_uri)

st.title("Apple Sales Prediction")

# Define input fields for all features
average_temperature = st.number_input("Average Temperature", value=0.0)
rainfall = st.number_input("Rainfall", value=0.0)
weekend = st.selectbox("Weekend", [0, 1])
holiday = st.selectbox("Holiday", [0, 1])
price_per_kg = st.number_input("Price per Kg", value=0.0)
promo = st.selectbox("Promo", [0, 1])
previous_days_demand = st.number_input("Previous Day's Demand", value=0.0)

# Create a DataFrame from the input
input_data = pd.DataFrame({
    "average_temperature": [average_temperature],
    "rainfall": [rainfall],
    "weekend": [weekend],
    "holiday": [holiday],
    "price_per_kg": [price_per_kg],
    "promo": [promo],
    "previous_days_demand": [previous_days_demand]
})

# Make predictions
if st.button("Predict"):
    predictions = model.predict(input_data)
    st.write(f"Predicted Demand: {predictions[0]}")

Overwriting streamlit_app.py


In [45]:
!streamlit run streamlit_app.py

^C


#### Success!

You've just logged your first MLflow model! 

Navigate to the MLflow UI to see the run that was just created (named "apples_rf_test", logged to the Experiment "Apple_Models"). 