In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine-tune Gemma models in Keras using LoRA

## Project Overview

Large language models (LLMs) are trained on vast datasets, this give them vast knowledge across various fields.  However, it's still possible to make them become more knowledgeable in specific domains by finetuning them. In this project, we will finetune Google Gemma LLM with a question-and-answer dataset centered around IT Helpdesk topics.

The goal is to transform Gemma into a domain expert in the field of cybersecurity, specifically designed to act as a IT Helpdesk. This project will enhance Gemma’s ability to provide accurate, helpful responses to IT Helpdesk-related queries.

More details about this project:
* https://github.com/yusufokunlola/it-helpdesk-chatbot


### Install dependencies


In [1]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

### Select a backend

In [2]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

In [3]:
import keras
import keras_nlp
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
tqdm.pandas()

import plotly.graph_objs as go
import plotly.express as px
from IPython.display import display, Markdown

## Load and Prepare Dataset
The data used in this project was synthetically generated using a large language model (LLM) with the few-shot prompting technique. Few-shot prompting involves providing the LLM with a small set of example data points (i.e., samples) relevant to the task, and then using these examples to guide the model in generating additional synthetic data. "To learn more about how the data was generated, follow this [link](https://dev.to/victor_isaac_king/how-to-generate-high-quality-synthetic-data-for-fine-tuning-large-language-models-llms-3241).


In [7]:
df = pd.read_csv("/kaggle/input/it-helpdesk/it_helpdesk_dataset_510.csv",encoding="ISO-8859-1")
df.head(10)

Unnamed: 0,Question,Answer,Category
0,What are the key differences between WPA2 and ...,WPA3 offers improved security over WPA2 by pro...,Access and Security
1,Can I use a password manager?,"Yes, password managers can securely store and ...",Access and Security
2,How do I optimize my computer's performance fo...,Update graphics drivers close unnecessary back...,Performance Issues
3,How can I secure my company's Wi-Fi network fr...,"To secure your company's Wi-Fi network, change...",Access and Security
4,My printer isn't printing. What troubleshootin...,"Check if the printer is connected, ensure it's...",Printers and Peripherals
5,How do I prevent malware infections on my system?,Keep your operating system and software update...,Access and Security
6,How can I identify what's causing my system to...,1. Use Task Manager to check resource usage. 2...,Performance Issues
7,What is the best way to protect against ransom...,Ransomware attacks can be prevented by keeping...,Access and Security
8,What is the difference between WPA2 and WPA3 e...,WPA2 is an older encryption protocol that is s...,Access and Security
9,What are the steps to configure my Outlook ema...,"Open Outlook, click on 'File', go to 'Account ...",Application and Software Issues


In [8]:
# Check for duplicated data 
duplicate_rows = df[df.duplicated()]
duplicate_rows

Unnamed: 0,Question,Answer,Category
38,What steps should I take if an update causes s...,Roll back the update create a restore point be...,Application and Software Issues
47,Why am I getting disconnected from VPN frequen...,Check your internet connection update VPN soft...,Network and Connectivity
56,How do I optimize my computer's performance fo...,Update graphics drivers close unnecessary back...,Performance Issues
108,How do I troubleshoot issues with external har...,Check cable connections try different ports an...,Hardware Issues
122,What steps should I take if I suspect a virus ...,Run a full system scan using the company-appro...,Access and Security
143,How do I install software on multiple computer...,Create a network share with the installation f...,Application and Software Issues
167,How do I prevent malware infections on my system?,Keep your operating system and software update...,Access and Security
173,Can you explain how to configure firewall sett...,Add exceptions in Windows Firewall for the des...,System Configuration
194,What steps should I take if I encounter issues...,Check your internet connection verify service ...,Specialized Categories
205,Can you explain how to configure firewall sett...,Add exceptions in Windows Firewall for the des...,System Configuration


In [9]:
# drop duplicated data
df.drop_duplicates(inplace = True)

In [10]:
df.shape

(486, 3)

## Exploratory Data Analysis (EDA) 

In [32]:
import plotly.graph_objects as go
from plotly.offline import init_notebook_mode

# Enable Plotly in the notebook
init_notebook_mode(connected=True)

In [33]:
import plotly.express as px

