In [1]:
#
#  This notebook demonstrates the following:
#  * Loads the model 'gemma_instruct_2b_en' via Keras
#  * Shows various methods to complete a prompt
#    * model generate() method
#    * Langchain Runnable class
#    * GemmaLocalKaggle
#    * GemmaChatLocalKaggle
#  * Experiments with ways to integrate conversation history
#

In [2]:
#  Python imports
#
#  Notes:
#  * Make sure you install the packages in requirements.txt
#  * Make sure you setup your KAGGLE secrets via env vars.
#
import os
import keras
import keras_nlp
from keras_nlp.models import GemmaBackbone, BertBackbone
from keras.models import load_model
import kagglehub
from langchain.schema.runnable import Runnable
from typing import Any, Optional

In [3]:
#  Global config
#
os.environ["GRPC_VERBOSITY"] = "ERROR" # Suppress annoying low-level GRPC warnings
os.environ["GLOG_minloglevel"] = "2"   # Suppress annoying low-level GLOG warnings

In [4]:
#
# Initial prompt - I'll use this prompt to contrast-and-compare various 
# prompt completion techniques using Keras
#
initial_prompt = "Tell me a story. My name is Daniel."

In [5]:
#
#  Test a basic prompt completion using the model generate() call
#
template = "{instruction}" # Note: format may depend on the fine-tuning dataset format
simple_prompt = template.format(
    instruction=initial_prompt,
    response="",
)
model_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
sampler = keras_nlp.samplers.TopKSampler(k=3, seed=2)
model_lm.compile(sampler=sampler)
prompt_completion = model_lm.generate(simple_prompt, max_length=2048) # Note: Using the context size of Gemma as max
model_prompt_completion = prompt_completion.replace("\n\n","\n") # remove double new-lines for prettier printing here
print("Prompt:\n", "\"" + simple_prompt + "\"")
print("\nCompletion:\n", "\""+model_prompt_completion+"\"")


normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
I0000 00:00:1726863232.248734 11654186 service.cc:146] XLA service 0x6000028bc400 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1726863232.248753 11654186 service.cc:154]   StreamExecutor device (0): Host, Default Version
I0000 00:00:1726863232.252651 11654186 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Prompt:
 "Tell me a story. My name is Daniel."

Completion:
 "Tell me a story. My name is Daniel.
I was born and raised in a small coastal town in California called Santa Cruz. The ocean was my playground, and I spent countless hours swimming, surfing, and playing in the waves. I had a deep love for the ocean and all it had to offer.
One fateful day, a storm hit Santa Cruz with devastating force. The powerful waves crashed over the coastline, destroying everything in their path. My beloved town was devastated, and I was heartbroken.
But through the devastation, there were also glimmers of hope. The ocean had a way of reclaiming what it had lost, and new life emerged in the aftermath. The town was rebuilt, stronger than ever before.
As I grew older, I realized that my passion for the sea had not changed. I still yearned for the calmness and beauty of the ocean, and I found solace in the memories of my childhood.
Years later, I returned to Santa Cruz to rebuild my life. I had learned fro

In [53]:
#
# Test a basic prompt completion a minimal Langchain Runnable
#
class LMRunnable(Runnable):
    def __init__(self, model: keras_nlp.models.GemmaCausalLM):
        self.model = model

    def invoke(self, input: Any, config: Optional[dict] = None) -> str:
        prompt = str(input)
        output = self.model.generate(prompt, max_length=2048)
        return output if isinstance(output,str) else "Unknown output type."

# Use the (previously compiled) Gemma model as a Langchain Runnable
runnable = LMRunnable(model_lm)

# Complete the prompt via Runnable invoke()
runnable_prompt = f"{initial_prompt}" # Note: format may depend on the fine-tuning dataset format
print("Runnable prompt:\n" + "\"" +runnable_prompt + "\"")
runnable_prompt_completion = runnable.invoke(runnable_prompt)
print("\nRunnable completion:\n" + "\"" + runnable_prompt_completion + "\"")

runnable = None # clean up memory
model_lm = None # clean up memory

Runnable prompt:
"Tell me a story. My name is Daniel."

Runnable completion:
"Tell me a story. My name is Daniel.

My life was filled with the mundane. My days were filled with work, errands, and a never-ending cycle of chores. But there was one bright spot in my otherwise ordinary existence: a small garden tucked away in the corner of my backyard.

One sunny afternoon, I decided to explore the garden. I had never been there before, and I was amazed by the vibrant flowers and the lush greenery. I felt a sense of peace and tranquility wash over me as I sat down and took in the beauty of my surroundings.

As I sat there, I heard a gentle breeze rustling through the leaves. It was then that I noticed a small, furry creature scurrying across the garden path. It was the cutest little squirrel I had ever seen, and I reached out my hand to pet its head.

The squirrel was startled but seemed to sense my good intentions. It wagged its tail and let out a tiny squeak, as if to say hello. I smiled

