In [None]:
import os
import boto3
import zipfile
import requests

import ray
from ray import tune
from ray import serve
from ray.air.config import ScalingConfig
from ray.train.xgboost import XGBoostTrainer
from ray.train.xgboost import XGBoostPredictor
from ray.train.batch_predictor import BatchPredictor
from ray.serve import PredictorDeployment
from ray.serve.http_adapters import pandas_read_json
from ray.tune import Tuner, TuneConfig

ray.init()

# Ray AIR

__Ray AIR is the Ray AI Runtime__, a set of high-level easy-to-use APIs for
ingesting data, training models – including reinforcement learning
models – tuning those models and then serving them.

<img src="https://technical-training-assets.s3.us-west-2.amazonaws.com/Introduction_to_Ray_AIR/e2e_air.png" width=600 loading="lazy"/>

Key principles behind Ray and Ray AIR are
* Performance
* Developer experience and simplicity

__Read, preprocess with Ray Data__

In [None]:
dataset = ray.data.read_parquet("s3://anyscale-training-data/intro-to-ray-air/nyc_taxi_2021.parquet")

train_dataset, valid_dataset = dataset.train_test_split(test_size=0.3)

__Fit model with Ray Train__

In [None]:
trainer = XGBoostTrainer(
    label_column="is_big_tip",
    scaling_config=ScalingConfig(num_workers=4, use_gpu=False),
    params={ "objective": "binary:logistic", },
    datasets={"train": train_dataset, "valid": valid_dataset},
)

result = trainer.fit()

__Optimize hyperparams with Ray Tune__

In [None]:
tuner = Tuner(trainer, 
            param_space={'params' : {'max_depth': tune.randint(2, 12)}},
            tune_config=TuneConfig(num_samples=4, metric='train-logloss', mode='min'))

checkpoint = tuner.fit().get_best_result().checkpoint

__Batch prediction__

In [None]:
batch_predictor = BatchPredictor.from_checkpoint(checkpoint, XGBoostPredictor)

predicted_probabilities = batch_predictor.predict(valid_dataset.drop_columns(['is_big_tip']))

__Online prediction with Ray Serve__

In [None]:
deployment = PredictorDeployment.bind(XGBoostPredictor, result.checkpoint, http_adapter=pandas_read_json)

serve.run(deployment)

__HTTP or Python services__

In [None]:
sample_input = dict(valid_dataset.take(1)[0])
del(sample_input['is_big_tip'])
del(sample_input['__index_level_0__'])
requests.post("http://localhost:8000/", json=[sample_input]).json()