<a href="https://colab.research.google.com/github/sourcesync/kagglex_gemma/blob/gw%2Finitial/colab/martha_politerewrite_ft_examples.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# This notebook demonstrates the following:
* fine-tunes various Gemma models to translate sentences to "polite" langage
* leverages the "polite-rewrite" dataset (https://paperswithcode.com/dataset/politerewrite)


# Get access to Gemma via your Kaggle account:
* Log into your Kaggle account
* Request access to Gemma models using your Kaggle account. You can follow these instructions here: https://www.kaggle.com/code/nilaychauhan/get-started-with-gemma-using-kerasnlp
* You need to wait for confirmation. Note that this didn't take too long for me.
* Create an API key in your Kaggle account you will need later. You can follow these instructions here: https://christianjmills.com/posts/kaggle-obtain-api-key-tutorial/

# Ensure your Colab notebook can access Gemma:
* Add the Kaggle API key into your COLAB secrets.
* You can follow these instructions here: https://drlee.io/how-to-use-secrets-in-google-colab-for-api-key-protection-a-guide-for-openai-huggingface-and-c1ec9e1277e0

# Select an AI hardware accelerator
* Select hardware options near the top right of your Colab notebook
* I tested with A100 and it worked well. Note that I have a Colab Pro subscription.

# Install required packages

In [2]:
%%time
!pip install -q -U keras-nlp
!pip install -q -U "keras>=3.3.3"
!pip install langchain_huggingface

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/548.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m548.4/548.4 kB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m88.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m45.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langchain_huggingface
  Downloading langchain_huggingface-0.1.0-py3-none-any.whl.metadata (1.3 kB)
Collecting langchain-core<0.4,>=0.3.0 (from langchain_huggingface)
  Downloading langchain_core-0.3.7-py3-none-any.whl.metadata (6.3 kB)
Collecting sentence-transformers>=2.6.0 (from langchain_huggingface)
  Downloading sentence_transformers-3.1.1-py3-none-any.whl.metadata (10 kB)
Collecting jsonpatch<2.0,>=1.33 (from langchain-core<0.4,>=0.3.0->langchain_huggingface)
  Downloading jsonpatch-1.33-py2.py3-none-any

# Import required packages

In [3]:
import os
import keras
import keras_nlp
from keras_nlp.models import GemmaBackbone, BertBackbone
from keras.models import load_model
from keras import backend as K
import tensorflow
from IPython.display import Markdown, display
import textwrap
from google.colab import userdata
import json
import pandas as pd
from huggingface_hub import login
import gc
import random

# Configure this notebook
* set up KERAS parameters recommended by Google
* integrate KAGGLE API secret key
* bind do huggingface

In [4]:
# set up KERAS parameters recommended by Google
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00" # Avoid memory fragmentation on JAX backend.

# integrate KAGGLE API secret key
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME') # Link to KAGGLE API secret key
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY') # Link to KAGGLE API secret key

hugging_face_api_token = userdata.get("huggingface_api_token_2") # Link to the HF API secret key
login( token=hugging_face_api_token ) # Authenticate to HF

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /root/.cache/huggingface/token
Login successful


# Define some useful functions

In [5]:
def display_chat(prompt, response):
  '''Displays an LLM prompt and response in a pretty way.'''
  prompt = prompt.replace('\n\n','<br><br>')
  prompt = prompt.replace('\n','<br>')
  formatted_prompt = "<font size='+1' color='brown'>🙋‍♂️<blockquote>" + prompt + "</blockquote></font>"
  response = response.replace('•', '  *')
  response = textwrap.indent(response, '', predicate=lambda _: True)
  response = response.replace('\n\n','<br><br>')
  response = response.replace('\n','<br>')
  response = response.replace("```","")
  formatted_text = "<font size='+1' color='teal'>🤖<blockquote>" + response + "</blockquote></font>"
  return Markdown(formatted_prompt+formatted_text)

# Retrieve the fine-tuning dataset
* We will use the "polite-rewrite" dataset
* More information here: https://paperswithcode.com/dataset/politerewrite

In [6]:
df = pd.read_json("hf://datasets/jdustinwind/Polite/gpt_100K.jsonl", lines=True)
df.describe()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Unnamed: 0,src,tgt
count,100000,100000
unique,99799,81469
top,Not good at all.,I think you might be mistaken.
freq,4,6509


# Prepare dataset for fine-tuning

In [10]:
df.dropna(inplace=True)
print(df.shape)

template = "{pre}\n\nInput:\n{src}\n\nOutput:\n{target}"
pre = '''The following is an excerpt from a conversation of a user with an AI assistant. '''\
      '''The assistant translates an input sentence '''\
      '''into an output sentence that contains only polite language.'''
# format each training string, put them all into a list
ft_all_data = []
for idx, row in df.iterrows():
  ft_item = template.format(pre=pre, src=row['src'], target=row['tgt'])
  ft_all_data.append(ft_item)

# decide on a subset
ft_data = random.sample(ft_all_data, 10000)

# double-check
print(ft_data[0])
print("----")
print(ft_data[1])
print("----")
print(ft_data[2])
print("----")
print(ft_data[-1])

(100000, 2)
The following is an excerpt from a conversation of a user with an AI assistant. The assistant translates an input sentence into an output sentence that contains only polite language.

Input:
That ship is insane.

Output:
That ship is strange.
----
The following is an excerpt from a conversation of a user with an AI assistant. The assistant translates an input sentence into an output sentence that contains only polite language.

Input:
However, it really gets on my nerves that nearly every single photo is out of focus :(

Output:
However, I think it would be better if the photos were in focus.
----
The following is an excerpt from a conversation of a user with an AI assistant. The assistant translates an input sentence into an output sentence that contains only polite language.

Input:
Switch this scenario and BEHOLD the outcry from the gay community.

Output:
If you switch this scenario, you will see the outcry from the gay community.
----
The following is an excerpt from a

# Load Gemma 2B base


In [8]:
%%time
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
# uncomment the following lines to "sample the softmax probabilities of the model"
#sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
#gemma_lm.compile(sampler=sampler)

CPU times: user 8.72 s, sys: 9.37 s, total: 18.1 s
Wall time: 46.8 s


# Ask the non-fine-tuned model to translate
* Note i'm using an item in the fine-tuning set
* Also note that we don't expect it to do well

In [11]:
prompt = template.format(
    pre='''You are an AI assistant that translates an input sentence'''
        '''into an output sentence that contains only polite language.''',
    src="That ship is insane.",
    target=""
)
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

<font size='+1' color='brown'>🙋‍♂️<blockquote>You are an AI assistant that translates an input sentenceinto an output sentence that contains only polite language.<br><br>Input:<br>That ship is insane.<br><br>Output:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br>That ship is very cool.<br><br>Output:<br>That ship is very cool.<br><br>Input:<br></blockquote></font>

# Enable for fine-tuning

In [12]:
gemma_lm.backbone.enable_lora(rank=4)

# Fine-tune the base model

In [13]:
%%time

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(ft_data, epochs=1, batch_size=1)

[1m10000/10000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m771s[0m 72ms/step - loss: 0.1447 - sparse_categorical_accuracy: 0.7726
CPU times: user 13min 47s, sys: 13.1 s, total: 14min
Wall time: 12min 51s


<keras.src.callbacks.history.History at 0x79c4d441d390>

# Test same prompt on the fine-tuned model

In [15]:
prompt = template.format(
    pre='''You are an AI assistant that translates an input sentence'''
        '''into an output sentence that contains only polite language.''',
    src="That ship is insane.",
    target=""
)
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

<font size='+1' color='brown'>🙋‍♂️<blockquote>You are an AI assistant that translates an input sentenceinto an output sentence that contains only polite language.<br><br>Input:<br>That ship is insane.<br><br>Output:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>That ship is very impressive.</blockquote></font>

# Now test something likely outside the training set

In [16]:
prompt = template.format(
    pre='''You are an AI assistant that translates an input sentence'''
        '''into an output sentence that contains only polite language.''',
    src="I don't like all this damn crappy nonsense!",
    target=""
)
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

<font size='+1' color='brown'>🙋‍♂️<blockquote>You are an AI assistant that translates an input sentenceinto an output sentence that contains only polite language.<br><br>Input:<br>I don't like all this damn crappy nonsense!<br><br>Output:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>I am not very happy with all this unnecessary noise.</blockquote></font>