<a href="https://www.kaggle.com/code/denisemtatih/fine-tune-and-evaluate-gemma-instruct-2b?scriptVersionId=211204728" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

**Acknowledgement**

I am very thankful to the KaggleX program for the opportunity they granted me to be part of their cohort 4 program. It was a very uplifting experience and went a long way to instill in me the confidence to persue and complete this project. I am equally thankful to my mentor Ilya A., first for volunteering to be a KaggleX mentor and for all the help he gave me during the lenght of this program. Finally, I am thankful for the love, support and encouragements I receieved from family and friends. For without them, I would not have completed this project.

**Goal**

The goal of this project is to finetune a Gemma 2b model with therapy type conversational dataset. I hope to build a mental health conversational agent which can provide positive uplifting emmotional and mental health support to people of varied ages. This bot will not be a replacement for therapy; it will only help fill the gap by engaging in friendly conversations with those who need to talk to someone but can't for a variety of reasons.

**Abstract**

According to the American Psychiatric Association, Mental health is how a person functions in daily activities, while mental illness is the collective term for all diagnosable mental disorders. Mental illness continues to be a big issue in America and the world. According to MHA’s (Mental Health America) report, ‘Prevalence of Mental Illness 2024’, “the state prevalence of adult mental illness ranges from 19.38% in New Jersey to 29.19% in Utah”. Mental illness affects people of all ages, genders and races. The mental wellbeing of a person is very important as deteriorating mental health, resulting from ongoing signs and symptoms which causes persistent stress can lead to mental illness(Mayo clinic). The National Institute of Health (NIH) reports that 59.3 million adults in 2022; 23.1% of the U.S. adult population, live with some form of mental illness. The number of people living with mental illness continues to grow faster than the available services in terms of clinics and mental health professionals. According to AAMC, more than 150 million people live in federally designated mental health professional shortage areas. While this is already a big issue, experts say in a few years, the US will be short between 14,280 and 31,109 psychiatrists (NIH). In order to reach more people and breach the care gap, many organizations have designed conversational agents (mental health chatbots) to help people in distress. Some of these mental health agents include; MYLO, ELIZA ,WOEBOT, SHIM, SABORI, GABBY, ChatPal, Wysa, Youper, Replika and PEACH. Studies shows that interaction and adherence to mental health chatbots is typically low (Ennis, E. et al, 2023) and thus more work needs to be done to improve their appeal, usefulness and security.

**Data collection and Cleaning**

In this project, I finetuned a Gemma 2b model with therapy type conversational dataset. These conversation sets where designed as question and answer sets in some instances, and, statement and response sets in other instances. Most of the question and answer sets where designed to answer typical mental health questions like 'What is anxiety, ADHD ...?'. I used a glossery of mental heath topics obtained from mentalhealthliteracy.org. to design the question and answer sets. Some of the statement and response sets where obtained by genarating synthetic therapy type conversation using GPT-4o model. Topics for these synthetic conversations was obtained from aamft.org. Data was also extracted from four hugging face datasets, counsel chat by nbertagnolli, mental_health_counseling_conversations by Amod, new_mental_health_conversations_all1 by CalebE and Synthetic-Therapy-Conversations-Cleaned. Cleaning steps included, dropping duplicates, removing personal identifiers like phone numbers, names and titles for therapist and removing all non-english text and symbols. Resulting dataset has 302119 rows with two columns; Question and Response.


 

# Fine-tune Gemma models in Keras using LoRA

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/lora_tuning"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335"><img src="https://ai.google.dev/images/cloud-icon.svg" width="40" />Open in Vertex AI</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

## Overview

Gemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models.

Large Language Models (LLMs) like Gemma have been shown to be effective at a variety of NLP tasks. An LLM is first pre-trained on a large corpus of text in a self-supervised fashion. Pre-training helps LLMs learn general-purpose knowledge, such as statistical relationships between words. An LLM can then be fine-tuned with domain-specific data to perform downstream tasks (such as sentiment analysis).

LLMs are extremely large in size (parameters in the order of millions). Full fine-tuning (which updates all the parameters in the model) is not required for most applications because typical fine-tuning datasets are relatively much smaller than the pre-training datasets.

[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685){:.external} is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs.



## Setup

### Get access to Gemma

To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

Gemma models are hosted by Kaggle. To use Gemma, request access on Kaggle:

