<a href="https://colab.research.google.com/github/vanderbilt-data-science/ai-training-day/blob/main/dsi_ai_training_fine_tune_gemma.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Parameter-Efficient Fine-Tuning Gemma

*By Myranda Shirk, Senior Data Scientist, Vanderbilt Data Science Institute*

Notebook created with help from [Gemma fine-tuning documentation](https://github.com/huggingface/notebooks/blob/main/peft/gemma_7b_english_quotes.ipynb).

## Fine-Tuning in Google Colab

According to [Gemma's HuggingFace Space](https://huggingface.co/google/gemma-7b), this fine-tuning code can be run on a free instance of Google Colab using the available GPU runtime. To change your runtime to GPU, select "Runtime"-> ""Change Runtime Type" -> GPU.

If for any reason you are not able to use a GPU, you may see the cells indicated for use on CPU.

### Libraries and APIs

To access Gemma, you need a [HuggingFace](www.huggingface.co) account and a HuggingFace API Token with read-only permissions (In HF: Profile -> Settings -> Access Tokens). Additionally, you need to visit [Gemma's HuggingFace Space](https://huggingface.co/google/gemma-7b) and click the button to accept their terms of use. After accepting the terms, access should be immediately granted.

In [1]:
import os
import getpass
os.environ["HF_TOKEN"] = getpass.getpass("Enter your HuggingFace token: ")

Enter your HuggingFace token: ··········


In [2]:

!pip3 install -q -U bitsandbytes==0.42.0
!pip3 install -q -U peft==0.8.2
!pip3 install -q -U trl==0.7.10
!pip3 install -q -U accelerate==0.27.1
!pip3 install -q -U datasets==2.17.0
!pip3 install -q -U transformers==4.38.1

### Model Setup and Training Objective

For this example, we will be fine-tuning Gemma on the English Quotes Dataset. We want our model to output a quote and its author given the start of a quote.

First, we can access Gemma-7B (or any Gemma model - simply change to "gemma-2b" for the 2B parameter model, etc) through HuggingFace (this is where you need your HF authenitication, which we set above).

**NOTE**: The below cells needs connection to a GPU, which you can access by selecting "Runtime" -> "Change Runtime Type" -> GPU

If you are not connected to a GPU, the error will say something along the lines of "you must have accelerate and bitsandbytes installed."

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer

model_id = "google/gemma-7b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}, token=os.environ['HF_TOKEN'])

tokenizer_config.json:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/555 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.11G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [None]:
# FOR CPU: DELETE QUOTES AND RUN THIS CELL
'''
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b")

'''


Let's see how Gemma does on this task without any fine-tuning. We will give it the start of a quote.

In [18]:
text = "Quote: Imagination is"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Quote: Imagination is more important than knowledge.

Albert Einstein

I am a creative and curious person who loves to learn


As we can see above, the model does finish the quote and attribute an author, but then it continues on with another quote without us prompting. Not exactly what we want!

### Data and Training Functions

Next, we will set up our training configuation for [LoRA] (https://www.run.ai/guides/generative-ai/lora-fine-tuning), a highly efficient training method.

In [11]:
os.environ["WANDB_DISABLED"] = "true"

In [12]:
from peft import LoraConfig

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

We can load our dataset through the HF datasets library.

In [8]:
# may not use this dataset
from datasets import load_dataset

data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/5.55k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/647k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/2508 [00:00<?, ? examples/s]

Now we can define our Supervised Fine-Tuning (SFT) trainer below and start the training!

In [9]:
import transformers
from trl import SFTTrainer

def formatting_func(example):
    text = f"Quote: {example['quote'][0]}\nAuthor: {example['author'][0]}"
    return [text]

trainer = SFTTrainer(
    model=model,
    train_dataset=data["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=10,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
    formatting_func=formatting_func,
)
trainer.train()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Map:   0%|          | 0/2508 [00:00<?, ? examples/s]



Step,Training Loss
1,1.373
2,0.4893
3,0.7016
4,0.4656
5,0.2462
6,0.5595
7,0.4917
8,0.1513
9,0.3819
10,0.3452


TrainOutput(global_step=10, training_loss=0.520535697042942, metrics={'train_runtime': 18.7183, 'train_samples_per_second': 2.137, 'train_steps_per_second': 0.534, 'total_flos': 21135849891840.0, 'train_loss': 0.520535697042942, 'epoch': 6.67})

### Evaluation

Let's see how our model does after fine-tuning. We will run the same example we did at the beginning. Remember that we want our model to give us the rest of the quote and its author.

In [11]:
text = "Quote: Imagination is"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Quote: Imagination is more important than knowledge.
Author: Albert Einstein
From: The World as I See It



Great! This is exactly what we want. We can now save our model as a .pt file to access later.

In [None]:
torch.save(model.state_dict(), "gemma-7b-peft-quotes.pt")

## Conclusion

You have just successfully fine-tuned Google's Gemma model on our English Quotes dataset. Feel free to adapt this process for SFT on your own project.

## Further Reading

- [PEFT Llama-2](https://colab.research.google.com/drive/1kKodzt_KZUXQA_dBBHyp4KS0WEY4SLpA#scrollTo=ZlFLWOPAxXbB)