In [55]:
#
# Test a basic prompt completion using Langchain's GemmaLocalKaggle class
# Documenation: https://python.langchain.com/api_reference/google_vertexai/gemma/langchain_google_vertexai.gemma.GemmaLocalKaggle.html
#
from langchain_google_vertexai import GemmaLocalKaggle
glk_model = GemmaLocalKaggle(model_name="gemma_instruct_2b_en", keras_backend="jax")
glk_prompt = f"{initial_prompt}" # Note: format may depend on the fine-tuning dataset format
print("GemmaLocalKaggle prompt:\n" + "\"" + glk_prompt + "\"")
glk_prompt_completion = glk.invoke(glk_prompt)
print("\nGemmaLocalKaggle completion:\n", glk_prompt_completion)

GemmaLocalKaggle prompt:
"Tell me a story. My name is Daniel."

GemmaLocalKaggle completion:
 Tell me a story. My name is Daniel.

I woke up in a small cottage nestled amongst the rolling hills. The morning sun painted the walls with warm golden hues, and the birds sang melodies that filled the air with joy.

As I opened my eyes, I noticed a small cottage nestled amongst the rolling hills. The cottage was made of wood and stone, with a thatched roof that seemed to dance in the sunlight. A small door was open, inviting me inside.

I stepped through the door and into the cozy interior. The walls were adorned with paintings of animals and landscapes, and the floor was covered with a soft rug. A fireplace crackled in the corner, casting warm shadows on the walls.

A warm meal awaited me in the kitchen, along with a cozy bed to rest my head. I sat down at the table and enjoyed the simple yet delicious meal.

As the sun began to set, I took a stroll outside the cottage, enjoying the fresh ai

In [58]:
#
# Test a basic prompt completion using Langchain's GemmaChatLocalKaggle class
# Documentation: https://python.langchain.com/api_reference/google_vertexai/gemma/langchain_google_vertexai.gemma.GemmaChatLocalKaggle.html
#
from langchain_google_vertexai import GemmaChatLocalKaggle
gclk_model = GemmaChatLocalKaggle(model_name="gemma_instruct_2b_en", keras_backend="jax")
gclk_prompt = f"{initial_prompt}" # Note: format may depend on the fine-tuning dataset format
print("GemmaChatLocalKaggle prompt:", gclk_prompt)
ai_message = gclk.invoke(gclk_prompt) # return a special aimessage object
print("\nGemmaChatLocalKaggle aimessage response object:")
ai_message.pretty_print()
# extract just the model completion from aimessage - vertexai appears to insert "turn" tags
turn_prefix = "<start_of_turn>model\n"
pos = ai_message.content.find(turn_prefix)
gclk_prompt_completion = ai_message.content[ pos + len(response_prefix): ]


GemmaChatLocalKaggle prompt: Tell me a story. My name is Daniel.

GemmaChatLocalKaggle aimessage response object:

<start_of_turn>user
Tell me a story. My name is Daniel.<end_of_turn>
<start_of_turn>model
My name is Daniel, and I have a story to tell. I was born in a small village nestled amidst rolling hills and shimmering rivers. My childhood was filled with laughter and adventure, as I explored the lush forests and hidden waterfalls that surrounded our home.

As I grew older, I became fascinated by the world beyond our village. I would spend countless hours gazing up at the vast expanse of the sky, wondering about the mysteries of the universe. My curiosity led me to explore nearby towns and cities, where I learned about different cultures and traditions.

My thirst for knowledge never waned. I enrolled in a prestigious university, where I excelled in my studies. I was particularly drawn to the field of literature, where I discovered the power of words to transport me to different w

In [57]:
#
# The rest of this notebook attempts to integrate previous "chat" history
# in various ways.    
#


In [15]:
# Extract the last two sentences of previous response as 'context' for the next prompt
# in the conversation/story.
cherry_picked_history = "".join( prompt_completion.strip().split("\n\n")[-2:] )
print(cherry_picked_history)

**What is the next part of the story?**I turned to the sky and looked up, admiring the vast expanse of the ocean. I thought about the countless adventures I had had in this vast world, and I felt a deep sense of gratitude for the beauty and wonder that surrounded me. I knew that my story was one of hope, resilience, and the enduring power of the ocean. I would always be proud of my roots, and I would always find solace in the vastness of the ocean.


In [8]:
#
# Continution via model generate()
#

# Extract the last two sentences of previous response as 'context' for the next prompt in the conversation/story.
cherry_picked_history = "".join( model_prompt_completion.strip().split("\n\n")[-2:] )
print(cherry_picked_history)

# Form the follow-up prompt.
continuation_prompt = template.format(
    instruction= f"You are a storyteller.  Use the following context and continue telling the story.\n\n{cherry_picked_history}",
    response="",
) # Note: That format may depend additional on the dataset used for fine-tuning
print("Continuation prompt:\n", continuation_prompt)

model_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
sampler = keras_nlp.samplers.TopKSampler(k=3, seed=2)
model_lm.compile(sampler=sampler)
model_continuation_prompt_completion = model_lm.generate(continuation_prompt, max_length=2048) 
print("\nPrompt completion:\n", "\""+model_continuation_prompt_completion+"\"")

Continuation prompt:
 You are a storyteller.  Use the following context and continue telling the story.

