This notebook implements QLoRA fine-tuning using FlashAttention-2. I use Llama 2 but you can use any other LLMs with less than 256 attention heads.

You need an Ampere GPU, or a more recent GPU, to run this notebook. On Google Colab, only the A100 is compatible.

First, install the following dependencies:

In [None]:
!pip install -q -U bitsandbytes
!pip install -q -U transformers
!pip install -q -U peft
!pip install -q -U accelerate
!pip install -q -U datasets
!pip install -q -U trl
!pip install -q -U flash-attn --no-build-isolation

If you use Llama 2, you need to enter your Hugging Face token to get access to the model on the hub.

In [None]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Import the following for QLoRA fine-tuning:

In [None]:
import torch
from datasets import load_dataset
from peft import LoraConfig, PeftModel,prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments
)

from trl import SFTTrainer



Load Llama 2's tokenizer and configure padding.

In [None]:
model_name = "meta-llama/Llama-2-7b-hf"
#Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, add_eos_token=True, use_fast=True)
#Create a new token and add it to the tokenizer
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = 'left'

I use openassistant-guanaco for fine-tuning.

In [None]:
dataset = load_dataset("timdettmers/openassistant-guanaco")



Load and quantize Llama 2. Note that I set use_flash_attention_2=True to activate FlashAttention.

I use packing (packing=True) to make longer training examples to better benefit from FlashAttention.


In [None]:
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
          model_name, quantization_config=bnb_config, device_map={"": 0}, use_flash_attention_2=True
)


model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=16,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules= ["q_proj","v_proj"]
)

training_arguments = TrainingArguments(
        output_dir="./results",
        evaluation_strategy="steps",
        do_eval=True,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        log_level="debug",
        save_steps=20,
        logging_steps=20,
        learning_rate=4e-4,
        eval_steps=20,
        fp16=True,
        max_steps=100,
        warmup_steps=10,
        lr_scheduler_type="linear",
)

trainer = SFTTrainer(
        model=model,
        train_dataset=dataset['train'],
        eval_dataset=dataset['test'],
        peft_config=peft_config,
        dataset_text_field="text",
        max_seq_length=1024,
        tokenizer=tokenizer,
        args=training_arguments,
        packing=True
)

trainer.train()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

max_steps is given, it will override any value given in num_train_epochs
Using auto half precision backend
Currently training with a batch size of: 16
***** Running training *****
  Num examples = 9,846
  Num Epochs = 1
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 100
  Number of trainable parameters = 8,388,608
You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.float16.


Step,Training Loss,Validation Loss
20,1.4196,1.352458
40,1.2865,1.312616
60,1.243,1.280513
80,1.261,1.272776
100,1.2853,1.270075


***** Running Evaluation *****
  Num examples = 518
  Batch size = 16
Saving model checkpoint to ./results/checkpoint-20
tokenizer config file saved in ./results/checkpoint-20/tokenizer_config.json
Special tokens file saved in ./results/checkpoint-20/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 518
  Batch size = 16
Saving model checkpoint to ./results/checkpoint-40
tokenizer config file saved in ./results/checkpoint-40/tokenizer_config.json
Special tokens file saved in ./results/checkpoint-40/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 518
  Batch size = 16
Saving model checkpoint to ./results/checkpoint-60
tokenizer config file saved in ./results/checkpoint-60/tokenizer_config.json
Special tokens file saved in ./results/checkpoint-60/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 518
  Batch size = 16
Saving model checkpoint to ./results/checkpoint-80
tokenizer config file saved in ./results/checkpoint-80

TrainOutput(global_step=100, training_loss=1.2990720176696777, metrics={'train_runtime': 666.7799, 'train_samples_per_second': 2.4, 'train_steps_per_second': 0.15, 'total_flos': 6.50352940548096e+16, 'train_loss': 1.2990720176696777, 'epoch': 0.16})

# Appendix

This is a standard fine-tuning using the same configuration as above but without FlashAttention.

I ran this to compare the training runtimes.

In [None]:
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
          model_name, quantization_config=bnb_config, device_map={"": 0}
)

model.config.use_cache = False # Gradient checkpointing is used by default but not compatible with caching

model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=16,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules= ["q_proj","v_proj"]
)

training_arguments = TrainingArguments(
        output_dir="./results",
        evaluation_strategy="steps",
        do_eval=True,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        log_level="debug",
        save_steps=20,
        logging_steps=20,
        learning_rate=4e-4,
        eval_steps=20,
        fp16=True,
        max_steps=100,
        warmup_steps=10,
        lr_scheduler_type="linear",
)

trainer = SFTTrainer(
        model=model,
        train_dataset=dataset['train'],
        eval_dataset=dataset['test'],
        peft_config=peft_config,
        dataset_text_field="text",
        max_seq_length=1024,
        tokenizer=tokenizer,
        args=training_arguments,
        packing=True
)

trainer.train()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

max_steps is given, it will override any value given in num_train_epochs
Using auto half precision backend
Currently training with a batch size of: 16
***** Running training *****
  Num examples = 9,846
  Num Epochs = 1
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 100
  Number of trainable parameters = 8,388,608
You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss
20,1.4199,1.352097
40,1.2863,1.312537
60,1.2432,1.280792
80,1.2611,1.272651
100,1.2853,1.270116


***** Running Evaluation *****
  Num examples = 518
  Batch size = 16
Saving model checkpoint to ./results/checkpoint-20
tokenizer config file saved in ./results/checkpoint-20/tokenizer_config.json
Special tokens file saved in ./results/checkpoint-20/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 518
  Batch size = 16
Saving model checkpoint to ./results/checkpoint-40
tokenizer config file saved in ./results/checkpoint-40/tokenizer_config.json
Special tokens file saved in ./results/checkpoint-40/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 518
  Batch size = 16
Saving model checkpoint to ./results/checkpoint-60
tokenizer config file saved in ./results/checkpoint-60/tokenizer_config.json
Special tokens file saved in ./results/checkpoint-60/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 518
  Batch size = 16
Saving model checkpoint to ./results/checkpoint-80
tokenizer config file saved in ./results/checkpoint-80

TrainOutput(global_step=100, training_loss=1.2991660499572755, metrics={'train_runtime': 855.9696, 'train_samples_per_second': 1.869, 'train_steps_per_second': 0.117, 'total_flos': 6.50352940548096e+16, 'train_loss': 1.2991660499572755, 'epoch': 0.16})