# Fine-tune Dolly-v2-3b with Ray AIR LightningTrainer and FSDP

In [1]:
# TODO(yunxuan): remove it
import os
os.environ["RAY_ML_DEV"] = "1"

## Set up ray cluster 
In this example, we are using a ray cluster with 16 g4dn.4xlarge instances. Each instance has one Tesla T4 GPU (16GiB Memory). 

We define a `runtime_env` to install the necessary Python libraries on each node. You can skip this step if you have already installed all the required packages in your workers' base image.

In [None]:
import ray

ray.init(
    runtime_env={
        "pip": [
            "datasets",
            "evaluate",
            "transformers>=4.26.0",
            "torch>=1.12.0",
            "pytorch_lightning>=2.0",
        ]
    }
)

In [2]:
num_workers = 16
batch_size_per_worker = 8
MODEL_NAME = "databricks/dolly-v2-3b"

## Prepare your data 
We are using tiny_shakespeare for finetuning, which contains 40,000 lines of Shakespeare from a variety of Shakespeare's plays. Featured in Andrej Karpathy's blog post ['The Unreasonable Effectiveness of Recurrent Neural Networks'](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). 

Data samples:
```
BAPTISTA:
I know him well: you are welcome for his sake.

GREMIO:
Saving your tale, Petruchio, I pray,
Let us, that are poor petitioners, speak too:
Baccare! you are marvellous forward.

PETRUCHIO:
O, pardon me, Signior Gremio; I would fain be doing.
```

Here, we have adopted similar pre-processing logic from another demo: {ref}`GPT-J-6B Fine-Tuning with Ray AIR and DeepSpeed <gpt-j-6b-finetune-deepspeed>`.

In [None]:
import ray
import pandas as pd
from datasets import load_dataset
from ray.data.preprocessors import BatchMapper, Chain
from transformers import AutoTokenizer, AutoModelForCausalLM

def split_text(batch: pd.DataFrame) -> pd.DataFrame:
    text = list(batch["text"])
    flat_text = "".join(text)
    split_text = [
        x.strip()
        for x in flat_text.split("\n")
        if x.strip() and not x.strip()[-1] == ":"
    ]
    return pd.DataFrame(split_text, columns=["text"])


def tokenize(batch: pd.DataFrame) -> dict:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
    tokenizer.pad_token = tokenizer.eos_token
    ret = tokenizer(
        list(batch["text"]),
        truncation=True,
        max_length=256,
        padding="max_length",
        return_tensors="np",
    )
    ret["labels"] = ret["input_ids"].copy()
    return dict(ret)

splitter = BatchMapper(split_text, batch_format="pandas")
tokenizer = BatchMapper(tokenize, batch_format="pandas")
preprocessor = Chain(splitter, tokenizer)

hf_dataset = load_dataset("tiny_shakespeare")
ray_datasets = ray.data.from_huggingface(hf_dataset)

## Define your lightning model

We are using the Dolly-v2-3b model for finetuning. It is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use. We load the model weights from Huggingface Model Hub and encapsulate it as a `pl.LightningModule`.

:::{note}
Make sure you pass the FSDP wrapped model parameters `self.trainer.model.parameters()` into the optimizer, instead of `self.model.parameters()`. 
:::


In [3]:
import torch
import pytorch_lightning as pl

class DollyV2Model(pl.LightningModule):
    def __init__(self, lr=2e-5, eps=1e-8):
        super().__init__()
        self.lr = lr
        self.eps = eps
        self.model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
        self.predictions = []
        self.references = []

    def forward(self, batch):
        outputs = self.model(
            batch["input_ids"], 
            attention_mask=batch["attention_mask"], 
            labels=batch["labels"]
        )
        loss = outputs[0]
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.forward(batch)
        self.log("train_loss", loss, prog_bar=True, on_step=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.trainer.model.parameters(), lr=self.lr, eps=self.eps)

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset tiny_shakespeare (/home/ray/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e)
100%|██████████| 3/3 [00:00<00:00, 1081.56it/s]
2023-05-01 00:56:50,821	INFO worker.py:1432 -- Connecting to existing Ray cluster at address: 10.0.43.253:6379...
2023-05-01 00:56:50,830	INFO worker.py:1607 -- Connected to Ray cluster. View the dashboard at https://console.anyscale-staging.com/api/v2/sessions/ses_m411tiqu8eluvt1k5ivfqj4q5r/services?redirect_to=dashboard 
2023-05-01 00:56:51,432	INFO packaging.py:520 -- Creating a file package for local directory '/tmp/ray_tmp_module/ray'.
2023-05-01 00:56:52,344	INFO packaging.py:347 -- Pushing file package 'gcs://_ray_pkg_be300ec1f2b48a35.zip' (155.57MiB) to Ray cluster...
2023-05-01 00:56:52,879	INFO packaging.py:360 -- Successfully pushed file package 'gcs://_ray_pkg_be300ec1f2b48a35.zip'.
2023-05-01 00:56:53,392	

