##### Copyright 2024 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine-tune Gemma models in Keras using LoRA

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/lora_tuning"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335"><img src="https://ai.google.dev/images/cloud-icon.svg" width="40" />Open in Vertex AI</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

## Overview

Gemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models.

Large Language Models (LLMs) like Gemma have been shown to be effective at a variety of NLP tasks. An LLM is first pre-trained on a large corpus of text in a self-supervised fashion. Pre-training helps LLMs learn general-purpose knowledge, such as statistical relationships between words. An LLM can then be fine-tuned with domain-specific data to perform downstream tasks (such as sentiment analysis).

LLMs are extremely large in size (parameters in the order of millions). Full fine-tuning (which updates all the parameters in the model) is not required for most applications because typical fine-tuning datasets are relatively much smaller than the pre-training datasets.

[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685){:.external} is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs.

This tutorial walks you through using KerasNLP to perform LoRA fine-tuning on a Gemma 2B model using the [Databricks Dolly 15k dataset](https://www.kaggle.com/datasets/databricks/databricks-dolly-15k){:.external}. This dataset contains 15,000 high-quality human-generated prompt / response pairs specifically designed for fine-tuning LLMs.

## Setup

### Get access to Gemma

To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

Gemma models are hosted by Kaggle. To use Gemma, request access on Kaggle:

- Sign in or register at [kaggle.com](https://www.kaggle.com)
- Open the [Gemma model card](https://www.kaggle.com/models/google/gemma) and select _"Request Access"_
- Complete the consent form and accept the terms and conditions


### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this tutorial, configure the backend for JAX.

In [1]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("__gcloud_sdk_auth__")

In [3]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

# Get the automatically generated Google Cloud authentication credentials
gcloud_sdk_auth = user_secrets.get_secret("__gcloud_sdk_auth__")

In [4]:
import json
import os

# Save the credentials to a file
with open("/tmp/gcloud_auth.json", "w") as f:
    f.write(gcloud_sdk_auth)

# Set the environment variable for Google Cloud credentials
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/tmp/gcloud_auth.json"

In [5]:
from google.cloud import automl_v1 as automl

# Example: Instantiate the AutoML client
automl_client = automl.AutoMlClient()

# Optionally, print the current project ID for confirmation
project_id = 'powerful-host-435901-c5'
print(f"Connected to project: {project_id}")

Connected to project: powerful-host-435901-c5


In [6]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Install dependencies

Install Keras, KerasNLP, and other dependencies.

In [None]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
#!pip install -q -U keras-nlp
#!pip install -q -U keras>=3
#!pip install nltk
#!pip install --upgrade nltk
#!pip install --upgrade --quiet keras-nlp
#!pip install --upgrade google-cloud-automl

In [None]:
#!pip uninstall -y google-cloud-bigquery google-cloud-storage pandas pydantic jupyterlab

In [None]:
# Install compatible versions
#!pip install pandas==1.5.3  # Compatible version for bigframes
#!pip install google-cloud-bigquery==3.10.0  # A version that works well
#!pip install google-cloud-storage==2.5.0  # Already compatible
#!pip install pydantic==1.10.0  # For compatibility with dataproc-jupyter-plugin
#!pip install jupyterlab==3.6.0  # Compatible version for beatrix-jupyterlab

In [None]:
#!pip install tqdm --upgrade

### Import packages

In [None]:
import numpy as np
import pandas as pd
import random
#import nltk
#from nltk.corpus import wordnet
#nltk.download('averaged_perceptron_tagger')  # for POS tagging, if required
import keras
import keras_nlp
from tqdm import tqdm
tqdm.pandas() # progress bar for pandas

In [None]:
# Specify a custom directory for nltk_data
#nltk_data_dir = '/kaggle/working/nltk_data'
#if not os.path.exists(nltk_data_dir):
    #os.makedirs(nltk_data_dir)

# Point NLTK to the custom directory
#nltk.data.path.append(nltk_data_dir)

# Download required corpora to the specified directory
#nltk.download('wordnet', download_dir=nltk_data_dir)
#nltk.download('wordnet')
#!unzip /usr/share/nltk_data/corpora/wordnet.zip -d /usr/share/nltk_data/corpora/
#nltk.download('omw-1.4', download_dir=nltk_data_dir)
#nltk.download('omw-1.4')

## Load Dataset

In [None]:
file = '/kaggle/input/new-dataset/augmented_dataset.csv'
df = pd.read_csv(file)
df.head(10)  # View the first 10 rows

In [None]:
# Check the shape of the data
print(df.shape)

In [None]:
#import unicodedata

# Function to get synonyms using WordNet
#def get_synonyms(word):
    #synonyms = set()
    #for syn in wordnet.synsets(word):
        #for lemma in syn.lemmas():
            #synonyms.add(lemma.name())
    #return list(synonyms)

# Function to clean text and fix encoding issues
#def clean_text(text):
    # Normalize text to remove special encoding artifacts
    #text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode('utf-8', 'ignore')
    
    # Replace common problematic characters
    #text = text.replace('â€™', "'").replace('â€œ', '"').replace('â€�', '"')
    #return text

# Function to perform synonym replacement and clean augmented text
#def synonym_replacement(sentence):
    #words = sentence.split()
    #new_sentence = []
    #for word in words:
        #synonyms = get_synonyms(word)
        #if synonyms and random.random() > 0.7:  # 30% chance of replacement
            #new_word = random.choice(synonyms)
            #new_sentence.append(new_word)
        #else:
            #new_sentence.append(word)
    
    # Join the new sentence and clean it
    #augmented_sentence = ' '.join(new_sentence)
    
    # Clean any encoding issues that may have been introduced
    #return clean_text(augmented_sentence)

# Augment the 'Question' column of the DataFrame and clean the output
#augmented_questions = []

# Repeat augmentation 4 times for each question
#for question in df['Question']:
    #for _ in range(4):  # Repeat augmentation 4 times
        #augmented_question = synonym_replacement(question)
        #augmented_questions.append(augmented_question)

# Create a new DataFrame with augmented and clean data
#augmented_df = pd.DataFrame({'Question': augmented_questions, 'Answer': pd.concat([df['Answer']] * 4).reset_index(drop=True)})

# Concatenate the original and augmented dataframes to expand the dataset
#data = pd.concat([df, augmented_df], ignore_index=True)

# Save the expanded DataFrame to a CSV file with proper encoding
#data.to_csv('expanded_questions_clean.csv', index=False, encoding='utf-8')

# Output the new DataFrame
#print(data)

In [None]:
#Split the dataset into training and validation

from sklearn.model_selection import train_test_split

# Step 1: Split the dataset (80% train, 20% validation)
train_data, val_data = train_test_split(df, test_size=0.2, random_state=42)

# Step 2: Define the template 
template = "\n\nQuestion:\n{Question}\n\nAnswer:\n{Answer}"

# Step 3: Apply the template to the training data
train_data["prompt"] = train_data.progress_apply(
    lambda row: template.format(
        Question=row['Question'],
        Answer=row['Answer']
    ), axis=1
)

# Step 4: Apply the template to the validation data
val_data["prompt"] = val_data.progress_apply(
    lambda row: template.format(
        Question=row['Question'],
        Answer=row['Answer']
    ), axis=1
)

# Step 5: Prepare training and validation data
train_x = np.array(train_data['prompt'].tolist())  # Input features (Formatted prompts)
train_y = np.array(train_data['Answer'].tolist())    # Corresponding labels (Answers)

val_x = np.array(val_data['prompt'].tolist())  # Validation input features (Formatted prompts)
val_y = np.array(val_data['Answer'].tolist())    # Corresponding validation labels

In [None]:
#Step 6a
train_data["prompt"] = train_data.progress_apply(lambda row: template.format(Question=row.Question,
                                                             Answer=row.Answer), axis=1)
training_data = train_data.prompt.tolist()

In [None]:
#step 6b
val_data["prompt"] = val_data.progress_apply(lambda row: template.format(Question=row.Question,
                                                             Answer=row.Answer), axis=1)
validation_data = val_data.prompt.tolist()

In [None]:
# Apply the template to each row, removing the reference to Category
#data["prompt"] = data.progress_apply(
    #lambda row: template.format(
        #Question=row['Question'],
        #Answer=row['Answer']
    #), axis=1
#)

In [None]:
#data["prompt"] = data.progress_apply(lambda row: template.format(Question=row.Question,
                                                             #Answer=row.Answer), axis=1)
#expanded_data = data.prompt.tolist()

## Load Model

KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/){:.external}. In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [None]:
#Step 7 Load the model
#version 2 of gemma
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

The `from_preset` method instantiates the model from a preset architecture and weights. In the code above, the string "gemma_2b_en" specifies the preset architecture — a Gemma model with 2 billion parameters.

NOTE: A Gemma model with 7
billion parameters is also available. To run the larger model in Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform [distributed tuning on a Gemma 7B model](https://ai.google.dev/gemma/docs/distributed_tuning) on Kaggle or Google Cloud.

In [None]:
#version 7 of gemma
#gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
#gemma_lm.summary()

## Inference before fine tuning

In this section, you will query the model with various prompts to see how it responds.


### Stay on track to quit smoking

Query the model for how to stay on track to quit smoking.

In [None]:
# Test a single example prompt for inference
test_prompt = template.format(
    Question="What is the right way to quit?",  # Example question
    Answer="",  # Empty answer, since we're generating the answer
)

# Generate response for this specific test prompt
print(gemma_lm.generate(test_prompt, max_length=100))

In [None]:
test_prompt = template.format(
    Question="Explain the process of quitting smoking in a way that a child could understand.",
    Answer="",
)
print(gemma_lm.generate(test_prompt, max_length=40))

## LoRA Fine-tuning

To get better responses from the model, fine-tune the model with Low Rank Adaptation (LoRA) using the Question and Answer smoking ceasation dataset dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [None]:
# Enable LoRA for the model and set the LoRA rank to 8.
gemma_lm.backbone.enable_lora(rank=8)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.5 billion to 1.3 million).

## Inference after fine-tuning
After fine-tuning, responses follow the instruction provided in the prompt.

In [None]:
#integrate learning rate scheduling and early stopping into the training workflow
from keras.callbacks import EarlyStopping, LearningRateScheduler

# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.05,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Define learning rate scheduler
def lr_schedule(epoch):
    initial_lr = 5e-5
    if epoch > 0 and epoch % 2 == 0:  # Example: Reduce every 2 epochs
        return max(initial_lr * 0.2, 1e-6)  # Reduce by a factor of 0.2
    return initial_lr

# Create callback for early stopping
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=2,  # Stop if no improvement for 2 epochs
    restore_best_weights=True
)
# Create learning rate scheduler callback
lr_scheduler = LearningRateScheduler(lr_schedule)

num_epochs = 5
for epoch in range(num_epochs):
    #Fine-tune the model using the prepared data
    history = gemma_lm.fit(
        x=train_x,               # Training inputs
        y=train_y,               # Training outputs
        validation_data=(val_x, val_y),  # Validation data
        epochs=1,                # Number of epochs (adjust as needed)
        batch_size=1,            # Batch size (adjust as needed)
        #callbacks=[early_stopping, lr_scheduler]  # Callbacks
    )
    #gemma_lm.fit(training_data, epochs=3, batch_size=1)

    gemma_lm.backbone.save_lora_weights(f'/kaggle/working/weights_epoch_{epoch + 1:02d}.lora.h5')

In [None]:
#gemma_lm.backbone.load_lora_weights("/kaggle/working/weights_epoch_5.lora.h5") #important to run but if the last epoch is the one with the least loss, then running it wont make any difference

In [None]:
# Limit the input sequence length to 512 (to control memory usage).
#gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
#optimizer = keras.optimizers.AdamW(
    #learning_rate=5e-5,
    #weight_decay=0.05,
#)
# Exclude layernorm and bias terms from decay.
#optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

#gemma_lm.compile(
    #loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    #optimizer=optimizer,
    #weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
#)
#gemma_lm.fit(data, epochs=3, batch_size=1)

In [None]:
# Limit the input sequence length to 512 (to control memory usage).
#gemma_lm.preprocessor.sequence_length = 512

# Use AdamW (a common optimizer for transformer models).
#optimizer = keras.optimizers.AdamW(
    #learning_rate=5e-5,  # Adjust learning rate as needed
    #weight_decay=0.05,   # Adjust weight decay as needed
#)

# Exclude layernorm and bias terms from decay.
#optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

# Compile the model with the specified loss function and optimizer.
#gemma_lm.compile(
    #loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    #optimizer=optimizer,
    #weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
#)

# Fit the model with training and validation data
#gemma_lm.fit(
    #train_x,              # Training input features
    #train_y,              # Training labels
    #epochs=3,             # Number of epochs
    #batch_size=3,         # Set an appropriate batch size
    #validation_data=(val_x, val_y)  # Validation data
#)

### Stay on track to quit smoking Prompt


In [None]:
prompt = template.format(
    Question="What is the right way to quit?",
    Answer="",
)
print(gemma_lm.generate(prompt, max_length=256))

In [None]:
prompt = template.format(
    Question="Which method is the best for quitting?",
    Answer="",
)
print(gemma_lm.generate(prompt, max_length=200))


In [None]:
prompt = template.format(
    Question="How can I cope with or avoid triggers?",
    Answer="",
)
print(gemma_lm.generate(prompt, max_length=250))

Note that for demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:

1. Increasing the size of the fine-tuning dataset
2. Training for more steps (epochs)
3. Setting a higher LoRA rank
4. Modifying the hyperparameter values such as `learning_rate` and `weight_decay`.


## Summary and next steps

This tutorial covered LoRA fine-tuning on a Gemma model using KerasNLP. Check out the following docs next:

* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).
* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/distributed_tuning).
* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).
* Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb).


