# Batch Training with Ray Datasets

# Introduction

Batch training and tuning are common tasks in simple machine learning use-cases such as time series forecasting. They require fitting of simple models on multiple data batches corresponding to locations, products, etc.

In the context of this notebook, batch training is understood as creating the same model(s) for different and separate datasets or subsets of a dataset. This notebook showcases how to conduct batch training using [Ray Dataset](https://docs.ray.io/en/latest/data/dataset.html).

![Batch training diagram](./images/batch-training.svg)

For the data, we will use the [NYC Taxi dataset](https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page). This popular tabular dataset contains historical taxi pickups by timestamp and location in NYC. To demonstrate batch training, we will simplify the data to a regression problem to predict `trip_duration` and use scikit-learn.

To demonstrate how batch training can be parallelized, we will train a separate model for each dropoff location. This means we can use the `dropoff_location_id` column in the dataset to group the dataset into data batches. Then we will fit a separate model for each batch and evaluate it.

# Contents

In this this tutorial, you will learn about:
 1. Creating a Ray Dataset,
 2. Inspecting a Ray Dataset,
 3. Running data transformations on a Ray Dataset in parallel,
 4. How to perform batch training with Ray Datasets using *group-by*.

# Walkthrough

Let’s start by importing a few required libraries, including open-source [Ray](https://github.com/ray-project/ray) itself!

In [1]:
from typing import Tuple, List, Union, Optional, Callable
import time
import pandas as pd
import numpy as np
import pyarrow.dataset as pds
from pyarrow import fs
from pyarrow import parquet as pq
from ray.data import Dataset

In [2]:
import ray

ray.init(ignore_reinit_error=True)

0,1
Python version:,3.8.5
Ray version:,2.0.0
Dashboard:,http://console.anyscale-staging.com/api/v2/sessions/ses_ZmHebxHaZpYkw9x9efJ5wBVX/services?redirect_to=dashboard


In [3]:
# For benchmarking purposes, we can print the times of various operations.
# In order to reduce clutter in the output, this is set to False by default.
PRINT_TIMES = False


def print_time(msg: str):
    if PRINT_TIMES:
        print(msg)

In [4]:
# To speed things up, we’ll only use a small subset of the full dataset consisting of two last months of 2019.
# You can choose to use the full dataset for 2018-2019 by setting the SMOKE_TEST variable to False.

SMOKE_TEST = True

## Introduction to Ray Datasets <a class="anchor" id="dataset"></a>

[Ray Datasets](datasets) are the standard way to load and exchange data in Ray libraries and applications. We will use the [Ray Dataset APIs](dataset-api) to read the data and quickly inspect it.

First, we will define some global variables we will use throughout the notebook, such as the list of S3 links to the files making up the dataset and the possible location IDs.

In [5]:
# Define some global variables.
target = "trip_duration"
s3_partitions = pds.dataset(
    "s3://anonymous@air-example-data/ursa-labs-taxi-data/by_year/",
    partitioning=["year", "month"],
)

s3_files = [f"s3://{file}" for file in s3_partitions.files]

# Obtain all location IDs
location_ids = (
    pq.read_table(s3_files[0], columns=["pickup_location_id"])["pickup_location_id"]
    .unique()
    .to_pylist()
)

starting_idx = -2 if SMOKE_TEST else 0

s3_files = s3_files[starting_idx:]
print(f"NYC Taxi using {len(s3_files)} file(s)!")
print(f"s3_files: {s3_files}")

NYC Taxi using 2 file(s)!
s3_files: ['s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 's3://air-example-data/ursa-labs-taxi-data/by_year/2019/06/data.parquet/ab5b9d2b8cc94be19346e260b543ec35_000000.parquet']


Next, we will call `ray.data.read_parquet` to create a Ray dataset from a list of S3 URIs. This will read the files in parallel onto the Ray cluster.

In [6]:
ds = ray.data.read_parquet(s3_files)
ds



Dataset(num_blocks=2, num_rows=14506285, schema={vendor_id: string, pickup_at: timestamp[us], dropoff_at: timestamp[us], passenger_count: int8, trip_distance: float, rate_code_id: string, store_and_fwd_flag: string, pickup_location_id: int32, dropoff_location_id: int32, payment_type: string, fare_amount: float, extra: float, mta_tax: float, tip_amount: float, tolls_amount: float, improvement_surcharge: float, total_amount: float, congestion_surcharge: float})

### Ray Dataset statistics

Let's get some basic statistics about our newly created Ray Dataset.

As our Ray Dataset is backed by Parquet, we can obtain the number of rows from the metadata without triggering a full data read.

In [7]:
print(f"Number of rows: {ds.count()}")

Number of rows: 14506285


Similarly, we can obtain the Dataset size (in bytes) from the metadata.

In [8]:
print(f"Size bytes (from parquet metadata): {ds.size_bytes()}")

Size bytes (from parquet metadata): 1928872430


Let's fetch and inspect the schema of the underlying Parquet files.

In [9]:
print("\nSchema data types:")
data_types = list(zip(ds.schema().names, ds.schema().types))
for s in data_types:
    print(f"{s[0]}: {s[1]}")


Schema data types:
vendor_id: string
pickup_at: timestamp[us]
dropoff_at: timestamp[us]
passenger_count: int8
trip_distance: float
rate_code_id: string
store_and_fwd_flag: string
pickup_location_id: int32
dropoff_location_id: int32
payment_type: string
fare_amount: float
extra: float
mta_tax: float
tip_amount: float
tolls_amount: float
improvement_surcharge: float
total_amount: float
congestion_surcharge: float


### Filter on Read - Projection and Filter Pushdown

Note that Ray Datasets' Parquet reader supports projection (column selection) and row filter pushdown, where we can push the above column selection and the row-based filter to the Parquet read. If we specify column selection at Parquet read time, the unselected columns won't even be read from disk. This can save a lot of memory, especially with big datasets, and allow us to avoid OOM issues.

The row-based filter is specified via [Arrow's dataset field expressions](https://arrow.apache.org/docs/6.0/python/generated/pyarrow.dataset.Expression.html#pyarrow.dataset.Expression). 

```{tip}
Best practice is to filter as much as you can directly in the Ray Dataset `read_parquet()`.
```

Normally, there is some data exploration to determine the cleaning steps. Let's just assume we know the data cleaning steps are:
- Drop negative trip distances, 0 fares, 0 passengers and trip durations smaller than 1 minute.
- Drop 2 unknown zones: `['264', '265']`.
- Calculate trip duration and add it as a new column.


In [10]:
def pushdown_read_data(files_list: list, sample_ids: list) -> Dataset:
    filter_expr = (
        (pds.field("passenger_count") > 0)
        & (pds.field("trip_distance") > 0)
        & (pds.field("fare_amount") > 0)
        & (~pds.field("pickup_location_id").isin([264, 265]))
        & (~pds.field("dropoff_location_id").isin([264, 265]))
        & (pds.field("dropoff_location_id").isin(sample_ids))
    )

    dataset = ray.data.read_parquet(
        files_list,
        columns=[
            "pickup_at",
            "dropoff_at",
            "pickup_location_id",
            "dropoff_location_id",
            "passenger_count",
            "trip_distance",
            "fare_amount",
        ],
        filter=filter_expr,
    )

    return dataset

In [11]:
# Test the pushdown_read_data function
pushdown_ds = pushdown_read_data(s3_files, location_ids)

print(f"Number rows: {pushdown_ds.count()}")
# Display some metadata about the dataset.
print("\nMetadata: ")
print(pushdown_ds)
# Fetch the schema from the underlying Parquet metadata.
print("\nSchema:")
print(pushdown_ds.schema())
# Take a peek at a single row
print("\nLook at a sample row:")
pushdown_ds.take(1)



Number rows: 14506285

Metadata: 
Dataset(num_blocks=2, num_rows=14506285, schema={pickup_at: timestamp[us], dropoff_at: timestamp[us], pickup_location_id: int32, dropoff_location_id: int32, passenger_count: int8, trip_distance: float, fare_amount: float})

Schema:
pickup_at: timestamp[us]
dropoff_at: timestamp[us]
pickup_location_id: int32
dropoff_location_id: int32
passenger_count: int8
trip_distance: float
fare_amount: float
-- schema metadata --
pandas: '{"index_columns": [{"kind": "range", "name": null, "start": 0, "' + 2548

Look at a sample row:


[ArrowRow({'pickup_at': datetime.datetime(2019, 5, 1, 0, 35, 54),
           'dropoff_at': datetime.datetime(2019, 5, 1, 0, 37, 27),
           'pickup_location_id': 145,
           'dropoff_location_id': 145,
           'passenger_count': 1,
           'trip_distance': 1.5,
           'fare_amount': 3.0})]

We can use `to_pandas` to convert a Ray Dataset into a pandas DataFrame and inspect that.

```{note}
Converting a Ray Dataset to pandas is not recommended with large data sizes, as it will load all the data into the memory of a single node. To help avoid OOM errors, `to_pandas` will by default raise an exception if there are more than 10000 rows in the Dataset. As we only want to see how the data looks like, we grab the first 100 rows and convert that.
```

In [12]:
df = pushdown_ds.limit(100).to_pandas()
df

Read progress: 100%|██████████| 1/1 [00:00<00:00, 781.79it/s]


Unnamed: 0,pickup_at,dropoff_at,pickup_location_id,dropoff_location_id,passenger_count,trip_distance,fare_amount
0,2019-05-01 00:35:54,2019-05-01 00:37:27,145,145,1,1.50,3.0
1,2019-05-01 00:37:45,2019-05-01 00:37:49,145,145,1,1.50,2.5
2,2019-05-01 00:44:57,2019-05-01 00:50:11,161,161,1,0.70,5.0
3,2019-05-01 00:59:48,2019-05-01 01:10:22,163,141,1,2.00,9.5
4,2019-05-01 00:23:20,2019-05-01 00:32:57,260,56,1,2.50,10.0
...,...,...,...,...,...,...,...
95,2019-05-01 00:27:58,2019-05-01 00:30:26,106,106,1,10.60,2.5
96,2019-05-01 00:51:47,2019-05-01 01:15:42,45,241,1,15.50,42.5
97,2019-05-01 00:20:43,2019-05-01 00:26:22,231,87,1,1.16,6.0
98,2019-05-01 00:47:54,2019-05-01 00:53:01,162,162,2,0.88,5.5


### Custom data transform functions

Ray Datasets allows you to specify custom data transform functions. These [user defined functions (UDFs)](transforming_datasets) can be called using `Dataset.map_batches(my_UDF)`. The transformation will be conducted in parallel for each data batch.

```{tip}
You may need to call `Dataset.repartition(n)` first to split the Dataset into more blocks internally. By default, each block corresponds to one file. The upper bound of parallelism is the number of blocks.
```

You can specify the data format you are using in the `batch_format` parameter. The dataset will be divided into batches and those batches converted into the specified format. Available data formats you can specify in the `batch_format` paramater include `"pandas", "pyarrow", "numpy"`. Tabular data will be passed into your UDF by default as a pandas DataFrame. Tensor data will be passed into your UDF as a numpy array.

Here, we will use `batch_format="pandas"` explicitly for clarity.

In [13]:
# A pandas DataFrame UDF for transforming the Dataset in parallel.
def transform_batch(df: pd.DataFrame) -> pd.DataFrame:
    df["trip_duration"] = (df["dropoff_at"] - df["pickup_at"]).dt.seconds
    df = df[df["trip_duration"] > 60]
    df.drop(["dropoff_at", "pickup_at", "pickup_location_id"], axis=1, inplace=True)
    df["dropoff_location_id"] = df["dropoff_location_id"].fillna(-1)
    return df

In [14]:
# Test the transform UDF.
print(f"Number of rows before transformation: {pushdown_ds.count()}")

# Repartition the dataset to allow for higher parallelism.
pushdown_ds = pushdown_ds.repartition(16)

# batch_format="pandas" tells Datasets to provide the UDF with batches
# represented as pandas DataFrames.
pushdown_ds = pushdown_ds.map_batches(transform_batch, batch_format="pandas")

# Verify row count.
print(f"Number of rows after transformation: {pushdown_ds.count()}")

Number of rows before transformation: 14506285


Read: 100%|██████████| 2/2 [00:03<00:00,  1.88s/it]
Repartition: 100%|██████████| 16/16 [00:02<00:00,  7.60it/s]
Map_Batches: 100%|██████████| 16/16 [00:02<00:00,  5.75it/s]


Number of rows after transformation: 13893562


### Tidying up

We'll delete the datasets we have been using in order to free up memory in our Ray cluster.

In [15]:
del ds
del pushdown_ds

To make our code easier to read, let's summarize the data processing functions again here.

In [16]:
# Filter parquet data using Ray Datasets read_parquet()
def pushdown_read_data(files_list: list, sample_ids: list) -> Dataset:

    start = time.time()

    filter_expr = (
        (pds.field("passenger_count") > 0)
        & (pds.field("trip_distance") > 0)
        & (pds.field("fare_amount") > 0)
        & (~pds.field("pickup_location_id").isin([264, 265]))
        & (~pds.field("dropoff_location_id").isin([264, 265]))
        & (pds.field("dropoff_location_id").isin(sample_ids))
    )

    dataset = ray.data.read_parquet(
        files_list,
        columns=[
            "pickup_at",
            "dropoff_at",
            "pickup_location_id",
            "dropoff_location_id",
            "passenger_count",
            "trip_distance",
            "fare_amount",
        ],
        filter=filter_expr,
    )

    data_loading_time = time.time() - start
    print_time(f"Data loading time: {data_loading_time:.2f} seconds")
    return dataset


# A Pandas DataFrame UDF for transforming the underlying blocks of a Dataset in parallel.
def transform_batch(df: pd.DataFrame) -> pd.DataFrame:
    df["trip_duration"] = (df["dropoff_at"] - df["pickup_at"]).dt.seconds
    df = df[df["trip_duration"] > 60]
    df.drop(["dropoff_at", "pickup_at", "pickup_location_id"], axis=1, inplace=True)
    df["dropoff_location_id"] = df["dropoff_location_id"].fillna(-1)
    return df

## Batch training with Ray Datasets <a class="anchor" id="train_func"></a>

Now that we have learned more about our data and written a pandas UDF to transform our data, we are ready to train a model on batches of this data in parallel.

1. We will use the `dropoff_location_id` column in the dataset to group the dataset into data batches. 
2. Then we will fit a separate model for each batch to predict `trip_duration`.

In [17]:
import sklearn
from sklearn.base import BaseEstimator
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error
from ray.train.sklearn import SklearnTrainer, SklearnPredictor
from ray.train.batch_predictor import BatchPredictor

### Define training functions

We want to fit a linear regression model to the trip duration for each drop-off location. For scoring, we will calculate mean absolute error on the validation set, and report that as model error per drop-off location.

The `fit_and_score_sklearn` function contains the logic necessary to fit a scikit-learn model and evaluate it using mean absolute error.

In [18]:
def fit_and_score_sklearn(
    train_df: pd.DataFrame, test_df: pd.DataFrame, model: BaseEstimator
) -> Tuple[BaseEstimator, float]:
    # Assemble train/test pandas dfs
    train_X = train_df[["passenger_count", "trip_distance", "fare_amount"]]
    train_y = train_df.trip_duration
    test_X = test_df[["passenger_count", "trip_distance", "fare_amount"]]
    test_y = test_df.trip_duration

    # Start training.
    model = model.fit(train_X, train_y)
    pred_y = model.predict(test_X)
    error = sklearn.metrics.mean_absolute_error(test_y, pred_y)

    return model, error

The `train_and_evaluate` function contains the logic for train-test splitting and fitting of a model using the `fit_and_score_sklearn` function.

As an input, this function takes in a pandas DataFrame. When we call `Dataset.map_batches` or `Dataset.groupby().map_groups()`, the Dataset will be batched into multiple pandas DataFrames and this function will be ran for each one in parallel. We will return the model and its error. Those results will be collected back into a Ray Dataset automatically.

In [19]:
def train_and_evaluate(
    df: pd.DataFrame, model: BaseEstimator, location_id: int
) -> pd.DataFrame:
    # check if input df is big enough for training
    if len(df) < 4:
        print_time(f"Data batch for LocID {location_id} is empty or smaller than 4 rows")
        return None

    start = time.time()

    # Train / test split
    # Randomly split the data into 80/20 train/test.
    train_df, test_df = train_test_split(df, test_size=0.2, shuffle=True)

    results = fit_and_score_sklearn(train_df, test_df, model)

    # Assemble location_id, name of model, and metrics in a pandas DataFrame
    results = [location_id] + list(results)
    results_df = pd.DataFrame([results], columns=["location_id", "model", "error"])

    training_time = time.time() - start
    print_time(f"Training time for LocID {location_id}: {training_time:.2f} seconds")

    return results_df

For the model itself, we will use scikit-learn's Linear Regression.

In [20]:
MODEL = LinearRegression()

Recall how we wrote a data transform `transform_batch` UDF? It was called with pattern:
- `Dataset.map_batches(transform_batch, batch_format="pandas")`

Similarly, a groupby-aggregation function can be used later when we perform a [Ray Dataset *group-by*](datasets-groupbys). We will define our aggregation function `agg_func` which will be ran for each group in parallel. The usage pattern is:
- `Dataset.groupby(column).map_groups(agg_func, batch_format="pandas")`.

In [21]:
# A Pandas DataFrame aggregation function for processing grouped batches of Ray Dataset data.
def agg_func(df: pd.DataFrame) -> pd.DataFrame:
    location_id = df["dropoff_location_id"][0]

    # Handle errors in data groups
    try:
        # Transform the input pandas AND fit_and_evaluate the transformed pandas
        transformed_df = transform_batch(df)
        results_df = train_and_evaluate(transformed_df, MODEL, location_id)
        assert results_df is not None
    except Exception:
        # assemble a null entry
        print(f"Failed on LocID {location_id}!")
        results_df = pd.DataFrame([[location_id, None, None]], columns=["location_id", "model", "error"])

    return results_df

### Run batch training using `map_groups`

Finally, the main "driver code" reads each Parquet file (each file corresponds to one month of NYC taxi data) into a Ray Dataset `ds`. Then we use Ray Dataset *group-by* to map each group into a batch of data and run `agg_func` on each of them in parallel by calling `ds.groupby("dropoff_location_id").map_groups(agg_func, batch_format="pandas")`.

In [22]:
# Driver code to run this.
start = time.time()

# Read data into Ray Dataset
ds = pushdown_read_data(s3_files, location_ids).repartition(16)

# Use Ray Dataset groupby.map_groups() to process each group in parallel and return a Ray Dataset.
results = ds.groupby("dropoff_location_id").map_groups(agg_func, batch_format="pandas")
print(f"groupby.map_groups() finished!")

total_time_taken = time.time() - start
print(f"Total number of models: {results.count()}")
print(f"TOTAL TIME TAKEN: {total_time_taken:.2f} seconds")

Read: 100%|██████████| 2/2 [00:01<00:00,  1.08it/s]
Repartition: 100%|██████████| 16/16 [00:01<00:00, 13.22it/s]
Sort Sample: 100%|██████████| 16/16 [00:01<00:00, 13.27it/s]
Shuffle Map: 100%|██████████| 16/16 [00:01<00:00, 11.53it/s]
Shuffle Reduce: 100%|██████████| 16/16 [00:01<00:00, 12.34it/s]
Map_Batches:  100%|██████████| 16/16 [00:01<00:00,  6.63s/it]

(_map_block_nosplit pid=66620) Failed on LocID 199!


Map_Batches: 100%|██████████| 16/16 [02:53<00:00, 10.83s/it]


groupby.map_groups() finished!
Total number of models: 257
TOTAL TIME TAKEN: 180.74 seconds


Finally, we can inspect the models we have trained and their errors.

In [23]:
results

Dataset(num_blocks=16, num_rows=257, schema={location_id: int32, model: object, error: float64})

In [24]:
# sort values by location id
results_df = results.to_pandas(limit=float("inf"))
results_df.sort_values(by=["location_id"], ascending=True, inplace=True)
results_df

Unnamed: 0,location_id,model,error
0,1,LinearRegression(),1186.550249
1,2,LinearRegression(),207.324870
2,3,LinearRegression(),1412.273083
3,4,LinearRegression(),382.897671
4,5,LinearRegression(),2515.787663
...,...,...,...
252,259,LinearRegression(),495.479018
253,260,LinearRegression(),506.609215
254,261,LinearRegression(),684.562030
255,262,LinearRegression(),389.035014