It was then that I realized that my story was one of hope, resilience, and the enduring power of the ocean.I am Daniel, the son of Santa Cruz, and I will always be proud of my roots.

Prompt completion:
 "You are a storyteller.  Use the following context and continue telling the story.

It was then that I realized that my story was one of hope, resilience, and the enduring power of the ocean.I am Daniel, the son of Santa Cruz, and I will always be proud of my roots.

I was born and raised in the heart of the bustling city life, yet I always found myself drawn to the serenity of the ocean and the vastness of nature. The ocean was my sanctuary, my escape, and my guiding light.

As I grew older, my curiosity led me to explore the world beyond the city walls. I traveled to distant islands, where I witnessed the breathtaking beauty of coral reefs and the vibrant colors of marine life. I 

In [10]:
#
# Continuation via a basic Langchain Runnable
#
class LMRunnable(Runnable):
    def __init__(self, model: keras_nlp.models.GemmaCausalLM):
        self.model = model

    def invoke(self, input: Any, config: Optional[dict] = None) -> str:
        prompt = str(input)
        output = self.model.generate(prompt, max_length=2048)
        return output if isinstance(output,str) else "Unknown output type."

# Use the (previously compiled) Gemma model as a Langchain Runnable
runnable = LMRunnable(model_lm)

# Complete the prompt via Runnable invoke()
continuation_prompt = f"You are a storyteller.  Use the following context and continue telling the story.\n\n{cherry_picked_history}"
print("Continuation prompt:\n", continuation_prompt)
runnable_prompt_completion = runnable.invoke(continuation_prompt)
print("\nPrompt completiont:\n", runnable_prompt_completion)

Continuation prompt:
 You are a storyteller.  Use the following context and continue telling the story.

It was then that I realized that my story was one of hope, resilience, and the enduring power of the ocean.I am Daniel, the son of Santa Cruz, and I will always be proud of my roots.

Prompt completiont:
 You are a storyteller.  Use the following context and continue telling the story.

It was then that I realized that my story was one of hope, resilience, and the enduring power of the ocean.I am Daniel, the son of Santa Cruz, and I will always be proud of my roots.

The ocean was my playground, and the waves were my teachers. They taught me the rhythm of the tides, the secrets of the currents, and the importance of always being prepared for whatever came my way.

As I grew older, I became a skilled fisherman, learning the art of casting my net and reeling in the bounty of the sea. I learned the language of the ocean, understanding its moods and predicting its changes.

Years later,

In [13]:
#
# Continuation via Langchain's GemmaLocalKaggle class
# Documenation: https://python.langchain.com/api_reference/google_vertexai/gemma/langchain_google_vertexai.gemma.GemmaLocalKaggle.html
#
from langchain_google_vertexai import GemmaLocalKaggle
glk_model = GemmaLocalKaggle(model_name="gemma_instruct_2b_en", keras_backend="jax")
glk_continuation_prompt = f"You are a storyteller.  Use the following context and continue telling the story.\n\n{cherry_picked_history}" # Note: format may depend on the fine-tuning dataset format
print("GemmaLocalKaggle prompt:\n" + "\"" + glk_continuation_prompt + "\"")
glk_prompt_completion = glk_model.invoke(glk_continuation_prompt)
print("\nGemmaLocalKaggle completion:\n", glk_prompt_completion)

GemmaLocalKaggle prompt:
"You are a storyteller.  Use the following context and continue telling the story.

It was then that I realized that my story was one of hope, resilience, and the enduring power of the ocean.I am Daniel, the son of Santa Cruz, and I will always be proud of my roots."

GemmaLocalKaggle completion:
 You are a storyteller.  Use the following context and continue telling the story.

It was then that I realized that my story was one of hope, resilience, and the enduring power of the ocean.I am Daniel, the son of Santa Cruz, and I will always be proud of my roots.

**What happened next?**

I stood on the shore, watching the waves caress the shore, carrying away the worries of the world. The sun was shining brightly, casting a warm glow on the sand, and the air was filled with the sound of birds singing. I smiled, feeling a sense of contentment wash over me.

**What is the next part of the story?**

I turned to the sky and looked up, admiring the vast expanse of the o

In [None]:
#
# Test a basic prompt completion using Langchain's GemmaChatLocalKaggle class
# Documentation: https://python.langchain.com/api_reference/google_vertexai/gemma/langchain_google_vertexai.gemma.GemmaChatLocalKaggle.html
#
from langchain_google_vertexai import GemmaChatLocalKaggle
gclk_model = GemmaChatLocalKaggle(model_name="gemma_instruct_2b_en", keras_backend="jax")
gclk_prompt = f"{initial_prompt}" # Note: format may depend on the fine-tuning dataset format
print("GemmaChatLocalKaggle prompt:", gclk_prompt)
ai_message = gclk.invoke(gclk_prompt) # return a special aimessage object
print("\nGemmaChatLocalKaggle aimessage response object:")
ai_message.pretty_print()
# extract just the model completion from aimessage - vertexai appears to insert "turn" tags
turn_prefix = "<start_of_turn>model\n"
pos = ai_message.content.find(turn_prefix)
gclk_prompt_completion = ai_message.content[ pos + len(response_prefix): ]