[Product Launch] Publish your fine-tuned Keras models to Kaggle in a few lines of code

In [None]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
#!pip install -q -U keras-nlp
#!pip install -q -U keras>=3
#!pip install nltk
#!pip install --upgrade nltk
#!pip install --upgrade --quiet keras-nlp
#!pip install --upgrade google-cloud-automl

In [None]:
# For building keras-nlp locally.
#!apt install -q python3.10-venv
# Install all deps
#!pip install -q -U tf-keras tensorflow-text kagglehub
# Install keras-nlp
#!pip install keras-nlp --upgrade
# Install keras
#!pip install -q -U keras

In [None]:
#import json

# Path to the uploaded kaggle.json file
#json_path = '/kaggle/input/json-key/kaggle.json'  

# Load the kaggle.json file
#with open(json_path, 'r') as f:
    #kaggle_api = json.load(f)

# Extract the username and key
#kaggle_username = kaggle_api['username']
#kaggle_key = kaggle_api['key']

#print("Username:", kaggle_username)
#print("API Key:", kaggle_key)  


In [None]:

# Save the finetuned model as a KerasNLP preset.
preset_dir = "./finetuned_gemma"

#gemma_lm.save_to_preset(preset_dir)
gemma_lm.save_to_preset(preset_dir)

# Upload the preset as a new model variant on Kaggle
kaggle_username = "rhodanankabirwa5"  # Set your actual Kaggle username here
kaggle_uri = f"kaggle://{kaggle_username}/gemma/keras/finetuned_gemma"
keras_nlp.upload_preset(kaggle_uri, preset_dir)


