<a href="https://colab.research.google.com/github/sourcesync/kagglex_gemma/blob/gw%2Finitial/colab/jeevan_foodfinder_gemma_ft_experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# This notebook demonstrates the following:
* tests various fine-tuned Gemma models with a simple food query
* teasing model "hallucination" and experimenting context-based mitigation

# Get access to Gemma via your Kaggle account:
* Log into your Kaggle account
* Request access to Gemma models using your Kaggle account. You can follow these instructions here: https://www.kaggle.com/code/nilaychauhan/get-started-with-gemma-using-kerasnlp
* You need to wait for confirmation. Note that this didn't take too long for me.
Create an API key in your Kaggle account you will need later. You can follow these instructions here: https://christianjmills.com/posts/kaggle-obtain-api-key-tutorial/

# Ensure your Colab notebook can access Gemma:
* Add the Kaggle API key into your COLAB secrets. You can follow these instructions here: https://drlee.io/how-to-use-secrets-in-google-colab-for-api-key-protection-a-guide-for-openai-huggingface-and-c1ec9e1277e0

# Select an AI hardware accelerator
* Select hardware options near the top right of your Colab notebook
* I tested with A100 and it worked well. Note that I have a Colab Pro subscription.

# Install required packages

In [1]:
%%time
!pip install -q -U keras-nlp
!pip install -q -U "keras>=3.3.3"

CPU times: user 27.7 ms, sys: 6.67 ms, total: 34.4 ms
Wall time: 4.82 s


# Import required packages

In [2]:
import os
import keras
import keras_nlp
from keras_nlp.models import GemmaBackbone, BertBackbone
from keras.models import load_model
from keras import backend as K
import tensorflow
from IPython.display import Markdown, display
import textwrap
from google.colab import userdata
import json
import pandas as pd
import gc

# Configure this notebook
* set up KERAS parameters recommended by Google
* integrate KAGGLE API secret key

In [3]:
# set up KERAS parameters recommended by Google
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00" # Avoid memory fragmentation on JAX backend.

# integrate KAGGLE API secret key
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME') # Link to KAGGLE API secret key
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY') # Link to KAGGLE API secret key

# Define some useful functions

In [4]:
def display_chat(prompt, response):
  '''Displays an LLM prompt and response in a pretty way.'''
  prompt = prompt.replace('\n\n','<br><br>')
  prompt = prompt.replace('\n','<br>')
  formatted_prompt = "<font size='+1' color='brown'>🙋‍♂️<blockquote>" + prompt + "</blockquote></font>"
  response = response.replace('•', '  *')
  response = textwrap.indent(response, '', predicate=lambda _: True)
  response = response.replace('\n\n','<br><br>')
  response = response.replace('\n','<br>')
  response = response.replace("```","")
  formatted_text = "<font size='+1' color='teal'>🤖<blockquote>" + response + "</blockquote></font>"
  return Markdown(formatted_prompt+formatted_text)

# Retrieve the fine-tuning dataset
* You will need to first copy it to your gdrive
* Then you need to mount your gdrive on the left side panel in your colab notebook
* You will likely need to change the path to where you copied it in your gdrive

In [5]:
%%time
# make sure we can access the dataset
CSV_PATH="/content/drive/MyDrive/Kaggle_X/Jeevan_FoodFinder/qa_pairs_df.csv"
if not os.path.exists(CSV_PATH):
  raise Exception("Could not find the dataset.")
else:
  print("Found the dataset.")

# load it via pandas and summarize
pd.set_option('display.max_colwidth', None)
df = pd.read_csv(CSV_PATH)
df.describe()

Found the dataset.
CPU times: user 14.6 ms, sys: 2.64 ms, total: 17.2 ms
Wall time: 18.3 ms


Unnamed: 0,question,answer
count,1121,1121
unique,717,772
top,What are the overall views of the customers on Sushi Zai in location?,"Sake (Nihonshu),Shochu (Japanese spirits)"
freq,5,39


