# Fine-tune Dolly-v2-3b with Ray AIR LightningTrainer and FSDP

First install some requirements

In [None]:
## Requirements
! pip install "datasets" "evaluate" "transformers>=4.26.0" "torch>=1.12.0" "pytorch_lightning>=2.0"

In [None]:
import ray
import torch
import evaluate

import pytorch_lightning as pl
import torch.nn.functional as F

from datasets import load_dataset, load_metric
from typing import Any
from torch.utils.data import DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForCausalLM
from ray.tune.syncer import SyncConfig
from ray.data.preprocessors import Chain

In [1]:
import os
os.environ["RAY_ML_DEV"] = "1"

In [2]:
# MODEL_NAME = "databricks/dolly-v2-7b"
MODEL_NAME = "databricks/dolly-v2-3b"

In [None]:
import numpy as np
import pandas as pd
from ray.data.preprocessors import BatchMapper


def split_text(batch: pd.DataFrame) -> pd.DataFrame:
    text = list(batch["text"])
    flat_text = "".join(text)
    split_text = [
        x.strip()
        for x in flat_text.split("\n")
        if x.strip() and not x.strip()[-1] == ":"
    ]
    return pd.DataFrame(split_text, columns=["text"])


def tokenize(batch: pd.DataFrame) -> dict:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
    tokenizer.pad_token = tokenizer.eos_token
    ret = tokenizer(
        list(batch["text"]),
        truncation=True,
        max_length=256,
        padding="max_length",
        return_tensors="np",
    )
    ret["labels"] = ret["input_ids"].copy()
    return dict(ret)

splitter = BatchMapper(split_text, batch_format="pandas")
tokenizer = BatchMapper(tokenize, batch_format="pandas")
preprocessor = Chain(splitter, tokenizer)

hf_dataset = load_dataset("tiny_shakespeare")
ray_datasets = ray.data.from_huggingface(hf_dataset)

In [3]:
class DollyV2Model(pl.LightningModule):
    def __init__(self, lr=2e-5, eps=1e-8):
        super().__init__()
        self.lr = lr
        self.eps = eps
        self.model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

        self.metric = evaluate.load("accuracy")
        self.predictions = []
        self.references = []

    def forward(self, batch):
        outputs = self.model(
            batch["input_ids"], 
            attention_mask=batch["attention_mask"], 
            labels=batch["labels"]
        )
        loss = outputs[0]
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.forward(batch)
        self.log("train_loss", loss)
        if self.global_rank == 0 and batch_idx % 10 == 0:
            print("loss = ", loss.item())
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.trainer.model.parameters(), lr=self.lr, eps=self.eps)

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset tiny_shakespeare (/home/ray/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e)
100%|██████████| 3/3 [00:00<00:00, 1081.56it/s]
2023-05-01 00:56:50,821	INFO worker.py:1432 -- Connecting to existing Ray cluster at address: 10.0.43.253:6379...
2023-05-01 00:56:50,830	INFO worker.py:1607 -- Connected to Ray cluster. View the dashboard at https://console.anyscale-staging.com/api/v2/sessions/ses_m411tiqu8eluvt1k5ivfqj4q5r/services?redirect_to=dashboard 
2023-05-01 00:56:51,432	INFO packaging.py:520 -- Creating a file package for local directory '/tmp/ray_tmp_module/ray'.
2023-05-01 00:56:52,344	INFO packaging.py:347 -- Pushing file package 'gcs://_ray_pkg_be300ec1f2b48a35.zip' (155.57MiB) to Ray cluster...
2023-05-01 00:56:52,879	INFO packaging.py:360 -- Successfully pushed file package 'gcs://_ray_pkg_be300ec1f2b48a35.zip'.
2023-05-01 00:56:53,392	

In [None]:
num_workers = 16
batch_size_per_worker = 8

In [None]:
from ray.train.lightning import LightningTrainer, LightningConfigBuilder
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp import ShardingStrategy, BackwardPrefetch
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer
import functools

auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls = {GPTNeoXLayer}
)

# Define the configs for LightningTrainer
lightning_config = (
    LightningConfigBuilder()
    .module(cls=DollyV2Model, lr=2e-5, eps=1e-8)
    .trainer(
        max_epochs=1, 
        accelerator="gpu", 
        log_every_n_steps=1,
        precision="16-mixed",
    )
    .strategy(
        name="fsdp",
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        forward_prefetch=True,
        auto_wrap_policy=auto_wrap_policy,
    )
    .checkpointing(save_last=True)
    .build()
)

