# Efficiently scale distributed training with PyTorch FSDP and Q-Lora

Open LLMs like Meta [Llama 2](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf), Mistral AI [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) & [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) models or TII [Falcon](https://huggingface.co/tiiuae/falcon-40b) brought a lot of progress in the last year. We went from no OpenAI competitor to a whole zoo of LLMs, which can be used for a variety of tasks. However, most of the time you need to fine-tune the model on your data or use case to get the full potential of the model. Fine-tuning smaller LLMs, like Mistral became very accessible on a single GPU by using Q-Lora. But efficiently fine-tuning bigger models like Llama 2 70b or Mixtral stayed a challenge until now. 
 
This blog post walks you thorugh how to fine-tune a LLM using PyTorch FSDP and Q-Lora with the help of Hugging Face [TRL](https://huggingface.co/docs/trl/index), [Transformers](https://huggingface.co/docs/transformers/index), [peft](https://huggingface.co/docs/peft/index) & [datasets](https://huggingface.co/docs/datasets/index). In addition to FSDP we will use [Flash Attention v2 through the Pytorch SDPA](https://pytorch.org/blog/pytorch2-2/) implementation. 



1. [Setup development environment](#1-setup-development-environment)
2. [Create and prepare the dataset](#2-create-and-prepare-the-dataset)
3. [Fine-tune the LLM with PyTorch FSDP, Q-Lora and SDPA](#3-fine-tune-the-llm-with-pytorch-fsdp-q-lora-and-sdpa)
4. [Test Model and run Inference](#4-test-model-and-run-inference)

_Note: This blog was created and validated on NVIDIA H100 and NVIDIA A10G GPUs._

As example use case we will fine-tune the Llama 2 70b model on the [HuggingFaceH4/no_robots](https://huggingface.co/datasets/HuggingFaceH4/no_robots) dataset. No Robots is a high-quality dataset of 10,000 instructions and demonstrations created by skilled human annotators. This data can be used for supervised fine-tuning (SFT) to make language models follow instructions better. No Robots was modelled after the instruction dataset described in OpenAI's [InstructGPT paper](https://huggingface.co/papers/2203.02155), and is comprised mostly of single-turn instructions.

**FSDP + Q-Lora Background**

In a collaboration between [Answer.AI](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html), Tim Dettmers [Q-Lora creator](https://github.com/TimDettmers/bitsandbytes) and [Hugging Face](https://huggingface.co/), we are proud to announce to share the support of Q-Lora and PyTorch FSDP (Fully Sharded Data Parallel). FSDP and Q-Lora allows you now to fine-tune Llama 2 70b or Mixtral 8x7B on 2x consumer GPUs (24GB). If you want to learn more about the background of this collaboration take a look at [You can now train a 70b language model at home](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html). Hugging Face PEFT is were the magic happens for this happens, read more about it in the [PEFT documentation](https://huggingface.co/docs/peft/v0.10.0/en/accelerate/fsdp).

* [PyTorch FSDP](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) is a data/model parallelism technique that shards model across GPUs, reducing memory requirements and enabling the training of larger models more efficiently​​​​​​.
* Q-LoRA is a fine-tuning method that leverages quantization and Low-Rank Adapters to efficiently reduced computational requirements and memory footprint. 


## 1. Setup development environment

Our first step is to install Hugging Face Libraries and Pyroch, including trl, transformers and datasets. If you haven't heard of trl yet, don't worry. It is a new library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs. 

In [None]:
# Install Pytorch for FSDP and FA/SDPA
!pip install "torch==2.2.2" tensorboard

# Install Hugging Face libraries
!pip install  --upgrade \
  "transformers==4.39.3" \
  "datasets==2.18.0" \
  "accelerate==0.29.2" \
  "evaluate==0.4.1" \
  "bitsandbytes==0.43.1" \
  "trl==0.8.3"  \
  "peft==0.10.0" 
  


In [None]:
!pip install git+https://github.com/huggingface/trl

Next we need to login into Hugging Face to access the Llama 2 70b model. If you don't have an account yet and accepted the terms, you can create one [here](https://huggingface.co/join). 

In [None]:
!huggingface-cli login --token ""

## 2. Create and prepare the dataset

After our environment is set up, we can start creating and preparing our dataset. A fine-tuning dataset should have a diverse set of demonstrations of the task you want to solve. If you want to learn more about how to create a dataset, take a look at the [How to Fine-Tune LLMs in 2024 with Hugging Face](https://www.philschmid.de/fine-tune-llms-in-2024-with-trl#3-create-and-prepare-the-dataset).

We will use the [HuggingFaceH4/no_robots](https://huggingface.co/datasets/HuggingFaceH4/no_robots) dataset. The dataset is already has a column `messages` with conversational format, which is supported by trl. We can directly use datasets they use this format. 

```json
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
```

The [no_robots](https://huggingface.co/datasets/HuggingFaceH4/no_robots) dataset has 10,000 split into 9,500 training and  500 test examples. Some samples are not including a `system` message. We will load the dataset with the `datasets` library, add a missing `system` message and save them to separate files. 

After making sure each sample has a `system` message, we will save the dataset as json files.

In [None]:
from datasets import load_dataset

# Convert dataset to OAI messages
system_message = """You are Llama, an AI assistant created by Philipp to be helpful and honest. Your knowledge spans a wide range of topics, allowing you to engage in substantive conversations and provide analysis on complex subjects."""

def create_conversation(sample):
    if sample["messages"][0]["role"] == "system":
        return sample
    else:
      sample["messages"] = [{"role": "system", "content": system_message}] + sample["messages"]
      return sample

# Load dataset from the hub
dataset = load_dataset("HuggingFaceH4/no_robots")

# Add system message to each conversation
columns_to_remove = list(dataset["train_sft"].features)
columns_to_remove.remove("messages")
dataset = dataset.map(create_conversation, remove_columns=columns_to_remove,batched=False)

# Filter out conversations which are corrupted with wrong turns, keep which have even number of turns after adding system message
dataset["train_sft"] = dataset["train_sft"].filter(lambda x: len(x["messages"][1:]) % 2 == 0)

# save datasets to disk 
dataset["train_sft"].to_json("train_dataset.json", orient="records",force_ascii=False)
dataset["test_sft"].to_json("test_dataset.json", orient="records",force_ascii=False)

## 3. Fine-tune the LLM with PyTorch FSDP, Q-Lora and SDPA

We are now ready to fine-tune our model with PyTorch FSDP, Q-Lora and SDPA. Since we are running in a distributed setup, we need to use `torchrun` and a python script to start the training. 

We prepared a script `run_fsdp_qlora.py` which will load the dataset from disk, prepare the model, tokenizer and start the training. It usees the [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) from `trl` to fine-tune our model. The `SFTTrainer` makes it straightfoward to supervise fine-tune open LLMs with features, like:
* Dataset formatting, including conversational and instruction format
* Training on completions only, ignoring prompts
* Packing datasets for more efficient training
* PEFT (parameter-efficient fine-tuning) support including Q-LoRA
* Preparing the model and tokenizer for conversational fine-tuning (e.g. adding special tokens)
 
To make it easy for everyone to configure the training we used the new `TrlParser`, that allows you to provide your hyperparameters in a yaml file. If you want to make changes to the training, change the inputs below or you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. `--num_epochs 10`

In [None]:
%%writefile llama_70b_fsdp_qlora.yaml
# script parameters
model_id: "meta-llama/llama-7b-hf"     # Hugging Face model id
dataset_path: "./"                     # path to dataset
max_seq_len: 3072                      # max sequence length for model and packing of the dataset
# training parameters
output_dir: "./llama-7b-hf-no-robot"  # Temporary output directory for model checkpoints
report_to: "tensorboard"              # report metrics to tensorboard
learning_rate: 2e-4                   # learning rate
lr_scheduler_type: "constant"             # learning rate scheduler
num_train_epochs: 3                   # number of training epochs
per_device_train_batch_size: 1        # batch size per device during training
gradient_accumulation_steps: 4        # number of steps before performing a backward/update pass
optim: adamw_torch_fused              # use fused adamw optimizer
logging_steps: 10                     # log every 10 steps
save_strategy: epoch                  # save checkpoint every epoch
max_grad_norm: 0.3                    # max gradient norm
warmup_ratio: 0.03                    # warmup ratio
bf16: true                            # use bfloat16 precision
tf32: true                            # use tf32 precision
gradient_checkpointing: true          # use gradient checkpointing to save memory
# FSDP parameters: https://huggingface.co/docs/transformers/main/en/fsdp
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP # auto wrap policy
  fsdp_backward_prefetch: BACKWARD_PRE # backward prefetch
  fsdp_cpu_ram_efficient_loading: true # cpu ram efficient loading
  fsdp_forward_prefetch: false # forward prefetch
  fsdp_offload_params: false # offload params
  fsdp_sharding_strategy: FULL_SHARD # sharding strategy
  fsdp_state_dict_type: SHARDED_STATE_DICT # state dict type
  fsdp_sync_module_states: true # sync module states
  fsdp_use_orig_params: false # use original params

_Note: At the end of the training there will be a slight increase in GPU memory usage (~10%). This is due to the saving of the model correctly. Make sure to have enough memory left on your GPU to save the model. [REF](https://huggingface.co/docs/peft/v0.10.0/en/accelerate/fsdp#memory-usage)_

**Expected Memory usage:**
* Full-finetuning with FSDP => ~16X80GB GPUs
* FSDP + LoRA => 8X80GB GPUs
* FSDP + Q-Lora => 2x40GB GPUs 
* FSDP + Q-Lora + CPU offloading => 2x24GB GPUs, with 19.6 GB/GPU and overall CPU RAM of around 107 GB

_Note: To use CPU offloading you need to change the value of `fsdp_offload_params` to `True` in the `llama_70b_fsdp_qlora.yaml`._

In [None]:
!torchrun --nproc_per_node=4 ./scripts/run_fsdp_qlora.py --config llama_70b_fsdp_qlora.yaml

The training with Flash Attention for 3 epochs with a dataset of 10k samples took 18:29:58 on a `g5.12xlarge`. The instance costs `1,212$/h` which brings us to a total cost of only `1.8$`.

### _Optional: Merge LoRA adapter in to the original model_

When using QLoRA, we only train adapters and not the full model. This means when saving the model during training we only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with Text Generation Inference you can merge the adapter weights into the model weights using the `merge_and_unload` method and then save the model with the `save_pretrained` method. This will save a default model, which can be used for inference.

_Note: You might require > 192GB CPU Memory._

In [None]:
#### COMMENT IN TO MERGE PEFT AND BASE MODEL ####
# from peft import AutoPeftModelForCausalLM

# # Load PEFT model on CPU
# model = AutoPeftModelForCausalLM.from_pretrained(
#     args.output_dir,
#     torch_dtype=torch.float16,
#     low_cpu_mem_usage=True,
# )  
# # Merge LoRA and base model and save
# merged_model = model.merge_and_unload()
# merged_model.save_pretrained(args.output_dir,safe_serialization=True, max_shard_size="2GB")

## 4. Test Model and run Inference

After the training is done we want to evaluate and test our model. We will load different samples from the original dataset and evaluate the model on those samples, using a simple loop and accuracy as our metric. 

_Note: Evaluating Generative AI models is not a trivial task since 1 input can have multiple correct outputs. If you want to learn more about evaluating generative models, check out [Evaluate LLMs and RAG a practical example using Langchain and Hugging Face](https://www.philschmid.de/evaluate-llm) blog post._



In [None]:
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline 

peft_model_id = "./code-llama-7b-text-to-sql"
# peft_model_id = args.output_dir

# Load Model with PEFT adapter
model = AutoPeftModelForCausalLM.from_pretrained(
  peft_model_id,
  device_map="auto",
  torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
# load into pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

Let’s load our test dataset try to generate an instruction.

In [None]:
from datasets import load_dataset 
from random import randint


# Load our test dataset
eval_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
rand_idx = randint(0, len(eval_dataset))

# Test on sample 
prompt = pipe.tokenizer.apply_chat_template(eval_dataset[rand_idx]["messages"][:2], tokenize=False, add_generation_prompt=True)
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)

print(f"Query:\n{eval_dataset[rand_idx]['messages'][1]['content']}")
print(f"Original Answer:\n{eval_dataset[rand_idx]['messages'][2]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")