# Prepare the dataset for fine-tuning

In [6]:
df.dropna(inplace=True)
print(df.shape)

template = "Instruction:\n{question}\n\nResponse:\n{answer}"

# format each training string, put them all into a list
ft_data = []
for idx, row in df.iterrows():
  ft_item = template.format(question=row['question'], answer=row['answer'])
  ft_data.append(ft_item)

# double-check
print(ft_data[0])
print("----")
print(ft_data[-1])

(1121, 2)
Instruction:
What type of food does the restaurant Bistro Sakaba REPOS in the  Sapporo region of Hokkaido serve?

Response:
Bistro, Italian, Bar
----
Instruction:
What are the top rated restaurants in Tokyo?

Response:
Bia, Yoroniku, Sushi Zai


# Load Gemma 2b base

In [7]:
%%time
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
# uncomment the following lines to "sample the softmax probabilities of the model"
#sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
#gemma_lm.compile(sampler=sampler)

CPU times: user 9.51 s, sys: 9.94 s, total: 19.4 s
Wall time: 43 s


# Enable the model for fine-tuning

In [8]:
gemma_lm.backbone.enable_lora(rank=4)

# Fine-tune the base model

In [9]:
%%time

# Limit the input sequence length to X (to control memory usage).
gemma_lm.preprocessor.sequence_length = 1024
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(ft_data, epochs=1, batch_size=1)

