<a href="https://colab.research.google.com/github/tuhinmallick/AI-for-Fashion/blob/main/Fine_tune_Gemma_2_on_Your_Computer_With_Transformers_and_Unsloth.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


*More details in this article: [Fine-tune Gemma 2 on Your Computer with LoRA and QLoRA](https://newsletter.kaitchup.com/p/fine-tune-gemma-2-on-your-computer)*

This notebook shows how to fine-tune Gemma 2 with QLoRA and LoRA. It works on a 24 GB GPU with Transformers or a 16 GB GPU with Unsloth.

The notebook is organized in 3 parts:
* QLoRA fine-tuning with Transformers
* LoRA fine-tuning with Transformers
* QLoRA fine-tuning with Unsloth

I don't provide code LoRA fine-tuning with Unsloth as I got errors when I wrote this notebook.

#With Transformers

We will need the following packages:


In [None]:
!pip install --upgrade bitsandbytes transformers peft accelerate datasets trl

Collecting bitsandbytes
  Downloading bitsandbytes-0.43.3-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)
Collecting transformers
  Downloading transformers-4.44.2-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting peft
  Downloading peft-0.12.0-py3-none-any.whl.metadata (13 kB)
Collecting accelerate
  Downloading accelerate-0.33.0-py3-none-any.whl.metadata (18 kB)
Collecting datasets
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting trl
  Downloading trl-0.9.6-py3-none-any.whl.metadata (12 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.w

## QLoRA Fine-tuning

The code is very similar to the one I used for fine-tuning Llama 3, except that we don't need to configure a pad token.

[Fine-tune Llama 3 on Your Computer](https://kaitchup.substack.com/p/fine-tune-llama-3-on-your-computer)

In [None]:
import torch, os, multiprocessing
from datasets import load_dataset
from peft import LoraConfig, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig
)
from trl import SFTTrainer, SFTConfig

if torch.cuda.is_bf16_supported():
  os.system('pip install flash_attn')
  compute_dtype = torch.bfloat16
  attn_implementation = 'flash_attention_2'
else:
  compute_dtype = torch.float16
  attn_implementation = 'sdpa'

model_name = "google/gemma-2-9b"
#Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, add_eos_token=True, use_fast=True)
tokenizer.padding_side = 'right'

ds = load_dataset("timdettmers/openassistant-guanaco")



bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
          model_name, quantization_config=bnb_config, device_map={"": 0}, attn_implementation=attn_implementation
)

model = prepare_model_for_kbit_training(model, gradient_checkpointing_kwargs={'use_reentrant':True})

print(model)

peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.00,
        r=16,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules= ['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"]
)


training_arguments = SFTConfig(
        output_dir="./gemma-2-9b/r16a16_QLoRA",
        eval_strategy="steps",
        do_eval=True,
        optim="paged_adamw_8bit",
        per_device_train_batch_size=4,
        gradient_accumulation_steps=8,
        per_device_eval_batch_size=4,
        log_level="debug",
        save_strategy="epoch",
        logging_steps=20,
        learning_rate=1e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        eval_steps=20,
        num_train_epochs=1,
        warmup_ratio=0.1,
        lr_scheduler_type="linear",
        dataset_text_field="text",
        max_seq_length=512,
)

trainer = SFTTrainer(
        model=model,
        train_dataset=ds['train'],
        eval_dataset=ds['test'],
        peft_config=peft_config,
        tokenizer=tokenizer,
        args=training_arguments,
)

trainer.train()

Repo card metadata block was not found. Setting CardData to empty.
  self.pid = os.fork()


Map (num_proc=12):   0%|          | 0/9846 [00:00<?, ? examples/s]

