# GPT-J-6B Batch Prediction with Ray AIR

This example showcases how to use the Ray AIR for **GPT-J batch inference**. GPT-J is a GPT-2-like causal language model trained on the Pile dataset. This model has 6 billion parameters. For more information on GPT-J, click [here](https://huggingface.co/docs/transformers/model_doc/gptj).

We use Ray Data and a pretrained model from Hugging Face hub. Note that you can easily adapt this example to use other similar models.

It is highly recommended to read [Ray AIR Key Concepts](air-key-concepts) and [Ray Data Key Concepts](data_key_concepts) before starting this example.

If you are interested in serving (online inference), see {doc}`/ray-air/examples/gptj_serving`.

```{note}
In order to run this example, make sure your Ray cluster has access to at least one GPU with 16 or more GBs of memory. The amount of memory needed will depend on the model.
```

In [1]:
model_id = "EleutherAI/gpt-j-6B"
revision = "float16"  # use float16 weights to fit in 16GB GPUs
prompt = (
    "In a shocking finding, scientists discovered a herd of unicorns living in a remote, "
    "previously unexplored valley, in the Andes Mountains. Even more surprising to the "
    "researchers was the fact that the unicorns spoke perfect English."
)

In [2]:
import ray

We define a {ref}`runtime environment <runtime-environments>` to ensure that the Ray workers have access to all the necessary packages. You can omit the `runtime_env` argument if you have all of the packages already installed on each node in your cluster.

In [None]:
ray.init(
    runtime_env={
        "pip": [
            "accelerate>=0.16.0",
            "transformers>=4.26.0",
            "numpy<1.24",  # remove when mlflow updates beyond 2.2
            "torch",
        ]
    }
)

For the purposes of this example, we will use a very small toy dataset composed of multiple copies of our prompt. Ray Data can handle much bigger datasets with ease.

In [4]:
import ray.data
import pandas as pd

ds = ray.data.from_pandas(pd.DataFrame([prompt] * 10, columns=["prompt"]))

Since we will be using a pretrained model from Hugging Face hub, the simplest way is to use {meth}`map_batches <ray.data.Dataset.map_batches>` with a [callable class UDF](transforming_data_actors). This will allow us to save time by initializing a model just once and then feed it multiple batches of data.

In [5]:
class PredictCallable:
    def __init__(self, model_id: str, revision: str = None):
        from transformers import AutoModelForCausalLM, AutoTokenizer
        import torch

        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            revision=revision,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            device_map="auto",  # automatically makes use of all GPUs available to the Actor
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)

    def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
        tokenized = self.tokenizer(
            list(batch["prompt"]), return_tensors="pt"
        )
        input_ids = tokenized.input_ids.to(self.model.device)
        attention_mask = tokenized.attention_mask.to(self.model.device)

        gen_tokens = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            do_sample=True,
            temperature=0.9,
            max_length=100,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        return pd.DataFrame(
            self.tokenizer.batch_decode(gen_tokens), columns=["responses"]
        )

All that is left is to run the `map_batches` method on the dataset. We specify that we want to use one GPU for each Ray Actor that will be running our callable class.

Also notice that we repartition the dataset into 100 partitions before mapping batches. This is to make sure there will be enough parallel tasks to take advantage of all the GPUs. 100 is an arbitrary number. You can pick any other numbers as long as it is more than the number of available GPUs in the cluster.

```{tip}
If you have access to large GPUs, you may want to increase the batch size to better saturate them.

If you want to use inter-node model parallelism, you can also increase `num_gpus`. As we have created the model with `device_map="auto"`, it will be automatically placed on correct devices. Note that this requires nodes with multiple GPUs.
```

In [6]:
preds = (
    ds
    .repartition(100)
    .map_batches(
        PredictCallable,
        batch_size=4,
        fn_constructor_kwargs=dict(model_id=model_id, revision=revision),
        batch_format="pandas",
        compute=ray.data.ActorPoolStrategy(),
        num_gpus=1,
    )
)

After `map_batches` is done, we can view our generated text.

In [7]:
preds.take_all()

2023-02-28 10:40:50,530	INFO bulk_executor.py:41 -- Executing DAG InputDataBuffer[Input] -> ActorPoolMapOperator[MapBatches(PredictCallable)]
MapBatches(PredictCallable), 0 actors [0 locality hits, 1 misses]: 100%|██████████| 1/1 [12:10<00:00, 730.80s/it]


[{'responses': 'In a shocking finding, scientists discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.\n\nThe finding comes from the team of researchers, which includes Dr. Michael Goldberg, a professor and chair of the Zoology Department at the University of Maryland. Dr. Goldberg spent a year collecting and conducting research in the Ecuadorian Andes, including the Pinchahu'},
 {'responses': 'In a shocking finding, scientists discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.\n\nThe team of British, Argentine and Chilean scientists found that the elusive unicorns had been living in the valley for at least 50 years, and had even interacted with humans.\n\nThe team’s findings published in the journa

You may notice that we are not using an AIR {class}`Predictor <ray.train.predictor.Predictor>` here. This is because Predictors are mainly intended to be used with AIR {class}`Checkpoints <ray.air.checkpoint.Checkpoint>`, which we don't for this example. See {ref}`air-predictors` for more information and usage examples.