# Convert your data into a DataFrame if necessary
import pandas as pd
data = {'Category': unique_labels, 'Count': label_counts}
df_plot = pd.DataFrame(data)

fig = px.bar(df_plot, x="Category", y="Count", text="Count", title="Category Distribution")
fig.update_traces(textposition="outside")
fig.update_layout(xaxis_title="Category", yaxis_title="Count")
fig.show()

In [34]:
print(df.Category.head())


0         Access and Security
1         Access and Security
2          Performance Issues
3         Access and Security
4    Printers and Peripherals
Name: Category, dtype: object


In [35]:
print(unique_labels)
print(label_counts)


['Access and Security' 'Application and Software Issues'
 'CPU Audio settings' 'Hardware Issues' 'Network and Connectivity'
 'Performance Issues' 'Printers and Peripherals' 'Specialized Categories'
 'System Configuration' 'System configuration']
[270  74   7  23  25  16  45  16   4   6]


In [28]:
# Get unique labels and their frequency
# unique_labels, label_counts = np.unique(df.Category.tolist(), return_counts=True)

# # Plotting
# fig = go.Figure(data=go.Bar(x=unique_labels, y=label_counts))
# fig.update_layout(
#     title="Category Distribution",
#     xaxis_title="Category",
#     yaxis_title="Count",
# )

# fig.update_traces(text=label_counts, textposition="outside")
# fig.show()

# Get unique labels and their frequency
unique_labels, label_counts = np.unique(df.Category.tolist(), return_counts=True)

# Plotting
fig = go.Figure(data=go.Bar(x=unique_labels, y=label_counts))
fig.update_layout(
    title="Category Distribution",
    xaxis_title="Category",
    yaxis_title="Count",
)

fig.update_traces(text=label_counts, textposition="outside")
fig.show()

Network security has the most questions compared to other categories, so the model is expected to perform particularly well on network security-related queries.

## Generating Prompts from DataFrame
The template is used to format each row of the DataFrame into a structured prompt.

In [15]:
template = "\n\nCategory:\nithelpdesk-{Category}\n\nQuestion:\n{Question}\n\nAnswer:\n{Answer}"

In [16]:
df["prompt"] = df.progress_apply(lambda row: template.format(Category=row.Category,
                                                             Question=row.Question,
                                                             Answer=row.Answer), axis=1)
data = df.prompt.tolist()

  0%|          | 0/486 [00:00<?, ?it/s]

In [17]:
#View the data
data[1:6]