Map (num_proc=12):   0%|          | 0/518 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/856 [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/8 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 3584, padding_idx=0)
    (layers): ModuleList(
      (0-41): 42 x Gemma2DecoderLayer(
        (self_attn): Gemma2SdpaAttention(
          (q_proj): Linear4bit(in_features=3584, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=3584, out_features=2048, bias=False)
          (v_proj): Linear4bit(in_features=3584, out_features=2048, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=3584, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear4bit(in_features=3584, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=3584, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=3584, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm()
        (post_attention_layernorm): Gemma2RMSN


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Map:   0%|          | 0/9846 [00:00<?, ? examples/s]

Map:   0%|          | 0/518 [00:00<?, ? examples/s]

Using auto half precision backend
Currently training with a batch size of: 4
***** Running training *****
  Num examples = 9,846
  Num Epochs = 1
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 8
  Total optimization steps = 307
  Number of trainable parameters = 54,018,048
It is strongly recommended to train Gemma2 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss,Validation Loss
20,1.5414,1.430857
40,1.3586,1.357276
60,1.3045,1.346512
80,1.3312,1.341357
100,1.3109,1.338429
120,1.3019,1.33629
140,1.3075,1.33346
160,1.2969,1.332452
180,1.3048,1.330233
200,1.3196,1.32865



***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4


Step,Training Loss,Validation Loss
20,1.5414,1.430857
40,1.3586,1.357276
60,1.3045,1.346512
80,1.3312,1.341357
100,1.3109,1.338429
120,1.3019,1.33629
140,1.3075,1.33346
160,1.2969,1.332452
180,1.3048,1.330233
200,1.3196,1.32865



***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4
Saving model checkpoint to ./gemma-2-9b/r16a16_QLoRA/checkpoint-307
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--google--gemma-2-9b/snapshots/beb0c08e9eeb0548f3aca2ac870792825c357b7d/config.json
Model config Gemma2Config {
  "architectures": [
    "Gemma2ForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "attn_logit_softcapping": 50.0,
  "bos_token_id": 2,
  "cache_implementation": "hybrid",
  "eos_token_id": 1,
  "final_logit_softcapping": 30.0,
  "head_dim": 256,
  "hidden_act": "gelu_pytorch_tanh",
  "hidden_activation": "gelu_pytorch_tanh",
  "hidden_size": 3584,
  "init

TrainOutput(global_step=307, training_loss=1.322845154554137, metrics={'train_runtime': 13497.883, 'train_samples_per_second': 0.729, 'train_steps_per_second': 0.023, 'total_flos': 2.2335744778621747e+17, 'train_loss': 1.322845154554137, 'epoch': 0.9975629569455727})

##LoRA Fine-tuning

In [None]:
import torch, os, multiprocessing
from datasets import load_dataset
from peft import LoraConfig
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer
)
from trl import SFTTrainer, SFTConfig

if torch.cuda.is_bf16_supported():
  os.system('pip install flash_attn')
  compute_dtype = torch.bfloat16
  attn_implementation = 'flash_attention_2'
else:
  compute_dtype = torch.float16
  attn_implementation = 'sdpa'

model_name = "google/gemma-2-9b"
#Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, add_eos_token=True, use_fast=True)
tokenizer.padding_side = 'right'

ds = load_dataset("timdettmers/openassistant-guanaco")

#Add the EOS token
def process(row):
    row["text"] = row["text"]+tokenizer.eos_token
    return row

ds = ds.map(
    process,
    num_proc= multiprocessing.cpu_count(),
    load_from_cache_file=False,
)


model = AutoModelForCausalLM.from_pretrained(
          model_name, device_map={"": 0}, torch_dtype=compute_dtype, attn_implementation=attn_implementation
)

model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':True})

print(model)

peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.00,
        r=16,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules= ['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"]
)

training_arguments = SFTConfig(
        output_dir="./gemma-2-9b/r16a16_LoRA",
        eval_strategy="steps",
        do_eval=True,
        optim="paged_adamw_8bit",
        per_device_train_batch_size=1,
        gradient_accumulation_steps=32,
        per_device_eval_batch_size=1,
        log_level="debug",
        save_strategy="epoch",
        logging_steps=20,
        learning_rate=1e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        eval_steps=20,
        num_train_epochs=1,
        warmup_ratio=0.1,
        lr_scheduler_type="linear",
        dataset_text_field="text",
        max_seq_length=512,
)

trainer = SFTTrainer(
        model=model,
        train_dataset=ds['train'],
        eval_dataset=ds['test'],
        peft_config=peft_config,
        tokenizer=tokenizer,
        args=training_arguments,
)

trainer.train()

Repo card metadata block was not found. Setting CardData to empty.
  self.pid = os.fork()


Map (num_proc=12):   0%|          | 0/9846 [00:00<?, ? examples/s]

Map (num_proc=12):   0%|          | 0/518 [00:00<?, ? examples/s]

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 3584, padding_idx=0)
    (layers): ModuleList(
      (0-41): 42 x Gemma2DecoderLayer(
        (self_attn): Gemma2SdpaAttention(
          (q_proj): Linear(in_features=3584, out_features=4096, bias=False)
          (k_proj): Linear(in_features=3584, out_features=2048, bias=False)
          (v_proj): Linear(in_features=3584, out_features=2048, bias=False)
          (o_proj): Linear(in_features=4096, out_features=3584, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=3584, out_features=14336, bias=False)
          (up_proj): Linear(in_features=3584, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=3584, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm()
        (post_attention_layernorm): Gemma2RMSNorm()
        (pre_feedforwa

Map:   0%|          | 0/518 [00:00<?, ? examples/s]

Using auto half precision backend
Currently training with a batch size of: 1
***** Running training *****
  Num examples = 9,846
  Num Epochs = 1
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 32
  Total optimization steps = 307
  Number of trainable parameters = 54,018,048
It is strongly recommended to train Gemma2 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss,Validation Loss
20,1.7054,1.538012
40,1.4425,1.446113
60,1.3809,1.433512
80,1.3952,1.42845
100,1.4007,1.424349
120,1.3795,1.423465
140,1.378,1.421013
160,1.3755,1.417949
180,1.3793,1.416622
200,1.3744,1.41442



***** Running Evaluation *****
  Num examples = 518
  Batch size = 1

***** Running Evaluation *****
  Num examples = 518
  Batch size = 1

***** Running Evaluation *****
  Num examples = 518
  Batch size = 1

***** Running Evaluation *****
  Num examples = 518
  Batch size = 1

***** Running Evaluation *****
  Num examples = 518
  Batch size = 1

***** Running Evaluation *****
  Num examples = 518
  Batch size = 1

***** Running Evaluation *****
  Num examples = 518
  Batch size = 1

***** Running Evaluation *****
  Num examples = 518
  Batch size = 1

***** Running Evaluation *****
  Num examples = 518
  Batch size = 1

***** Running Evaluation *****
  Num examples = 518
  Batch size = 1

***** Running Evaluation *****
  Num examples = 518
  Batch size = 1

***** Running Evaluation *****
  Num examples = 518
  Batch size = 1

***** Running Evaluation *****
  Num examples = 518
  Batch size = 1

***** Running Evaluation *****
  Num examples = 518
  Batch size = 1

***** Running Evalu

TrainOutput(global_step=307, training_loss=1.4054545517464803, metrics={'train_runtime': 8449.164, 'train_samples_per_second': 1.165, 'train_steps_per_second': 0.036, 'total_flos': 1.441453151470725e+17, 'train_loss': 1.4054545517464803, 'epoch': 0.9977655900873451})

#Unsloth

Install Unsloth and its dependencies. The following is what works with the current Colab configuration. It will change very often. Follow the [installation instructions provided in the GitHub](https://github.com/unslothai/unsloth) if it doesn't work.

In [None]:
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps "xformers<0.0.27" trl peft accelerate bitsandbytes

Collecting unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git
  Cloning https://github.com/unslothai/unsloth.git to /tmp/pip-install-ntfb2hp1/unsloth_98f31e582da241fc93556f25610a3d2e
  Running command git clone --filter=blob:none --quiet https://github.com/unslothai/unsloth.git /tmp/pip-install-ntfb2hp1/unsloth_98f31e582da241fc93556f25610a3d2e
  Resolved https://github.com/unslothai/unsloth.git to commit 92dce38e8b3c1db209cef860d90b60188e95f0f9
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting tyro (from unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git)
  Downloading tyro-0.8.5-py3-none-any.whl (103 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.4/103.4 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers>=4.42.3 (from unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.

##QLoRA


In [None]:
from unsloth import FastLanguageModel, is_bfloat16_supported
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
import multiprocessing
import torch
max_seq_length = 512
dtype = None
load_in_4bit = True


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "google/gemma-2-9b",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
)

ds = load_dataset("timdettmers/openassistant-guanaco")

#Add the EOS token
def process(row):
    row["text"] = row["text"]+tokenizer.eos_token
    return row

ds = ds.map(
    process,
    num_proc= multiprocessing.cpu_count(),
    load_from_cache_file=False,
)


training_arguments = SFTConfig(
        output_dir="./gemma-2-9b/r16a16_QLoRA_Unsloth",
        eval_strategy="steps",
        do_eval=True,
        optim="paged_adamw_8bit",
        per_device_train_batch_size=4,
        gradient_accumulation_steps=8,
        per_device_eval_batch_size=4,
        log_level="debug",
        save_strategy="epoch",
        logging_steps=20,
        learning_rate=1e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        eval_steps=20,
        num_train_epochs=1,
        warmup_ratio=0.1,
        dataset_text_field="text",
        max_seq_length=512,
        lr_scheduler_type="linear",
)



trainer = SFTTrainer(
        model=model,
        train_dataset=ds['train'],
        eval_dataset=ds['test'],
        tokenizer=tokenizer,
        args=training_arguments,
)

trainer.train()

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
==((====))==  Unsloth: Fast Gemma2 patching release 2024.7
   \\   /|    GPU: NVIDIA L4. Max memory: 22.168 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.3.0+cu121. CUDA = 8.9. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.26.post1. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Unsloth 2024.7 patched 42 layers with 42 QKV layers, 42 O layers and 42 MLP layers.
Repo card metadata block was not found. Setting CardData to empty.


Map (num_proc=12):   0%|          | 0/9846 [00:00<?, ? examples/s]

Map (num_proc=12):   0%|          | 0/518 [00:00<?, ? examples/s]

Map:   0%|          | 0/9846 [00:00<?, ? examples/s]

Map:   0%|          | 0/518 [00:00<?, ? examples/s]

Using auto half precision backend
Currently training with a batch size of: 4
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 9,846 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 4 | Gradient Accumulation steps = 8
\        /    Total batch size = 32 | Total steps = 307
 "-____-"     Number of trainable parameters = 54,018,048


Step,Training Loss,Validation Loss
20,1.5411,1.429181
40,1.358,1.356442
60,1.3036,1.345572
80,1.3303,1.340242
100,1.31,1.337499
120,1.301,1.335561
140,1.3066,1.333249
160,1.2966,1.331671
180,1.3042,1.329507
200,1.3189,1.32786



***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evaluation *****
  Num examples = 518
  Batch size = 4

***** Running Evalu

KeyError: None