# Many Model Training with Ray Tune

This template is a quickstart to using [Ray Tune](todo) for batch inference. Ray Tune is one of many libraries under the [Ray AI Runtime](air). See [this blog post](https://www.anyscale.com/blog/training-one-million-machine-learning-models-in-record-time-with-ray) for more information on the benefits of performing many model training with Ray!

This template walks through time-series forecasting using `sklearn`, but the framework and data format can be swapped out easily -- they are there just to help you build your own application!

At a high level, this template will:


> Slot in your code below wherever you see the ✂️ icon to build a many model training Ray application off of this template!

In [36]:
import pandas as pd
from statsforecast import StatsForecast
from statsforecast.models import AutoARIMA, AutoETS
from pyarrow import parquet as pq
from sklearn.metrics import mean_squared_error

import ray
from ray import tune
from ray.air import session

> ✂️ Replace this value to change the number of data partitions you will use. This will be total the number of Tune trials you will run!
>
> Note that this template will fit two models on each data partition and report the best performing one.

In [40]:
NUM_DATA_PARTITIONS: int = 10

> ✂️ Replace the following with your own data-loading and evaluation helper functions. (Or, just delete these!)

In [41]:
def get_m5_partition(unique_id: str) -> pd.DataFrame:
    df = pq.read_table(
        "s3://anonymous@m5-benchmarks/data/train/target.parquet",
        columns=["item_id", "timestamp", "demand"],
        filters=[("item_id", "=", unique_id)],
    ).to_pandas().rename(
        columns={"item_id": "unique_id", "timestamp": "ds", "demand": "y"}
    )
    df["unique_id"] = df["unique_id"].astype(str)
    df["ds"] = pd.to_datetime(df["ds"])
    return df.dropna()

def evaluate_cross_validation(df, metric):
    models = df.drop(columns=['ds', 'cutoff', 'y']).columns.tolist()
    evals = []
    for model in models:
        eval_ = df.groupby(['unique_id', 'cutoff']).apply(
            lambda x: metric(x['y'].values, x[model].values)
        ).to_frame()
        eval_.columns = [model]
        evals.append(eval_)
    evals = pd.concat(evals, axis=1)
    evals = evals.groupby(['unique_id']).mean(numeric_only=True)
    evals['best_model'] = evals.idxmin(axis=1)
    return evals

> ✂️ Replace this with your own training logic.

In [42]:
model_classes = [AutoARIMA, AutoETS]
n_windows = 1

def train_fn(config: dict):
    data_partition_id = config["data_partition_id"]
    train_df = get_m5_partition(data_partition_id)
    
    models = [model_cls() for model_cls in model_classes]
    forecast_horizon = 4
    
    sf = StatsForecast(
        df=train_df,
        models=models,
        freq="D",
        n_jobs=n_windows * len(models),
    )
    cv_df = sf.cross_validation(
        h=forecast_horizon,
        step_size=forecast_horizon,
        n_windows=n_windows,
    )

    eval_df = evaluate_cross_validation(df=cv_df, metric=mean_squared_error)
    best_model = eval_df["best_model"][data_partition_id]
    forecast_mse = eval_df[best_model][data_partition_id]

    # Report the best-performing model and its corresponding eval metric.
    session.report({"forecast_mse": forecast_mse, "best_model": best_model})

trainable = train_fn
trainable = tune.with_resources(
    trainable,
    resources={"CPU": len(model_classes) * n_windows}
)

```{note}
`tune.with_resources` is used at the end to specify the number of resources to assign *each trial*.
Feel free to change this to the resources required by your application! You can also comment out the `tune.with_resources` block to assign `1 CPU` (the default) to each trial.

Note that this is purely for Tune to know how many trials to schedule concurrently -- setting the number of CPUs does not actually enforce any kind of resource isolation!
```

> ✂️ Replace this with your desired hyperparameter search space!
>
> For example, this template searches over the data partition ID to train a model on.

In [43]:
data_partitions = list(pd.read_csv("item_ids.csv")["item_id"])
param_space = {
    "data_partition_id": tune.grid_search(data_partitions[:NUM_DATA_PARTITIONS]),
}

Run many model training using Ray Tune!

In [44]:
tuner = tune.Tuner(trainable, param_space=param_space)
result_grid = tuner.fit()

0,1
Current time:,2023-03-13 19:33:03
Running for:,00:01:13.78
Memory:,9.0/30.9 GiB

Trial name,status,loc,data_partition_id,iter,total time (s),forecast_mse
train_fn_5f8c1_00000,TERMINATED,10.0.59.45:120769,FOODS_1_001_CA_1,1,38.3696,0.642752
train_fn_5f8c1_00001,TERMINATED,10.0.59.45:120924,FOODS_1_001_CA_2,1,38.0588,0.693251
train_fn_5f8c1_00002,TERMINATED,10.0.59.45:121062,FOODS_1_001_CA_3,1,38.6348,1.74863
train_fn_5f8c1_00003,TERMINATED,10.0.59.45:121205,FOODS_1_001_CA_4,1,38.6539,0.187969
train_fn_5f8c1_00004,TERMINATED,10.0.59.45:120769,FOODS_1_001_TX_1,1,7.745,1.62253
train_fn_5f8c1_00005,TERMINATED,10.0.59.45:120924,FOODS_1_001_TX_2,1,6.58181,0.217498
train_fn_5f8c1_00006,TERMINATED,10.0.59.45:120769,FOODS_1_001_TX_3,1,9.11934,0.213073
train_fn_5f8c1_00007,TERMINATED,10.0.59.45:120924,FOODS_1_001_WI_1,1,7.80567,0.254881
train_fn_5f8c1_00008,TERMINATED,10.0.59.45:121062,FOODS_1_001_WI_2,1,6.68468,1.69451
train_fn_5f8c1_00009,TERMINATED,10.0.59.45:120769,FOODS_1_001_WI_3,1,8.56359,0.18749


Trial name,best_model,date,done,experiment_tag,forecast_mse,hostname,iterations_since_restore,node_ip,pid,time_since_restore,time_this_iter_s,time_total_s,timestamp,training_iteration,trial_id
train_fn_5f8c1_00000,AutoARIMA,2023-03-13_19-32-37,True,0_data_partition_id=FOODS_1_001_CA_1,0.642752,ip-10-0-59-45,1,10.0.59.45,120769,38.3696,38.3696,38.3696,1678761157,1,5f8c1_00000
train_fn_5f8c1_00001,AutoARIMA,2023-03-13_19-32-44,True,1_data_partition_id=FOODS_1_001_CA_2,0.693251,ip-10-0-59-45,1,10.0.59.45,120924,38.0588,38.0588,38.0588,1678761164,1,5f8c1_00001
train_fn_5f8c1_00002,AutoARIMA,2023-03-13_19-32-51,True,2_data_partition_id=FOODS_1_001_CA_3,1.74863,ip-10-0-59-45,1,10.0.59.45,121062,38.6348,38.6348,38.6348,1678761171,1,5f8c1_00002
train_fn_5f8c1_00003,AutoARIMA,2023-03-13_19-32-59,True,3_data_partition_id=FOODS_1_001_CA_4,0.187969,ip-10-0-59-45,1,10.0.59.45,121205,38.6539,38.6539,38.6539,1678761179,1,5f8c1_00003
train_fn_5f8c1_00004,AutoARIMA,2023-03-13_19-32-45,True,4_data_partition_id=FOODS_1_001_TX_1,1.62253,ip-10-0-59-45,1,10.0.59.45,120769,7.745,7.745,7.745,1678761165,1,5f8c1_00004
train_fn_5f8c1_00005,AutoETS,2023-03-13_19-32-50,True,5_data_partition_id=FOODS_1_001_TX_2,0.217498,ip-10-0-59-45,1,10.0.59.45,120924,6.58181,6.58181,6.58181,1678761170,1,5f8c1_00005
train_fn_5f8c1_00006,AutoARIMA,2023-03-13_19-32-54,True,6_data_partition_id=FOODS_1_001_TX_3,0.213073,ip-10-0-59-45,1,10.0.59.45,120769,9.11934,9.11934,9.11934,1678761174,1,5f8c1_00006
train_fn_5f8c1_00007,AutoETS,2023-03-13_19-32-58,True,7_data_partition_id=FOODS_1_001_WI_1,0.254881,ip-10-0-59-45,1,10.0.59.45,120924,7.80567,7.80567,7.80567,1678761178,1,5f8c1_00007
train_fn_5f8c1_00008,AutoARIMA,2023-03-13_19-32-58,True,8_data_partition_id=FOODS_1_001_WI_2,1.69451,ip-10-0-59-45,1,10.0.59.45,121062,6.68468,6.68468,6.68468,1678761178,1,5f8c1_00008
train_fn_5f8c1_00009,AutoETS,2023-03-13_19-33-03,True,9_data_partition_id=FOODS_1_001_WI_3,0.18749,ip-10-0-59-45,1,10.0.59.45,120769,8.56359,8.56359,8.56359,1678761183,1,5f8c1_00009


2023-03-13 19:33:03,453	INFO tune.py:825 -- Total run time: 73.80 seconds (73.78 seconds for the tuning loop).


> ✂️ Replace the metric and mode below with the metric you reported in your training function.

In [45]:
sample_result = result_grid[0]
sample_result.metrics

{'forecast_mse': 0.64275163,
 'best_model': 'AutoARIMA',
 'time_this_iter_s': 38.36956787109375,
 'done': True,
 'training_iteration': 1,
 'trial_id': '5f8c1_00000',
 'date': '2023-03-13_19-32-37',
 'timestamp': 1678761157,
 'time_total_s': 38.36956787109375,
 'pid': 120769,
 'hostname': 'ip-10-0-59-45',
 'node_ip': '10.0.59.45',
 'config': {'data_partition_id': 'FOODS_1_001_CA_1'},
 'time_since_restore': 38.36956787109375,
 'iterations_since_restore': 1,
 'experiment_tag': '0_data_partition_id=FOODS_1_001_CA_1'}