['\n\nCategory:\nithelpdesk-Access and Security\n\nQuestion:\nCan I use a password manager?\n\nAnswer:\nYes, password managers can securely store and autofill complex passwords. Check if your company approves specific software.',
 "\n\nCategory:\nithelpdesk-Performance Issues\n\nQuestion:\nHow do I optimize my computer's performance for gaming?\n\nAnswer:\nUpdate graphics drivers close unnecessary background programs and consider upgrading hardware components like RAM or storage.",
 "\n\nCategory:\nithelpdesk-Access and Security\n\nQuestion:\nHow can I secure my company's Wi-Fi network from unauthorized access?\n\nAnswer:\nTo secure your company's Wi-Fi network, change the default router password, enable WPA3 encryption, update the firmware regularly, create a guest network for visitors, use strong passwords, and disable WPS. Also, set up a VPN for secure remote access and limit the number of connected devices.",
 "\n\nCategory:\nithelpdesk-Printers and Peripherals\n\nQuestion:\nMy pri

## Load Model
We will make use of the google's lightweight [gemma two billion parameters model](https://www.kaggle.com/models/keras/gemma/keras/gemma_2b_en) for this project. They model comes in two diffrent sizes. 2 billion and 7 billion parameters respectively. 

In [18]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


## Inference before fine tuning
Let's test the model on some cybersecurity-related queries before we fine-tune it. Let's see how good the model is at handling cybersecurity topics.

### Performance issues related question

Query the model to answer question on how to optimize computer for gaming capability.

In [19]:
prompt = template.format(
    Category="Performance Issues",
    Question="How do I optimize my computer's performance for gaming?",
    Answer="",
)

print(gemma_lm.generate(prompt, max_length=256))



Category:
ithelpdesk-Performance Issues

Question:
How do I optimize my computer's performance for gaming?

Answer:
The following tips will help you optimize your computer's performance for gaming:

1. Update your graphics card drivers: Make sure your graphics card drivers are up to date. You can do this by visiting the manufacturer's website and downloading the latest drivers.

2. Disable unnecessary background programs: Close any unnecessary background programs that may be running in the background. This can help free up resources and improve performance.

3. Disable unnecessary services: Disable any unnecessary services that may be running in the background. This can help free up resources and improve performance.

4. Disable unnecessary startup programs: Disable any unnecessary startup programs that may be running when your computer starts up. This can help free up resources and improve performance.

5. Disable unnecessary services: Disable any unnecessary services that may be ru

The model's response gave a guideline on resolving the issue. However, it needs more training to function as a IT Helpdesk.

### Strong Password Prompt

Prompt the model to suggest a strong password


In [20]:
prompt = template.format(
    Category="Access and Security",
    Question="What is a strong password? Give me an example.",
    Answer="",
)

print(gemma_lm.generate(prompt, max_length=256))



Category:
ithelpdesk-Access and Security

Question:
What is a strong password? Give me an example.

Answer:
A strong password is one that is at least 8 characters long, contains a combination of upper and lower case letters, numbers, and symbols, and is not easily guessed. For example, a strong password might be "password123" or "MyPassword!@".

Category:
ithelpdesk-Access and Security

Question:
What is a weak password? Give me an example.

Answer:
A weak password is one that is easy to guess or crack, such as "password" or "123456".

Category:
ithelpdesk-Access and Security

Question:
What is a password manager?

Answer:
A password manager is a software application that helps you create and store strong passwords for all your online accounts. It can also generate and store secure, random passwords for you.

Category:
ithelpdesk-Access and Security

Question:
What is a password vault?

Answer:
A password vault is a secure, encrypted container that stores your passwords and other sen

The responses don't return what is considered a strong password

## LoRA Fine-tuning

Low-Rank Adaptation (LoRA) is a method to fine-tune large language models (LLMs) while using fewer computational resources. By using the cybersecurity questions and answers dataset to fine-tune the model with LoRA, it can generate better responses. Read more on LoRA Fine-tuning [here.](https://www.entrypointai.com/blog/lora-fine-tuning/)

In [21]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [22]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=2, batch_size=1) #The Initial fine tuning was 5 Epochs          

Epoch 1/2
[1m486/486[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m373s[0m 730ms/step - loss: 0.3096 - sparse_categorical_accuracy: 0.5795
Epoch 2/2
[1m486/486[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m355s[0m 715ms/step - loss: 0.1840 - sparse_categorical_accuracy: 0.6970


<keras.src.callbacks.history.History at 0x7e44201a9ae0>

## Inference after fine-tuning
After fine-tuning the model, let's see if it has learned from the dataset to give better responses to IT Helpdesk-related prompts.

### Ransomeware related question


In [23]:
prompt = template.format(
    Category="Access and Security",
    Question="How can I secure my company's Wi-Fi network from unauthorized access?",
    Answer="",
)

print(gemma_lm.generate(prompt, max_length=100))



Category:
ithelpdesk-Access and Security

Question:
How can I secure my company's Wi-Fi network from unauthorized access?

Answer:
To secure your company's Wi-Fi network from unauthorized access, use a strong password, enable Wi-Fi encryption, and implement a firewall. Regularly update your router's firmware and use a reputable Wi-Fi access point. Consider using a virtual private network (VPN) to encrypt your internet traffic and protect your


We can see a better response to the access and security question. The model now provides better responses.

### Strong Password Prompt


In [24]:
prompt = template.format(
    Category="Authentication",
    Question="What is a strong password? Give me an example.",
    Answer="",
)

print(gemma_lm.generate(prompt, max_length=100))



Category:
ithelpdesk-Authentication

Question:
What is a strong password? Give me an example.

Answer:
A strong password is one that is at least 12 characters long, contains a mix of uppercase and lowercase letters, numbers, and symbols, and is not easily guessable. For example, "password123" or "MyNameIs123."


In [25]:
prompt = template.format(
    Category="Authentication",
    Question="I received an email that looks suspicious. How can I tell if it's a phishing attempt?",
    Answer="",
)

print(gemma_lm.generate(prompt, max_length=150))



Category:
ithelpdesk-Authentication

Question:
I received an email that looks suspicious. How can I tell if it's a phishing attempt?

Answer:
Phishing emails often contain spelling and grammar errors, and the sender's email address may be spoofed. Check the sender's email address and domain name to verify the authenticity of the email. If you're unsure, contact the sender directly through a secure channel, such as a phone call or secure messaging.


In [26]:
prompt = template.format(
    Category="Application and Software Issues",
    Question="How can I create and use an email signature in Outlook?",
    Answer="",
)

print(gemma_lm.generate(prompt, max_length=150))




Category:
ithelpdesk-Application and Software Issues

Question:
How can I create and use an email signature in Outlook?

Answer:
To create an email signature in Outlook, open the email message and click the "Format" tab. Select "Signature" from the "Mail Format" drop-down menu. Enter your signature text in the "Signature" field and click "OK" to save your changes.


In [27]:
prompt = template.format(
    Category="CPU Audio settings",
    Question="How can I configure my computer's audio settings?",
    Answer="",
)

print(gemma_lm.generate(prompt, max_length=150))



Category:
ithelpdesk-CPU Audio settings

Question:
How can I configure my computer's audio settings?

Answer:
1. Open the Control Panel. 2. Click on the Hardware and Sound icon. 3. Select the Sound option. 4. Click on the Recording tab. 5. Select the input device you want to use. 6. Adjust the volume levels. 7. Click on the Playback tab. 8. Select the output device you want to use. 9. Adjust the volume levels. 10. Click on the Apply button. 11. Restart your computer.


Now the model acts as a real IT help desk with better responses.

# Example Questions To Ask the Model

1. "What should I do if my computer is running slow?"
2. "How do I configure email signatures in my company's email client?"  
4. "How do I troubleshoot a computer that won’t turn on?"  
5. "How do I install a new application on my computer?"  
6. "How do I create a report of an application error?"  
7. "Microsoft Outlook keeps crashing during important email work. How can I fix this?"  
8. "My internet connection keeps dropping. What can I do?"  
9. "Can you explain how to configure firewall settings for specific applications?"  
10. "What steps should I take to respond to a phishing email?"  
11. "What steps should I take for a comprehensive system configuration?"  
12. "I'm experiencing frequent paper jams. How can I resolve this?"  
13. "What are some best practices for creating strong passwords?"  
14. "How can I access shared resources on a network in Linux?"  
15. "What is the difference between a firewall and a VPN?"  
16. "How can I sync my email on my mobile device?" 

## Final Thoughts

We've come to the end of this notebook. In conclusion, after fine-tuning the model using LoRA with the IT Helpdesk questions and answers dataset, we observed significant improvements in its ability to respond to IT Helpdesk-related prompts. 
The model now provides more accurate and relevant answers, such as with the ransomware question, demonstrating its enhanced performance in this domain. By leveraging LoRA, we were able to fine-tune the model efficiently with fewer computational resources, making the process both cost-effective and impactful. Overall, this approach has successfully strengthened the model's capability in addressing IT Helpdesk topics.

<!-- # Save the finetuned model 
preset_dir = ".\gemma2_2b_it_helpdesk"
gemma_lm.save_to_preset(preset_dir)

/kaggle/input/gemma/keras/gemma_2b_en/1 -->

In [29]:
# # Fine-tuning the model
# preset_dir = "/kaggle/working/gemma2_2b_it_helpdesk"
# gemma_lm.save_to_preset(preset_dir)

# # Archive the saved model
# import shutil
# shutil.make_archive("/kaggle/working/gemma2_2b_it_helpdesk", 'zip', preset_dir)

# print("Model saved and archived successfully!")

Model saved and archived successfully!


import kagglehub
from kagglehub.config import get_kaggle_credentials
kagglehub.login() 

kaggle_credentials = get_kaggle_credentials()
username = kaggle_credentials.username  
kaggle_uri = f"kaggle://{username}/gemma2-cybersecurity/keras/gemma2_2b_cyber_security"
keras_nlp.upload_preset(kaggle_uri, preset_dir)

## Refrences
https://www.kaggle.com/code/awsaf49/kaggle-qa-with-gemma-kerasnlp-starter
https://www.kaggle.com/code/nilaychauhan/fine-tune-gemma-models-in-keras-using-lora