#Fine-Tuning with Ray AIR and DeepSpeed on Databricks

In this example, we will showcase how to use the Ray AIR for **pythia-12b**. These causal language model trained on the Pile dataset(825 GB). This particular model has 12 billion parameters. For more information on pythia-12b click [here](https://huggingface.co/EleutherAI/pythia-12b).

We will use Ray AIR (with the 🤗 Transformers integration) and a pretrained model from Hugging Face hub. Note that you can easily adapt this example to use other similar models.

This example focuses more on the performance and distributed computing aspects of Ray AIR. 

It is highly recommended to read [Ray AIR Key Concepts](https://raw.githubusercontent.com/ray-project/ray/master/doc/source/ray-air/examples/air-key-concepts) and [Ray Data Key Concepts](https://raw.githubusercontent.com/ray-project/ray/master/doc/source/ray-air/examples/data_key_concepts) before starting this example.

```{note}
In order to run this example, make sure your Ray cluster has access to at least 8 GPU's with 24 or more GBs of memory. The amount of memory needed will depend on the model. This notebook has been  tested with 4 g5.24xlarge workers and g4dn.8xlarge head node.
```
Benefits:

- No command-line trigger / and provides real-time updates
- Better UI to understand JOB performance and adjust the batch-size and performance configd
- Ray data API support data in parquet,csv and hf format
- Has MLFlow integration to track Experiments

In this notebook, we will:
1. [Add Dependencies to run deepspeed](#Deepspeed)
2. [Set up Ray](#setup)
3. [Load the dataset](#load)
4. [Preprocess the dataset with Ray AIR](#preprocess)
5. [Run the training with Ray AIR](#train)
6. [Generate text from prompt with Ray AIR](#predict)

## Load Dependencies for Deepspeed <a name="deepspeed"></a>
Uncomment and run the following line in order to create an init script which loads the dependencies required for Deepspeed (this notebook is being tested with `transformers==4.26.0`):

In [0]:
# kernel_gateway_init = """
# #!/bin/bash

# wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/libcusparse-dev-11-3_11.5.0.58-1_amd64.deb -O /tmp/libcusparse-dev-11-3_11.5.0.58-1_amd64.deb && \
# wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/libcublas-dev-11-3_11.5.1.109-1_amd64.deb -O /tmp/libcublas-dev-11-3_11.5.1.109-1_amd64.deb && \
# wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/libcusolver-dev-11-3_11.1.2.109-1_amd64.deb -O /tmp/libcusolver-dev-11-3_11.1.2.109-1_amd64.deb && \
# wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/libcurand-dev-11-3_10.2.4.109-1_amd64.deb -O /tmp/libcurand-dev-11-3_10.2.4.109-1_amd64.deb && \
# dpkg -i /tmp/libcusparse-dev-11-3_11.5.0.58-1_amd64.deb && \
# dpkg -i /tmp/libcublas-dev-11-3_11.5.1.109-1_amd64.deb && \
# dpkg -i /tmp/libcusolver-dev-11-3_11.1.2.109-1_amd64.deb && \
# dpkg -i /tmp/libcurand-dev-11-3_10.2.4.109-1_amd64.deb
# """ 
# # Change ‘username’ to your Databricks username in DBFS
# # Example: username = “stephen.offer@databricks.com”
# username = "puneet.jain@databricks.com"
# dbutils.fs.put("dbfs:/Users/{0}/init/ray.sh".format(username), kernel_gateway_init, True)
# "dbfs:/Users/{0}/init/ray.sh".format(username)

## Set up Ray <a name="setup"></a>

First, Let us start a ray cluster based on the cluster configuration. we need to specify the number of cores and gpus available per worker to **setup_ray_cluster** to create the correct multi-node setup

In [0]:
#install dependencies
%pip install -r requirement.txt

[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m
Collecting accelerate<0.19.0,>=0.18.0
  Downloading accelerate-0.18.0-py3-none-any.whl (215 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 215.3/215.3 kB 3.7 MB/s eta 0:00:00
Collecting deepspeed<0.9.0,>=0.8.3
  Downloading deepspeed-0.8.3.tar.gz (765 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 765.4/765.4 kB 8.8 MB/s eta 0:00:00
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting transformers[torch]<5,>=4.28.1
  Downloading transformers-4.30.1-py3-none-any.whl (7.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.2/7.2 MB 26.1 MB/s eta 0:00:00
Collecting langchain>=0.0.139
  Downloading langchain-0.0.195-py3-none-any.whl (1.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 47.0 MB/s eta 0:00:00
Collecting awscli
  Downloading awscli-1.27.151-py3-none-any.whl (4.1 MB)
     ━━━━━━━━━━━━━━━━━━━

In [0]:
# Import all the packages
import os
import re
import json
import logging
import subprocess
import mlflow

from pathlib import Path
from functools import partial
from datetime import datetime
from typing import Any, Dict, List, Tuple, Union

import ray
from ray.air import session
import ray.util.scheduling_strategies
from ray.train.huggingface import HuggingFaceTrainer
from ray.air.integrations.mlflow import MLflowLoggerCallback
from ray.air.config import ScalingConfig

from ray.data.preprocessors import Chain
from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster,MAX_NUM_WORKER_NODES





from datasets import Dataset, load_dataset
from huggingface_hub import snapshot_download
from datasets import load_dataset,load_from_disk

import numpy as np
import pandas as pd
import torch 

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    PreTrainedTokenizer,
    Trainer,
    TrainingArguments,
    set_seed,
)

### Define variables (can be added as databricks widgets as well)

In [0]:
pretrained_model_name_or_path = "tiiuae/falcon-7b"
dataset_path= "teknium/GPT4-LLM-Cleaned"
dataset_type =  "alpaca:chat"
use_gpu = True
num_workers = 4 # Configure based on the total gpus across the worker node
num_cpu_cores_per_worker = 6 # total cpu's present in each node
num_gpu_per_worker = 1 # total gpu's present in each node
max_length = 2048
local_output_dir = '/tmp/run/details'
gradient_checkpointing = True
seed = 5432 
username = dbutils.notebook.entry_point.getDbutils().notebook().getContext().tags().apply('user')
experiment_location = f"/Users/{username}/dolly_multi-gpu"


In [0]:
# shutdown_ray_cluster()

In [0]:
# Start the ray cluster
setup_ray_cluster(
  num_worker_nodes=MAX_NUM_WORKER_NODES,
  num_cpus_per_node=num_cpu_cores_per_worker,
  num_gpus_per_node=num_gpu_per_worker,
  collect_log_to_path="/dbfs/path/to/ray_collected_logs")

2023-06-10 11:53:06,143	INFO usage_lib.py:408 -- Usage stats collection is enabled by default without user confirmation because this terminal is detected to be non-interactive. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.
2023-06-10 11:53:06,143	INFO scripts.py:712 -- [37mLocal node IP[39m: [1m10.68.186.210[22m
2023-06-10 11:53:08,100	SUCC scripts.py:749 -- [32m--------------------[39m
2023-06-10 11:53:08,100	SUCC scripts.py:750 -- [32mRay runtime started.[39m
2023-06-10 11:53:08,101	SUCC scripts.py:751 -- [32m--------------------[39m
2023-06-10 11:53:08,101	INFO scripts.py:753 -- [36mNext steps[39m
2023-06-10 11:53:08,101	INFO scripts.py:756 -- To add another node to this Ray cluster, run
2023-06-10 11:53:08,101	INFO scripts.py:759 -- [1m  ray start --address='10.68.186.210

2023-06-10 11:53:39,939	INFO worker.py:1452 -- Connecting to existing Ray cluster at address: 10.68.186.210:9569...
2023-06-10 11:53:39,951	INFO worker.py:1627 -- Connected to Ray cluster. View the dashboard at [1m[32mhttp://10.68.186.210:9380 [39m[22m


To monitor and debug Ray from Databricks, view the dashboard at 
 https://dbc-dp-6051921418418893.cloud.databricks.com/driver-proxy/o/6051921418418893/0603-224203-e80gcfmk/9380/


'10.68.186.210:9569'

We will use `ray.init()` to initialize the ray cluster in the current session.

We define a `runtime_env` to ensure that the Ray workers have access to all the necessary packages. You can omit the `runtime_env` argument if you have all of the packages already installed on each node in your cluster.

In [0]:
runtime_env = {
    "env_vars": {"RAY_memory_monitor_refresh_ms": "0"}
}
ray.init()

2023-06-10 11:53:44,308	INFO worker.py:1334 -- Using address 10.68.186.210:9569 set in the environment variable RAY_ADDRESS
2023-06-10 11:53:44,309	INFO worker.py:1452 -- Connecting to existing Ray cluster at address: 10.68.186.210:9569...
2023-06-10 11:53:44,315	INFO worker.py:1627 -- Connected to Ray cluster. View the dashboard at [1m[32mhttp://10.68.186.210:9380 [39m[22m


0,1
Python version:,3.10.6
Ray version:,2.5.0
Dashboard:,http://10.68.186.210:9380


we will catch the models in the local nodes to avoid getting it from HF Server everytime during training calling the `snapshot_download()` command from HF

In [0]:

def force_on_node(node_id: str, remote_func_or_actor_class):
    scheduling_strategy = ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
        node_id=node_id, soft=False
    )
    options = {"scheduling_strategy": scheduling_strategy}
    return remote_func_or_actor_class.options(**options)


def run_on_every_node(remote_func_or_actor_class, **remote_kwargs):
    refs = []
    for node in ray.nodes():
        if node["Alive"] and node["Resources"].get("GPU", None):
            refs.append(
                force_on_node(node["NodeID"], remote_func_or_actor_class).remote(
                    **remote_kwargs
                )
            )
    return ray.get(refs)


@ray.remote(num_gpus=1)
def download_model():
    snapshot_download(pretrained_model_name_or_path,resume_download=True) 
  

_ = run_on_every_node(download_model)

[2m[36m(download_model pid=6992, ip=10.68.175.10)[0m Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]
[2m[36m(download_model pid=6886, ip=10.68.178.177)[0m 
[2m[36m(download_model pid=6886, ip=10.68.178.177)[0m Downloading (…)/configuration_RW.py:   0%|          | 0.00/2.61k [00:00<?, ?B/s][ADownloading (…)/configuration_RW.py: 100%|██████████| 2.61k/2.61k [00:00<00:00, 1.58MB/s]
[2m[36m(download_model pid=6964, ip=10.68.174.155)[0m Downloading (…)35e5bf6ee5/README.md:   0%|          | 0.00/10.2k [00:00<?, ?B/s][ADownloading (…)35e5bf6ee5/README.md: 100%|██████████| 10.2k/10.2k [00:00<00:00, 5.77MB/s]
[2m[36m(download_model pid=6886, ip=10.68.178.177)[0m Fetching 12 files:   8%|▊         | 1/12 [00:02<00:25,  2.31s/it]
[2m[36m(download_model pid=6992, ip=10.68.175.10)[0m Downloading (…)6ee5/modelling_RW.py:   0%|          | 0.00/47.6k [00:00<?, ?B/s][A
[2m[36m(download_model pid=6992, ip=10.68.175.10)[0m Downloading (…)6ee5/modelling_RW.py: 100%

## Loading the dataset <a name="load"></a>

We will be fine-tuning the model on the the Databricks crowd sourced dataset , it comprised of 15,000 lines of Question and Answer pairs . 

We will use [Ray Data](https://raw.githubusercontent.com/ray-project/ray/master/doc/source/ray-air/examples/data) for distributed preprocessing and data ingestion. We can easily convert the dataset obtained from Hugging Face Hub to Ray Data by using `ray.data.from_huggingface` 

Note ingestion from Delta,Parquet,CSV is also supported via Ray Data API.

In [0]:
# Splitting the data into test and train 


current_dataset  = load_dataset(
                  dataset_path,
                  streaming=True )
# current_dataset = load_training_dataset()
# current_dataset = current_dataset.train_test_split(seed=DEFAULT_SEED)

# current_dataset['train'].select(list(range(0,1000))).save_to_disk("/local_disk0/train.hf")
# current_dataset['test'].select(list(range(0,1000))).save_to_disk('/local_disk0/test.hf')

# # current_dataset['train'].save_to_disk("/local_disk0/train.hf")
# # current_dataset['test'].save_to_disk('/local_disk0/test.hf')
# del current_dataset

# # load the final data as ray data-set
# train_dataset = ray.data.from_huggingface(load_from_disk('/local_disk0/train.hf'))
# test_dataset = ray.data.from_huggingface(load_from_disk('/local_disk0/test.hf'))

Downloading readme:   0%|          | 0.00/501 [00:00<?, ?B/s]

In [0]:
dataset_type
d_type_split = dataset_type.split(":")
d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None

In [0]:
d_base_type

'alpaca'

## Preprocess the Dataset

We will need to do some preprocessing. For that, we will define two [Ray AIR Preprocessors](https://raw.githubusercontent.com/ray-project/ray/master/doc/source/ray-air/examples/air-preprocessors) using the {class}`~ray.data.preprocessors.BatchMapper` API, allowing us to define functions that will be applied on batches of data.

The `preprocess` function will call The `tokenize` function will take the lines and tokenize them using the 🤗 Tokenizer associated with the model, ensuring each entry has the same length (`max_length`) by padding and truncating. This is necessary for training.

```{note}
This preprocessing can be done in other ways. A common pattern is to tokenize first, and then split the obtained tokens into equally-sized blocks.
```

In [0]:
import ray
from pathlib import Path
from ray import tune
from datasets import Dataset
from ray.data.preprocessors import Chain, BatchMapper
from ray.air.util.check_ingest import DummyTrainer
from ray.air.config import ScalingConfig,RunConfig,CheckpointConfig


from training.trainer import load_tokenizer,preprocess_batch,\
                             DataCollatorForCompletionOnlyLM,get_model_tokenizer

def preprocess(batch):
  tokenizer = load_tokenizer(pretrained_model_name_or_path)
  seed=DEFAULT_SEED
  dataset = Dataset.from_pandas(batch)
  _preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer)
  dataset = dataset.map(
      _preprocessing_function,
      batched=True,
      remove_columns=["instruction", "context", "response", "text", "category"],
  )

  # Make sure we don't have any truncated records, as this would mean the end keyword is missing.
  dataset = dataset.filter(lambda rec: len(rec["input_ids"]) < max_length)
  dataset = dataset.shuffle(seed=seed)
  return dataset.to_pandas()


preprocessor = Chain(
    BatchMapper(preprocess, batch_format="pandas")
)


### Fine-tuning the model with Ray AIR <a name="train"></a>

We can now configure Ray AIR's {class}`~ray.train.huggingface.huggingface_trainer.HuggingFaceTrainer` to perform distributed fine-tuning of the model. In order to do that, we specify a `trainer_init_per_worker` function, which creates a 🤗 Transformers `Trainer` that will be distributed by Ray using Distributed Data Parallelism (using PyTorch Distributed backend internally). This means that each worker will have its own copy of the model, but operate on different data, At the end of each step, all the workers will sync gradients.

Because pythia-12b is a relatively large model, it may not be possible to fit it on smaller GPU types (<=16 GB GRAM). To deal with that issue, we can use [DeepSpeed](https://github.com/microsoft/DeepSpeed), a library to optimize the training process and allow us to (among other things) offload and partition optimizer and parameter states, reducing GRAM usage. Furthermore, DeepSpeed ZeRO Stage 3 allows us to load large models without running out of memory.

🤗 Transformers and Ray AIR's integration ({class}`~ray.train.huggingface.huggingface_trainer.HuggingFaceTrainer`) allow you to easily configure and use DDP and DeepSpeed. All you need to do is specify the DeepSpeed configuration in the [`TrainingArguments`](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments) object.

```{tip}
There are many DeepSpeed settings that allow you to trade-off speed for memory usage. The settings used below are tailored to the cluster setup used (16 g4dn.4xlarge nodes) and per device batch size of 16. Some things to keep in mind:
- If your GPUs support bfloat16, use that instead of float16 mixed precision to get better performance and prevent overflows. Replace `fp16=True` with `bf16=True` in `TrainingArguments`.
- If you are running out of GRAM: try reducing batch size (defined in the cell below the next one), set `"overlap_comm": False` in DeepSpeed config.
- If you are running out of RAM, add more nodes to your cluster, use nodes with more RAM, set `"pin_memory": False` in the DeepSpeed config, reduce the batch size, and remove `"offload_param"` from the DeepSpeed config.

For more information on DeepSpeed configuration, refer to [Hugging Face documentation](https://huggingface.co/docs/transformers/main_classes/deepspeed) and [DeepSpeed documentation](https://www.deepspeed.ai/docs/config-json/).

Additionally, if you prefer a lower-level API, the logic below can be expressed as an [Accelerate training loop](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/deepspeed_with_config_support.py) distributed by a Ray AIR {class}`~ray.train.torch.torch_trainer.TorchTrainer`.
```

#### Training speed

As we are using data parallelism, each worker operates on its own shard of the data. The batch size set in `TrainingArguments` is the **per device batch size** (per worker batch size). By changing the number of workers, we can change the **effective batch size** and thus the time needed for training to complete. The effective batch size is then calculated as `per device batch size * number of workers * number of gradient accumulation steps`. As we add more workers, the effective batch size rises and thus we need less time to complete a full epoch. While the speedup is not exactly linear due to extra communication overheads, in many cases it can be close to linear.

The preprocessed dataset has ~15000 examples. We have set per device batch size to 10.

* With 4 g5.24xlarge nodes, the effective batch size was 160, which equals to 85 steps per epoch. two epoch took 2.27 hours (including initialization and saving time).

In [0]:
def trainer_init_per_worker(train_dataset, eval_dataset=None, **config):

    set_seed(seed)

    # Use the actual number of CPUs assigned by Ray
    os.environ["OMP_NUM_THREADS"] = str(
        session.get_trial_resources().bundles[-1].get("CPU", 1)
    )
    # Enable tf32 for better performance
    torch.backends.cuda.matmul.allow_tf32 = True

    # Get config details

    epochs = config.get("epochs")
    lr = config.get("lr")
    per_device_train_batch_size = config.get("per_device_train_batch_size")
    per_device_eval_batch_size = config.get("per_device_eval_batch_size")
    logging_steps = config.get("logging_steps")
    save_strategy= config.get("save_strategy")
    evaluation_strategy = config.get("evaluation_strategy")
    save_steps = config.get("save_steps")
    eval_steps = config.get("eval_steps") 
    warmup_steps = config.get("warmup_steps")
    disable_tqdm=config.get("disable_tqdm")
    remove_unused_columns=config.get("remove_unused_columns")
    deepspeed=config.get("deepspeed", "configs/ds_z3_bf16_config.json")

    with open('/tmp'+'/deepspeed.json', 'w') as f:
      json.dump(deepspeed, f)

    print("Preparing training arguments")
    training_args = TrainingArguments(
        output_dir=local_output_dir,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        fp16=False, # change to true if using v100
        bf16=True,# chenge to false if using v100
        learning_rate=lr,
        num_train_epochs=epochs,
        deepspeed='/tmp'+'/deepspeed.json',
        gradient_checkpointing=gradient_checkpointing,
        logging_strategy=evaluation_strategy,
        logging_steps=logging_steps,
        evaluation_strategy=evaluation_strategy,
        eval_steps=eval_steps,
        save_strategy=save_strategy,
        save_steps=save_steps,
        load_best_model_at_end=False,
        disable_tqdm=disable_tqdm,
        remove_unused_columns=remove_unused_columns,
        warmup_steps=warmup_steps)

    print("Loading model")

    model, tokenizer = get_model_tokenizer(
        pretrained_model_name_or_path=pretrained_model_name_or_path, gradient_checkpointing=gradient_checkpointing
    )

    print("Model loaded")
    print("Train data size: %d", len(train_dataset))
    print("Test data size: %d", len(eval_dataset))

    data_collator = DataCollatorForCompletionOnlyLM(
        tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator
    )
    return trainer

With our `trainer_init_per_worker` complete, we can now instantiate the {class}`~ray.train.huggingface.huggingface_trainer.HuggingFaceTrainer`. Aside from the function, we set the `scaling_config`, controlling the amount of workers and resources used, and the `datasets` we will use for training and evaluation.

We pass the preprocessors we have defined earlier as an argument, wrapped in a {class}`~ray.data.preprocessors.chain.Chain`. The preprocessor will be included with the returned {class}`~ray.air.checkpoint.Checkpoint`, meaning it will also be applied during inference.

In [0]:
# get or create experiment
get_or_create_experiment(experiment_location)


sync_config = tune.SyncConfig(
  syncer=None
    # upload_dir="s3://one-env-eu-west-1/users/puneet.jain@databricks.com/",  # requires AWS credential
    )
#Create tags to log with mlflow
tags = dict(
  local_dir = f"/dbfs/{username}/dolly_train/job/",
  base_model_dir = pretrained_model_name_or_path,
  n_gpus = str(num_workers),
  num_cpu_cores_per_worker = str(num_cpu_cores_per_worker),
  num_gpu_per_worker = str(num_gpu_per_worker),  
  max_length = str(max_length),
  username = username )

root_path = os.getcwd()
deepspeed_config = os.path.join(root_path, "config/ds_z3_bf16_config.json")

with open(deepspeed_config) as json_data:
    deepspeed_config = json.load(json_data)

trainer = HuggingFaceTrainer(
    trainer_init_per_worker=trainer_init_per_worker,
    trainer_init_config={
        "deepspeed": deepspeed_config, 
        "lr" : 1e-6, # per device
        "per_device_train_batch_size" : 10,
        "per_device_eval_batch_size" : 10,
        "save_strategy" : "no",
        "evaluation_strategy" : "steps",
        "logging_steps" : 50,
        "save_steps" : 200,
        "eval_steps" : 50,
        "warmup_steps" : 25,
        "disable_tqdm" : True,
        "remove_unused_columns" :False,
        "epochs": 3},
    scaling_config=ScalingConfig(
        num_workers=16,
        use_gpu=use_gpu,
        resources_per_worker={"GPU": 1, 
                              "CPU": 22}), # should be total cores in node /total gpu's in node -2
    run_config = RunConfig(
                local_dir =  f"/dbfs/{username}/dolly_train/job/",
                callbacks=[MLflowLoggerCallback(experiment_name=experiment_location,
                                                tags = tags,
                                                save_artifact=False)],
                sync_config=sync_config,
                checkpoint_config = CheckpointConfig(num_to_keep = 1, 
                                                     checkpoint_score_attribute = 'eval_loss',
                                                     checkpoint_score_order = 'min') 
    ),
    datasets={"train": train_dataset ,
              "evaluation" : test_dataset},
    preprocessor=preprocessor,
)

Finally, we call the `~ray.train.huggingface.huggingface_trainer.HuggingFaceTrainer.fit` method to start training with Ray AIR. We will save the `~ray.air.Result` object to a variable so we can access metrics and checkpoints.

In [0]:
results = trainer.fit()

## Fetch the best model parameters

You can use the returned {class}`~ray.air.Result` object to access metrics and the Ray AIR {class}`~ray.air.checkpoint.Checkpoint` associated with the last iteration.

In [0]:
checkpoint = results.checkpoint
checkpoint

### Generate text from prompt

In [0]:
import ray.data
import pandas as pd
from training.generate import PredictCallable

instructions = [
    "Write a love letter to Edgar Allan Poe.",
    "Write a tweet announcing Dolly, a large language model from Databricks.",
    "I'm selling my Nikon D-750, write a short blurb for my ad.",
    "Explain to me the difference between nuclear fission and fusion.",
    "Give me a list of 5 science fiction books I should read next."]

ds = ray.data.from_pandas(pd.DataFrame(pd.Series(instructions),columns=["prompt"]))

In [0]:
preds = (
    ds
    .repartition(5)
    .map_batches(
        PredictCallable,
        batch_size=1,
        fn_constructor_kwargs=dict(checkpoint=checkpoint.uri.split('file://')[1],
                                   torch_dtype = "bfloat16" ),#change to float16 when using V100
        batch_format="pandas",
        compute=ray.data.ActorPoolStrategy(),
        num_gpus=4,
    )
)

In [0]:
preds.take_all()