In [None]:
import os
import ray

from ray import tune
from ray import serve
from ray.air.config import ScalingConfig
from ray.train.xgboost import XGBoostTrainer
from ray.train.xgboost import XGBoostPredictor
from ray.train.batch_predictor import BatchPredictor
from ray.serve import PredictorDeployment
from ray.serve.http_adapters import pandas_read_json
from ray.tune import Tuner, TuneConfig

import requests

ray.init()

# Ray Train

## Intro

### Outline

-   Goals
-   Trainer
    - Design
    - Flavors
    - In-depth with TensorFlow Trainer

### Model scenarios with Ray + Tensorflow Trainer

- Start with a minimal model and focus on key elements for Ray Train workflow
- Port a minimal word2vec model from training locally in TF/Keras to Ray Train

### Context: Ray AIR

Ray AIR is the Ray AI Runtime, a set of high-level easy-to-use APIs for
ingesting data, training models – including reinforcement learning
models – tuning those models and then serving them.

<img src="https://technical-training-assets.s3.us-west-2.amazonaws.com/Introduction_to_Ray_AIR/e2e_air.png" width=600 loading="lazy"/>

Key principles behind Ray and Ray AIR are
* Performance
* Developer experience and simplicity

__Read, preprocess with Ray Data__

In [None]:
dataset = ray.data.read_parquet("s3://anyscale-training-data/intro-to-ray-air/nyc_taxi_2021.parquet")

train_dataset, valid_dataset = dataset.train_test_split(test_size=0.3)

__Fit model with Ray Train__

In [None]:
trainer = XGBoostTrainer(
    label_column="is_big_tip",
    scaling_config=ScalingConfig(num_workers=32, use_gpu=False),
    params={ "objective": "binary:logistic", },
    datasets={"train": train_dataset, "valid": valid_dataset},
)

result = trainer.fit()

__Optimize hyperparams with Ray Tune__

In [None]:
tuner = Tuner(trainer, 
            param_space={'params' : {'max_depth': tune.randint(2, 12)}},
            tune_config=TuneConfig(num_samples=10, metric='train-logloss', mode='min'))

checkpoint = tuner.fit().get_best_result().checkpoint

__Batch prediction__

In [None]:
batch_predictor = BatchPredictor.from_checkpoint(checkpoint, XGBoostPredictor)

predicted_probabilities = batch_predictor.predict(valid_dataset.drop_columns(['is_big_tip']))

__Online prediction with Ray Serve__

In [None]:
deployment = PredictorDeployment.bind(XGBoostPredictor, result.checkpoint, http_adapter=pandas_read_json)

serve.run(deployment)

__HTTP or Python services__

In [None]:
sample_input = dict(valid_dataset.take(1)[0])
del(sample_input['is_big_tip'])
del(sample_input['__index_level_0__'])
requests.post("http://localhost:8000/", json=[sample_input]).json()

## Train Goals

* Developer experience
* Flexibility
* Performance and simplicity via delegation
    * Train does not re-implement distributed optimizers
    * Train coordinates and delegates native platform distributed training

## `Trainer` design and usage

### Idea: Trainer -> Checkpoint
   
* Trainer used by Train, Tune
* Checkpoint used for inference (Ray Data [batch], Serve [online]) and reporting

### Trainer Flavors

* Tree - e.g., XGBoost
* Library - e.g., Huggingface
* DL Trainers
    * PyTorch, TensorFlow, Horovod, Lightning, Accelerate

### Focus: Tensorflow Trainer

"Hello World" (iris) example with minimal model to look at data/train structure

In [None]:
import tensorflow as tf

from ray.air import session
from ray.air.integrations.keras import ReportCheckpointCallback
from ray.train.tensorflow import TensorflowTrainer
from ray.air.config import ScalingConfig

In [None]:
ds = ray.data.read_csv("s3://air-example-data/iris.csv")
ds

In [None]:
ds.take(2)

"If your dataset contains multiple features but your model accepts a single tensor as input, combine features with Concatenator."
https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.to_tf.html

In [None]:
from ray.data.preprocessors import Concatenator

preprocessor = Concatenator(output_column_name="features", exclude="target")

ds = preprocessor.transform(ds)

ds

In [None]:
def build_model() -> tf.keras.Model:
    model = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=(4,)),
            tf.keras.layers.Dense(5),
            tf.keras.layers.Dense(1),
        ]
    )
    return model

Train func

In [None]:
def train_func(config: dict):
    batch_size = config.get("batch_size", 64)
    epochs = config.get("epochs", 3)

    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    with strategy.scope():
        # Model building/compiling need to be within `strategy.scope()`.
        multi_worker_model = build_model()
        multi_worker_model.compile(
            optimizer=tf.keras.optimizers.SGD(learning_rate=config.get("lr", 1e-3)),
            loss=tf.keras.losses.mean_squared_error,
            metrics=[tf.keras.metrics.mean_squared_error],
        )

    dataset = session.get_dataset_shard("train")

    results = []
    for _ in range(epochs):
        tf_dataset = dataset.to_tf(
            feature_columns="features", label_columns="target", batch_size=batch_size
        )
        history = multi_worker_model.fit(
            tf_dataset, callbacks=[ReportCheckpointCallback()]
        )
        results.append(history.history)
    return results