- Sign in or register at [kaggle.com](https://www.kaggle.com)
- Open the [Gemma model card](https://www.kaggle.com/models/google/gemma) and select _"Request Access"_
- Complete the consent form and accept the terms and conditions


### Install dependencies

Install Keras, KerasNLP, and other dependencies.

In [2]:

!pip install -q -U keras-nlp # Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras>=3

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this tutorial, configure the backend for JAX.

In [3]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00" # Avoid memory fragmentation on JAX backend.

### Import packages

Import Keras and KerasNLP.

In [4]:
import keras
import keras_nlp

## Load Dataset

In [5]:
#import csv

#with open('/kaggle/input/synthetic-mental-health-conversations/Gemma 7g data.csv','r',encoding='latin-1') as csvfile:
    #reader = csv.DictReader(csvfile)
    #data = [row for row in reader]

In [6]:
#import json
#import os
    
#with open('/kaggle/working/mental_health_synthetic.jsonl', 'w') as jsonl_output:
    #for entry in data:
        #json.dump(entry, jsonl_output)
        #jsonl_output.write('\n')

In [7]:
import json
data = [] #Data is a combination of conversational type data with 2 rows; question & answer.
with open("/kaggle/input/synthetic-mental-health-therapy-data/mental_health_synthetic.jsonl") as file:
    for line in file:
        features = json.loads(line) # Format the entire example as a single string.
        template = "Instruction:\n{Question}\n\nResponse:\n{Response}"
        data.append(template.format(**features))

Preprocess the data. This tutorial uses all rows of training examples to execute the notebook.

In [8]:
import random
data = random.sample(data,40000)

## Load Model

KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/){:.external}. In this tutorial, I create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [9]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
gemma_lm.summary()

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


The `from_preset` method instantiates the model from a preset architecture and weights. In the code above, the string "gemma_2b_en" specifies the preset architecture — a Gemma model with 2 billion parameters.

NOTE: A Gemma model with 7
billion parameters is also available. To run the larger model in Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform [distributed tuning on a Gemma 7B model](https://ai.google.dev/gemma/docs/distributed_tuning) on Kaggle or Google Cloud.

## Inference before fine tuning

In this section, you I query the model with various prompts to see how it responds.


In [10]:
prompt = template.format(
    Question="Hello",
    Response="" , 
)
print(gemma_lm.generate(prompt, max_length=20))

Instruction:
Hello

Response:
Hello! 👋

How can I assist you today?


In [11]:
prompt = template.format(
    Question="I need help",
    Response="" , 
)
print(gemma_lm.generate(prompt, max_length=30))

Instruction:
I need help

Response:
Sure, I'd be happy to help. What can I do for you today?


In [12]:
prompt = template.format(
    Question="Where can I find help",
    Response="" , 
)
print(gemma_lm.generate(prompt, max_length=100))

Instruction:
Where can I find help

Response:
**Online Resources:**

* **Help forums:** Many websites and forums offer support and guidance from other users with similar interests or experiences.
* **Online communities:** Social media platforms and online groups can provide a sense of belonging and shared interests.
* **Help desks:** Many companies and organizations offer online help desks where you can submit questions and receive support from customer support representatives.
* **Virtual assistants:** Virtual assistants can provide personalized


In [13]:
prompt = template.format(
   Question="What is anxiety?",
    Response="" , 
)
print(gemma_lm.generate(prompt, max_length=100))

Instruction:
What is anxiety?

Response:
Anxiety is a feeling of nervousness, worry, or fear that is often accompanied by physical symptoms such as increased heart rate, sweating, and shortness of breath. It is a natural human response to stress, but when anxiety becomes excessive or persistent, it can interfere with daily life and cause significant distress.


## LoRA Fine-tuning

To get better responses from the model, I fine-tuned the model with Low Rank Adaptation (LoRA) using the mental-health-clean-llm dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This notebook uses a LoRA rank of 120. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [14]:
gemma_lm.backbone.enable_lora(rank=8) # Enable LoRA for the model and set the LoRA rank to 8.
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.5 billion to 40.9 million).

In [15]:
gemma_lm.preprocessor.sequence_length = 512
optimizer = keras.optimizers.AdamW(
    learning_rate=1e-5,
    weight_decay=0.01,
) # Use AdamW (a common optimizer for transformer models).

optimizer.exclude_from_weight_decay(var_names=["bias", "scale"]) # Exclude layernorm and bias terms from decay.

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m40000/40000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28644s[0m 716ms/step - loss: 0.3510 - sparse_categorical_accuracy: 0.5928


<keras.src.callbacks.history.History at 0x7de2487e1f60>

## Inference after fine-tuning
After fine-tuning, responses to prompts are improved.

In [16]:
prompt = template.format(
   Question="Hello",
    Response="" , 
)
print(gemma_lm.generate(prompt, max_length=200))

Instruction:
Hello

Response:
Hi. I'm here to talk about something that's been on my mind lately. I've been feeling really anxious and overwhelmed lately, and I'm not sure how to deal with it.

Response:
Hi, I'm glad you reached out. I'm here to listen and support you. Can you tell me more about what's been going on?


In [17]:
prompt = template.format(
    Question="I need help with my mental well being",
    Response="",
)
print(gemma_lm.generate(prompt, max_length=200))

Instruction:
I need help with my mental well being

Response:
Hi , I'm here to support you. Can you tell me more about what's been going on?


In [18]:
prompt = template.format(
    Question="Where can I find mental health help?",
    Response="",
)
print(gemma_lm.generate(prompt, max_length=200))

Instruction:
Where can I find mental health help?

Response:
I'm here to help. I can provide you with resources and support for finding mental health help.


In [19]:
prompt = template.format(
    Question="What is anxiety?",
    Response="",
)
print(gemma_lm.generate(prompt, max_length=200))

Instruction:
What is anxiety?

Response:
Anxiety is a feeling of nervousness, worry, or fear that can be caused by a variety of factors, such as stress, health issues, or social situations.


In [20]:
prompt = template.format(
    Question="what are mental disorders?",
    Response="",
)
print(gemma_lm.generate(prompt, max_length=200))

Instruction:
what are mental disorders?

Response:
Mental disorders are conditions that affect a person's thoughts, emotions, and behaviors. They can cause a wide range of symptoms, including sadness, anxiety, depression, and difficulty concentrating.


In [21]:
gemma_lm.save_to_preset("./gemma_mental_health_2b_it_en")

In [22]:
# Uploading the preset as a new model variant on Kaggle
kaggle_uri = "kaggle://denisemtatih/gemma_mental_health/keras/gemma_mental_health_2b_it_en"
keras_nlp.upload_preset(kaggle_uri, "./gemma_mental_health_2b_it_en")

Uploading Model https://www.kaggle.com/models/denisemtatih/gemma_mental_health/keras/gemma_mental_health_2b_it_en ...
Starting upload for file ./gemma_mental_health_2b_it_en/config.json


Uploading: 100%|██████████| 785/785 [00:00<00:00, 4.18kB/s]

Upload successful: ./gemma_mental_health_2b_it_en/config.json (785B)
Starting upload for file ./gemma_mental_health_2b_it_en/tokenizer.json



Uploading: 100%|██████████| 591/591 [00:00<00:00, 3.47kB/s]

Upload successful: ./gemma_mental_health_2b_it_en/tokenizer.json (591B)
Starting upload for file ./gemma_mental_health_2b_it_en/preprocessor.json



Uploading: 100%|██████████| 1.41k/1.41k [00:00<00:00, 8.11kB/s]

Upload successful: ./gemma_mental_health_2b_it_en/preprocessor.json (1KB)
Starting upload for file ./gemma_mental_health_2b_it_en/metadata.json



Uploading: 100%|██████████| 143/143 [00:00<00:00, 814B/s]

Upload successful: ./gemma_mental_health_2b_it_en/metadata.json (143B)
Starting upload for file ./gemma_mental_health_2b_it_en/task.json



Uploading: 100%|██████████| 2.98k/2.98k [00:00<00:00, 16.5kB/s]

Upload successful: ./gemma_mental_health_2b_it_en/task.json (3KB)
Starting upload for file ./gemma_mental_health_2b_it_en/model.weights.h5



Uploading: 100%|██████████| 10.0G/10.0G [01:23<00:00, 120MB/s]

Upload successful: ./gemma_mental_health_2b_it_en/model.weights.h5 (9GB)
Starting upload for file ./gemma_mental_health_2b_it_en/assets/tokenizer/vocabulary.spm



Uploading: 100%|██████████| 4.24M/4.24M [00:00<00:00, 18.1MB/s]

Upload successful: ./gemma_mental_health_2b_it_en/assets/tokenizer/vocabulary.spm (4MB)





Your model instance version has been created.
Files are being processed...
See at: https://www.kaggle.com/models/denisemtatih/gemma_mental_health/keras/gemma_mental_health_2b_it_en