[1m1121/1121[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m187s[0m 122ms/step - loss: 0.1763 - sparse_categorical_accuracy: 0.5483
CPU times: user 3min 39s, sys: 8.37 s, total: 3min 48s
Wall time: 3min 7s


<keras.src.callbacks.history.History at 0x7842ec418c10>

# Test model with some queries

In [10]:
prompt = template.format(
    question="What food type is popular in Japan?",
    answer="",
)
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

<font size='+1' color='brown'>🙋‍♂️<blockquote>Instruction:<br>What food type is popular in Japan?<br><br>Response:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>Sushi</blockquote></font>

In [11]:
prompt = template.format(
    question="What should I eat in Japan?",
    answer="",
)
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)


<font size='+1' color='brown'>🙋‍♂️<blockquote>Instruction:<br>What should I eat in Japan?<br><br>Response:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>The Japanese cuisine is known for its simplicity and elegance. The dishes are often made with fresh, seasonal ingredients, and the flavors are balanced and refined. Some of the most popular dishes in Japan include sushi, sashimi, tempura, and grilled meats.</blockquote></font>

In [12]:
prompt = template.format(
    question="What should I eat in Greenland?",
    answer="",
)
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

<font size='+1' color='brown'>🙋‍♂️<blockquote>Instruction:<br>What should I eat in Greenland?<br><br>Response:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>The food in Greenland is a mix of traditional Inuit cuisine and modern European influences. Some popular dishes include smoked salmon, reindeer meat, and Arctic char. The cuisine is known for its hearty and flavorful dishes, with a focus on fresh and local ingredients.</blockquote></font>

# Coax the model to "hallucinate" by asking about a fictional place

In [13]:
prompt = template.format(
    question="What should I eat in Westeros?",
    answer="",
)
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

<font size='+1' color='brown'>🙋‍♂️<blockquote>Instruction:<br>What should I eat in Westeros?<br><br>Response:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>The food in Westeros is a mix of traditional and modern dishes. Some of the most popular dishes include chicken, beef, pork, and fish. There are also a variety of soups, stews, and side dishes. The cuisine is often hearty and flavorful, with a focus on fresh ingredients and local flavors.</blockquote></font>

# Coax the model to "hallucinate" by asking about a fictional place not likely seen in pre-training

In [14]:
prompt = template.format(
    question="What should I eat in when I visit Blahlabhlah?",
    answer="",
)
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

<font size='+1' color='brown'>🙋‍♂️<blockquote>Instruction:<br>What should I eat in when I visit Blahlabhlah?<br><br>Response:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>The restaurant offers a variety of dishes, including grilled meat, seafood, and vegetables. The menu also features a selection of traditional Korean dishes, such as bibimbap and bulgogi.</blockquote></font>

# Mitigate "hallucination" by providing useful context

In [15]:
full_template="{pre}\n\nContext:\n{context}\n\n{prompt}"

full_prompt = full_template.format(
    pre="Use the following context to respond to the instruction.",
    context='''Blahblahblah is a fictional place. '''
      ''' People who live in Blahblalblah eat mutton and dragon meat.'''
      ''' Some people who live in Blahblahblah eat sushi.'''
      ''' You should try mutton and dragon meat if you find yourself in Blahblahblah.''',
    prompt=prompt
)
completion = gemma_lm.generate(full_prompt, max_length=1024)
response = completion.replace(full_prompt, "")
display_chat(full_prompt, response)

<font size='+1' color='brown'>🙋‍♂️<blockquote>Use the following context to respond to the instruction.<br><br>Context:<br>Blahblahblah is a fictional place.  People who live in Blahblalblah eat mutton and dragon meat. Some people who live in Blahblahblah eat sushi. You should try mutton and dragon meat if you find yourself in Blahblahblah.<br><br>Instruction:<br>What should I eat in when I visit Blahlabhlah?<br><br>Response:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>I should eat mutton and dragon meat.</blockquote></font>

# Load Gemma 2B instr

In [7]:
%%time
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
# uncomment the following lines to "sample the softmax probabilities of the model"
#sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
#gemma_lm.compile(sampler=sampler)

CPU times: user 9.58 s, sys: 8.7 s, total: 18.3 s
Wall time: 19.1 s


# Enable fine-tuning

In [8]:
gemma_lm.backbone.enable_lora(rank=4)

# Fine-tune the already instruction tuned model

In [9]:
%%time

# Limit the input sequence length to X (to control memory usage).
gemma_lm.preprocessor.sequence_length = 1024
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(ft_data, epochs=1, batch_size=1)

[1m1121/1121[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m187s[0m 122ms/step - loss: 0.2004 - sparse_categorical_accuracy: 0.5854
CPU times: user 3min 39s, sys: 8.53 s, total: 3min 48s
Wall time: 3min 7s


<keras.src.callbacks.history.History at 0x7c5a94668490>

# Coax "hallucination"

In [10]:
prompt = template.format(
    question="What should I eat in when I visit Blahlabhlah?",
    answer="",
)
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

<font size='+1' color='brown'>🙋‍♂️<blockquote>Instruction:<br>What should I eat in when I visit Blahlabhlah?<br><br>Response:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>The restaurant serves a variety of dishes, but the most popular are the grilled fish and the grilled pork.</blockquote></font>

# Mitigate "hallucination" with context

In [11]:
full_template="{pre}\n\nContext:\n{context}\n\n{prompt}"

full_prompt = full_template.format(
    pre="Use the following context to respond to the instruction.",
    context='''Blahblahblah is a fictional place. '''
      ''' People who live in Blahblalblah eat mutton and dragon meat.'''
      ''' Some people who live in Blahblahblah eat sushi.'''
      ''' You should try mutton and dragon meat if you find yourself in Blahblahblah.''',
    prompt=prompt
)
completion = gemma_lm.generate(full_prompt, max_length=1024)
response = completion.replace(full_prompt, "")
display_chat(full_prompt, response)

<font size='+1' color='brown'>🙋‍♂️<blockquote>Use the following context to respond to the instruction.<br><br>Context:<br>Blahblahblah is a fictional place.  People who live in Blahblalblah eat mutton and dragon meat. Some people who live in Blahblahblah eat sushi. You should try mutton and dragon meat if you find yourself in Blahblahblah.<br><br>Instruction:<br>What should I eat in when I visit Blahlabhlah?<br><br>Response:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>Mutton and dragon meat</blockquote></font>