<img src='https://docs.ray.io/en/latest/_images/session.svg' width=800 />

* https://www.tensorflow.org/api_docs/python/tf/distribute/MultiWorkerMirroredStrategy
* details
    * https://docs.ray.io/en/latest/ray-air/api/session.html
    * ray.air.integrations.keras.ReportCheckpointCallback https://docs.ray.io/en/latest/tune/api/doc/ray.air.integrations.keras.ReportCheckpointCallback.html
    * "To save a model to use for the TensorflowPredictor, you must save it under the “model” kwarg in Checkpoint passed to session.report()."
        * https://docs.ray.io/en/latest/train/api/doc/ray.train.tensorflow.TensorflowTrainer.html#ray.train.tensorflow.TensorflowTrainer

manual checkpoint

`checkpoint = Checkpoint.from_dict(dict(epoch=epoch, model_weights=model.get_weights()))` 
https://docs.ray.io/en/latest/train/dl_guide.html

In [None]:
train_config = {"lr": 1e-3, "batch_size": 32, "epochs": 4}

scaling_config = ScalingConfig(num_workers=2, use_gpu=False)

trainer = TensorflowTrainer(
    train_loop_per_worker=train_func,
    train_loop_config=train_config,
    scaling_config=scaling_config,
    datasets={"train": ds},
)

In [None]:
result = trainer.fit()

In [None]:
result.metrics

In [None]:
result.checkpoint

Training dataset from Tensorflow word2vec tutorial (https://www.tensorflow.org/tutorials/text/word2vec)

In [None]:
dataset = tf.data.Dataset.load('w2v.data.tf')

In [None]:
dataset

In [None]:
from tensorflow.keras import layers

class Word2Vec(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim):
        super(Word2Vec, self).__init__()
        self.target_embedding = layers.Embedding(vocab_size,
                                      embedding_dim,
                                      input_length=1,
                                      name="w2v_embedding")

        num_ns = 4 # from dataset construction
        self.context_embedding = layers.Embedding(vocab_size,
                                       embedding_dim,
                                       input_length=num_ns+1)

    def call(self, pair):
        target, context = pair
        # target: (batch, dummy?)  # The dummy axis doesn't exist in TF2.7+
        # context: (batch, context)
        if len(target.shape) == 2:
            target = tf.squeeze(target, axis=1)
        # target: (batch,)
        word_emb = self.target_embedding(target)
        # word_emb: (batch, embed)
        context_emb = self.context_embedding(context)
        # context_emb: (batch, context, embed)
        dots = tf.einsum('be,bce->bc', word_emb, context_emb)
        # dots: (batch, context)
        return dots

In [None]:
vocab_size = 4096
embedding_dim = 128
word2vec = Word2Vec(vocab_size, embedding_dim)
word2vec.compile(optimizer='adam',
                 loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                 metrics=['accuracy'])

In [None]:
BUFFER_SIZE = 10000
BATCH_SIZE = 1024

word2vec.fit(dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE), epochs=10)

https://docs.ray.io/en/latest/ray-air/examples/convert_existing_tf_code_to_ray_air.html

In [None]:
# 1. Pass in the hyperparameter config
def train_func(config: dict):
    epochs = config.get("epochs", 5)
    batch_size_per_worker = config.get("batch_size", 32)
    buffer_size = config.get("buffer_size", 8192)
    
    # 2. Synchronized model setup
    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    with strategy.scope():
        vocab_size = 4096
        embedding_dim = 128
        model = Word2Vec(vocab_size, embedding_dim)
        model.compile(optimizer='adam',
                 loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                 metrics=['accuracy'])

    # 3. Shard the dataset across `session.get_world_size()` workers
    global_batch_size = batch_size_per_worker * session.get_world_size()
    
    ds_path = config.get('tf_data') # if we're using classic TF Data, this must be globally accessible
    train_ds = tf.data.Dataset.load(ds_path).shuffle(buffer_size).batch(global_batch_size).cache().prefetch(buffer_size=tf.data.AUTOTUNE)

    if session.get_world_rank() == 0:
        print(f"\nDataset is sharded across {session.get_world_size()} workers:")
        # The number of samples is approximate, because is not always
        # a multiple of batch_size, so some batches could contain fewer than
        # `batch_size_per_worker` samples.
        print(
            f"# training batches per worker = {len(train_ds)} "
            f"(~{len(train_ds) * batch_size_per_worker} samples)"
        )
  
    # 4. Report metrics and checkpoint the model
    report_metrics_and_checkpoint_callback = ReportCheckpointCallback(report_metrics_on="epoch_end")
    model.fit(
        train_ds,
        batch_size=batch_size_per_worker,
        epochs=epochs,
        callbacks=[report_metrics_and_checkpoint_callback],
    )

In [None]:
import os

train_config = {"batch_size": BATCH_SIZE, "epochs": 4, "buffer_size" : BUFFER_SIZE, "tf_data" : os.path.abspath('w2v.data.tf')}

scaling_config = ScalingConfig(num_workers=8, use_gpu=False)

trainer = TensorflowTrainer(
    train_loop_per_worker=train_func,
    train_loop_config=train_config,
    scaling_config=scaling_config,
)

In [None]:
result = trainer.fit()