## Configure your FSDP strategy
As Dolly-v2-3b is a relatively large model, it cannot be properly fit into a single commercial GPU. In this example, we use the FSDP strategy to shard model parameters across multiple workers. This allows us to avoid GPU out-of-memory issues and support a larger global batch size.

:::{note}
FSDP is a type of data parallelism that shards model parameters, optimizer states and gradients across DDP ranks. This was inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. You may refer to these blogs for more information:

- [Getting Started with Fully Sharded Data Parallel(FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html#:~:text=FSDP%20is%20a%20type%20of,sizes%20for%20our%20training%20job.)
- [Fully Sharded Data Parallel: faster AI training with fewer GPUs](https://engineering.fb.com/2021/07/15/open-source/fsdp/)
- [PyTorch FSDP Tutorial](https://www.youtube.com/watch?v=8_k76AHu__s&list=PL_lsbAsL_o2BT6aerEKgIoufVD_fodnuT)
:::

To start trainig with Lightning's [FSDPStrategy](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.FSDPStrategy.html#lightning.pytorch.strategies.FSDPStrategy), you only need to provide the initialization arguments in `LightningConfigBuilder.strategy()`. Behind the scenes, LightningTrainer handles the cluster environment settings and job launching.


:::{tips}
Some tips for FSDP configutarion:
- `sharding_strategy`:
    - `ShardingStrategy.NO_SHARD`: Parameters, gradients, and optimizer states are not sharded. Similar to DDP.
    - `ShardingStrategy.SHARD_GRAD_OP`: Gradients and optimizer states are sharded during computation, and additionally, parameters are sharded outside computation.
    - `ShardingStrategy.FULL_SHARD`: Parameters, gradients, and optimizer states are sharded. It has minimal GRAM usage among the 3 options.
- `auto_wrap_policy`:
    - Model layers are often wrapped with FSDP in a layered fashion. This means that only the layers in a single FSDP instance are required to aggregate all parameters to a single device during forwarding or backward calculations.
    - Use `transformer_auto_wrap_policy` to automatically wrap each Transformer Block into a single FSDP instance. 
- `backward_prefetch` and `forward_prefetch`:
    - Overlap the upcoming all-gather while executing the current forward/backward pass. It can improve throughput but may slightly increase peak memory usage.
:::

In [None]:
import functools
from ray.train.lightning import LightningTrainer, LightningConfigBuilder
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp import ShardingStrategy, BackwardPrefetch
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer

# Define the model sharding policy:
# Wrap every GPTNeoXLayer as its own FSDP instance
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls = {GPTNeoXLayer}
)

# Aggregate all arguments for LightningTrainer
lightning_config = (
    LightningConfigBuilder()
    .module(cls=DollyV2Model, lr=2e-5, eps=1e-8)
    .trainer(
        max_epochs=1, 
        accelerator="gpu", 
        precision="16-mixed",
    )
    .strategy(
        name="fsdp",
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        forward_prefetch=True,
        auto_wrap_policy=auto_wrap_policy,
    )
    .checkpointing(save_last=True)
)

In [None]:
from pytorch_lightning.callbacks import TQDMProgressBar

# Create a customized progress bar for LightningTrainer
class DollyV2ProgressBar(TQDMProgressBar):
    def __init__(self, num_iters_per_epoch, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_iters_per_epoch = num_iters_per_epoch
    
    def on_train_epoch_start(self, trainer, *_):
        super().on_train_epoch_start(trainer, *_)
        self.train_progress_bar.reset(self.num_iters_per_epoch)

total_batches = splitter.fit_transform(ray_datasets["train"]).count()
num_iters_per_epoch = total_batches // (num_workers * batch_size_per_worker)
lightning_config.trainer(callbacks=[DollyV2ProgressBar(num_iters_per_epoch)])

## Fine-tune with LightningTrainer

```{note}
Here we upload the checkpoints to cloud storage by setting S3 bucket URI to {class}`air.RunConfig(storage_path) <ray.air.RunConfig>`. You can also write to your local file system. See {ref}`train-run-config` for an example.
```

In [None]:
# Save AIR checkpoints according to the performance on validation set
run_config = RunConfig(
    name=f"finetune_dolly-v2-3b",
    storage_path="s3://anyscale-staging-data-cld-kvedzwag2qa8i5bjxuevf5i7/ray-lightning-results/",
    checkpoint_config=CheckpointConfig(),
)

# Scale the DDP training workload across 16 GPUs
# You can change this config based on your compute resources.
scaling_config = ScalingConfig(
    num_workers=num_workers, use_gpu=True, resources_per_worker={"CPU": 12, "GPU": 1}
)

trainer = LightningTrainer(
    lightning_config=lightning_config.build(),
    run_config=run_config,
    scaling_config=scaling_config,
    datasets={"train": ray_datasets["train"]},
    datasets_iter_config={"batch_size": batch_size_per_worker},
    preprocessor=preprocessor,
)
result = trainer.fit()

result


In [19]:
checkpoint_uri = result.checkpoint.uri
checkpoint_local_dir = "/tmp/model-checkpoint"
# checkpoint_uri = "s3://yunxuanx-test/model-checkpoint/finetune-dolly-v2/LightningTrainer_18e2e_00000_0_2023-04-30_17-57-08/checkpoint_000000/"

os.makedirs(checkpoint_local_dir, exist_ok=True)
os.system(f"aws s3 sync {checkpoint_uri} {checkpoint_local_dir}")

download: s3://yunxuanx-test/model-checkpoint/finetune-dolly-v2/LightningTrainer_18e2e_00000_0_2023-04-30_17-57-08/checkpoint_000000/.is_checkpoint to ../s3/checkpoint/.is_checkpoint
download: s3://yunxuanx-test/model-checkpoint/finetune-dolly-v2/LightningTrainer_18e2e_00000_0_2023-04-30_17-57-08/checkpoint_000000/.metadata.pkl to ../s3/checkpoint/.metadata.pkl
download: s3://yunxuanx-test/model-checkpoint/finetune-dolly-v2/LightningTrainer_18e2e_00000_0_2023-04-30_17-57-08/checkpoint_000000/_preprocessor to ../s3/checkpoint/_preprocessor
download: s3://yunxuanx-test/model-checkpoint/finetune-dolly-v2/LightningTrainer_18e2e_00000_0_2023-04-30_17-57-08/checkpoint_000000/.tune_metadata to ../s3/checkpoint/.tune_metadata
download: s3://yunxuanx-test/model-checkpoint/finetune-dolly-v2/LightningTrainer_18e2e_00000_0_2023-04-30_17-57-08/checkpoint_000000/model to ../s3/checkpoint/model


0

Now, let's generate some text and see if our fine-tuned model can write like Shakespeare:

In [8]:
from transformers import pipeline

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="right")
dolly = DollyV2Model.load_from_checkpoint(f"{checkpoint_local_dir}/model")

nlp_pipeline = pipeline(task="text-generation", model=dolly.model, tokenizer=tokenizer, device=0)

for prompt in ["This is", "I am", "Once more"]:
    print(nlp_pipeline(prompt, max_new_tokens=30, pad_token_id=tokenizer.eos_token_id))

[{'generated_text': 'This is the day that I was born, and this is the day that I shall die.'}]
[{'generated_text': 'I am a poor man, sir, and I am a soldier.'}]
[{'generated_text': 'Once more, my lord, I am your servant.'}]
