(lightning_advanced_example)=

# Finetune a BERT Text Classifier with LightningTrainer

:::{note}

This is an advanced example for {class}`LightningTrainer <ray.train.lightning.LightningTrainer>`, which demonstrates how to use LightningTrainer with `Datastream` and `Batch Predictor`. 

If you just want to quickly convert your existing PyTorch Lightning scripts into Ray AIR, you can refer to this starter example:
{ref}`Train a Pytorch Lightning Image Classifier <lightning_mnist_example>`.

:::

In this demo, we will introduce how to finetune a text classifier on [CoLA(The Corpus of Linguistic Acceptability)](https://nyu-mll.github.io/CoLA/) datasets with pretrained BERT. 
In particular, we will:
- Create Ray Data from the original CoLA dataset.
- Define a preprocessor to tokenize the sentences.
- Finetune a BERT model using LightningTrainer.
- Construct a BatchPredictor with the checkpoint and preprocessor.
- Do batch prediction on multiple GPUs, and evaluate the results.

In [35]:
SMOKE_TEST = True

In [3]:
import ray
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset, load_metric
import numpy as np

## 1. Pre-process CoLA Datastream

CoLA is a binary sentence classification task with 10.6K training examples. First, we download the dataset and metrics using the HuggingFace API, and create Ray Data for each split accordingly.

In [4]:
dataset = load_dataset("glue", "cola")
metric = load_metric("glue", "cola")

Reusing dataset glue (/home/ray/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
100%|██████████| 3/3 [00:00<00:00, 948.44it/s]


In [None]:
ray_datasets = ray.data.from_huggingface(dataset)

Next, define a preprocessor that tokenizes the input sentences and pads the ID sequence to length 128 using the bert-base-uncased tokenizer. The preprocessor transforms all datasets that we provide to the LightningTrainer later.

In [6]:
from ray.data.preprocessors import BatchMapper

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")


def tokenize_sentence(batch):
    encoded_sent = tokenizer(
        batch["sentence"].tolist(),
        max_length=128,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )
    batch["input_ids"] = encoded_sent["input_ids"].numpy()
    batch["attention_mask"] = encoded_sent["attention_mask"].numpy()
    batch["label"] = np.array(batch["label"])
    batch.pop("sentence")
    return batch


preprocessor = BatchMapper(tokenize_sentence, batch_format="numpy")

## 2. Define a PyTorch Lightning Model

You don't have to make any change of your `LightningModule` definition. Just copy and paste your code here:

In [7]:
class SentimentModel(pl.LightningModule):
    def __init__(self, lr=2e-5, eps=1e-8):
        super().__init__()
        self.lr = lr
        self.eps = eps
        self.num_classes = 2
        self.model = AutoModelForSequenceClassification.from_pretrained(
            "bert-base-cased", num_labels=self.num_classes
        )
        self.metric = load_metric("glue", "cola")
        self.predictions = []
        self.references = []

    def forward(self, batch):
        input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]
        outputs = self.model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        return logits

    def training_step(self, batch, batch_idx):
        labels = batch["label"]
        logits = self.forward(batch)
        loss = F.cross_entropy(logits.view(-1, self.num_classes), labels)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        labels = batch["label"]
        logits = self.forward(batch)
        preds = torch.argmax(logits, dim=1)
        self.predictions.append(preds)
        self.references.append(labels)

    def on_validation_epoch_end(self):
        predictions = torch.concat(self.predictions).view(-1)
        references = torch.concat(self.references).view(-1)
        matthews_correlation = self.metric.compute(
            predictions=predictions, references=references
        )

        # self.metric.compute() returns a dictionary:
        # e.g. {"matthews_correlation": 0.53}
        self.log_dict(matthews_correlation, sync_dist=True)
        self.predictions.clear()
        self.references.clear()

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps)

## 3. Finetune the model with LightningTrainer

Define a LightningTrainer with necessary configurations, including hyper-parameters, checkpointing and compute resources settings. 

You may find the API of {class}`LightningConfigBuilder <ray.train.lightning.LightningConfigBuilder>` useful.


In [8]:
from ray.train.lightning import LightningTrainer, LightningConfigBuilder
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig

# Define the configs for LightningTrainer
lightning_config = (
    LightningConfigBuilder()
    .module(cls=SentimentModel, lr=1e-5, eps=1e-8)
    .trainer(max_epochs=5, accelerator="gpu")
    .checkpointing(save_on_train_epoch_end=False)
    .build()
)