In [None]:
# List the contents of the current working directory
print("Working directory contents before saving:")
print(os.listdir('/kaggle/working/'))

In [None]:
# Check where the model was saved
print("Contents after saving:")
print(os.listdir('/kaggle/working/'))

In [None]:
!nvidia-smi

In [None]:
import keras
import gc

# Clear Keras session to release memory
keras.backend.clear_session()
gc.collect()


In [None]:
import keras_nlp
import gc
from keras import mixed_precision
from keras import backend as K

# Clear Keras session and collect garbage
K.clear_session()
gc.collect()

# Enable mixed precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

# Replace with your actual Kaggle username
kaggle_username = "rhodanankabirwa5"
model_path = f"kaggle://{kaggle_username}/gemma/keras/finetuned_gemma"

try:
    # Load the fine-tuned model
    finetuned_model = keras_nlp.models.GemmaCausalLM.from_preset(model_path)
except ValueError as e:
    print(f"Error loading model: {e}")


In [8]:
import keras_nlp

In [10]:
# Replace with your actual Kaggle username
kaggle_username = "rhodanankabirwa5"
model_path = f"kaggle://{kaggle_username}/gemma/keras/finetuned_gemma"

In [11]:
# Load the model that was just uploaded to Kaggle
finetuned_model = keras_nlp.models.GemmaCausalLM.from_preset(f"kaggle://{kaggle_username}/gemma/keras/finetuned_gemma")

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [12]:
# Verify the model is loaded
print("Model loaded successfully:", finetuned_model)

Model loaded successfully: <GemmaCausalLM name=gemma_causal_lm, built=True>