In [None]:
from pytorch_lightning.callbacks import TQDMProgressBar

class DollyV2ProgressBar(TQDMProgressBar):
    def __init__(self, num_iters_per_epoch, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_iters_per_epoch = num_iters_per_epoch
    
    def on_train_epoch_start(self, trainer, *_):
        super().on_train_epoch_start(trainer, *_)
        self.train_progress_bar.reset(self.num_iters_per_epoch)

total_train_batches = splitter.fit_transform(ray_datasets["train"]).count()
num_iters_per_epoch = total_train_batches // (num_workers * batch_size_per_worker)
progress_bar = DollyV2ProgressBar(num_iters_per_epoch)
lightning_config

In [None]:

from ray.tune.syncer import SyncConfig

# Save AIR checkpoints according to the performance on validation set
run_config = RunConfig(
    name=f"finetune-{MODEL_NAME}",
    storage_path="s3://yunxuanx-test/model-checkpoint",
    checkpoint_config=CheckpointConfig(),
)

# Scale the DDP training workload across 16 GPUs
# You can change this config based on your compute resources.
scaling_config = ScalingConfig(
    num_workers=num_workers, use_gpu=True, resources_per_worker={"CPU": 14, "GPU": 1}
)


trainer = LightningTrainer(
    lightning_config=lightning_config,
    run_config=run_config,
    scaling_config=scaling_config,
    datasets={"train": ray_datasets["train"]},
    datasets_iter_config={"batch_size": batch_size_per_worker},
    preprocessor=preprocessor,
)
result = trainer.fit()

result


In [9]:
# checkpoint_uri = result.checkpoint.uri
checkpoint_uri = "s3://yunxuanx-test/model-checkpoint/finetune-dolly-v2/LightningTrainer_18e2e_00000_0_2023-04-30_17-57-08/checkpoint_000000/"
checkpoint_local_dir = "/home/ray/s3/checkpoint"

In [19]:
cmd = f"aws s3 sync {checkpoint_uri} {checkpoint_local_dir}"
os.system(cmd)

download: s3://yunxuanx-test/model-checkpoint/finetune-dolly-v2/LightningTrainer_18e2e_00000_0_2023-04-30_17-57-08/checkpoint_000000/.is_checkpoint to ../s3/checkpoint/.is_checkpoint
download: s3://yunxuanx-test/model-checkpoint/finetune-dolly-v2/LightningTrainer_18e2e_00000_0_2023-04-30_17-57-08/checkpoint_000000/.metadata.pkl to ../s3/checkpoint/.metadata.pkl
download: s3://yunxuanx-test/model-checkpoint/finetune-dolly-v2/LightningTrainer_18e2e_00000_0_2023-04-30_17-57-08/checkpoint_000000/_preprocessor to ../s3/checkpoint/_preprocessor
download: s3://yunxuanx-test/model-checkpoint/finetune-dolly-v2/LightningTrainer_18e2e_00000_0_2023-04-30_17-57-08/checkpoint_000000/.tune_metadata to ../s3/checkpoint/.tune_metadata
download: s3://yunxuanx-test/model-checkpoint/finetune-dolly-v2/LightningTrainer_18e2e_00000_0_2023-04-30_17-57-08/checkpoint_000000/model to ../s3/checkpoint/model


0

In [11]:
# from ray.train.lightning import LightningCheckpoint
# checkpoint = LightningCheckpoint.from_uri(checkpoint_uri)
# # Very slow!
# air_model = checkpoint.get_model(DollyV2Model)

In [4]:
import torch
from transformers import AutoTokenizer, pipeline
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="right")
dolly = DollyV2Model.load_from_checkpoint(f"{checkpoint_local_dir}/model").cuda()
nlp_pipeline = pipeline(task="text-generation", model=dolly.model, tokenizer=tokenizer, device=0)

In [8]:
for prompt in ["This is", "I am", "Once more"]:
    print(nlp_pipeline(prompt, max_new_tokens=30, pad_token_id=tokenizer.eos_token_id))

[{'generated_text': 'This is the day that I was born, and this is the day that I shall die.'}]
[{'generated_text': 'I am a poor man, sir, and I am a soldier.'}]
[{'generated_text': 'Once more, my lord, I am your servant.'}]