# Save AIR checkpoints according to the performance on validation set
run_config = RunConfig(
    name="ptl-sent-classification",
    checkpoint_config=CheckpointConfig(
        num_to_keep=2,
        checkpoint_score_attribute="matthews_correlation",
        checkpoint_score_order="max",
    ),
)

# Scale the training workload across 4 GPUs
# You can change this config based on your compute resources.
scaling_config = ScalingConfig(
    num_workers=4, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)

In [9]:
if SMOKE_TEST:
    lightning_config = (
        LightningConfigBuilder()
        .module(cls=SentimentModel, lr=1e-5, eps=1e-8)
        .trainer(max_epochs=2, accelerator="gpu")
        .checkpointing(save_on_train_epoch_end=False)
        .build()
    )

    for split, ds in ray_datasets.items():
        ray_datasets[split] = ds.random_sample(0.1)

Train the model with the configuration we specified above. 

To feed data into LightningTrainer, we need to configure the following arguments:

- datasets: A dictionary of the input Ray datasets, with special keys "train" and "val".
- datasets_iter_config: The argument list of {meth}`iter_torch_batches() <ray.data.Datastream.iter_torch_batches>`. It defines the way we iterate dataset shards for each worker.
- preprocessor: The preprocessor that will be applied to the input dataset.

:::{note}
Note that we are using Datastream for data ingestion for faster preprocessing here, but you can also continue to use the native `PyTorch DataLoader` or `LightningDataModule`. See {ref}`this example <lightning_mnist_example>`. 

:::


Now, call `trainer.fit()` to initiate the training process.

In [None]:
trainer = LightningTrainer(
    lightning_config=lightning_config,
    run_config=run_config,
    scaling_config=scaling_config,
    datasets={"train": ray_datasets["train"], "val": ray_datasets["validation"]},
    datasets_iter_config={"batch_size": 16},
    preprocessor=preprocessor,
)
result = trainer.fit()

In [11]:
result

Result(
  metrics={'_report_on': 'validation_end', 'train_loss': 0.05989973247051239, 'matthews_correlation': 0.5175218541439164, 'epoch': 4, 'step': 670, 'should_checkpoint': True, 'done': True, 'trial_id': '5ae4c_00000', 'experiment_tag': '0'},
  path='/home/ray/ray_results/ptl-sent-classification/LightningTrainer_5ae4c_00000_0_2023-04-05_12-45-05',
  checkpoint=LightningCheckpoint(local_path=/home/ray/ray_results/ptl-sent-classification/LightningTrainer_5ae4c_00000_0_2023-04-05_12-45-05/checkpoint_000004)
)

## 4. Do Batch Inference with a Saved Checkpoint

Now that we have fine-tuned the module, we can load the checkpoint into a BatchPredictor and perform fast inference with multiple GPUs. It will distribute the inference workload across multiple workers when calling `predict()` and run prediction on multiple shards of data in parallel. 

You can find more details in [Using Predictors for Inference](air-predictors).

In [8]:
from ray.train.batch_predictor import BatchPredictor
from ray.train.lightning import LightningCheckpoint, LightningPredictor

# Use in-memory checkpoint object
checkpoint = result.checkpoint

# You can also load a checkpoint from disk:
# checkpoint = LightningCheckpoint.from_directory("YOUR_CHECKPOINT_DIR")

batch_predictor = BatchPredictor(
    checkpoint=checkpoint,
    predictor_cls=LightningPredictor,
    use_gpu=True,
    model_class=SentimentModel,
    preprocessor=preprocessor,
)

In [33]:
# Use 2 GPUs for batch inference
predictions = batch_predictor.predict(
    ray_datasets["validation"],
    feature_columns=["input_ids", "attention_mask", "label"],
    keep_columns=["label"],
    batch_size=16,
    min_scoring_workers=2,
    max_scoring_workers=2,
    num_gpus_per_worker=1,
)

We obtained a Ray dataset containing predictions from `batch_predictor.predict()`. Now we can easily evaluate the results with just a few lines of code:

In [31]:
# Internally, BatchPredictor calls forward() method of the LightningModule.
# Convert the logits tensor into labels with argmax.
def argmax(batch):
    batch["predictions"] = batch["predictions"].apply(lambda x: np.argmax(x))
    return batch


results = predictions.map_batches(argmax).to_pandas()

matthews_corr = metric.compute(
    predictions=results["predictions"], references=results["label"]
)
print(matthews_corr)

{'matthews_correlation': 0.5175218541439164}
