In [9]:
#
#  This notebook demonstrates the following:
#  * Use the Keras built-in Gemma instruction fine-tuned model ('gemma_instruct_2b_en')
#  * Shows various methods to complete a prompt in Keras
#    * model generate() method
#    * Langchain Runnable class
#    * GemmaLocalKaggle
#    * GemmaChatLocalKaggle
#  * Experiments with ways to integrate conversation history
#

In [22]:
#  Python imports
#
#  Notes:
#  * Make sure you install the packages in requirements.txt
#  * Make sure you setup your KAGGLE secrets via env vars.
#
import os
import keras
import keras_nlp
from keras_nlp.models import GemmaBackbone, BertBackbone
from keras.models import load_model
import kagglehub
import langchain_core
from langchain_core.messages import HumanMessage, SystemMessage
from langchain.schema.runnable import Runnable
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from typing import Any, Optional
import tensorflow as tf
from keras.config import disable_interactive_logging
import gc

In [23]:
#  Global config
#
os.environ["GRPC_VERBOSITY"] = "ERROR" # Try to suppress annoying GRPC warnings
os.environ["GLOG_minloglevel"] = "3"   # Try to suppress annoying GLOG warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Try to suppress annoying TF CPP warnings
# Try to suppress general warnings sent through the warnings module
def warn(*args, **kwargs): 
    pass
import warnings
warnings.warn = warn
disable_interactive_logging() # Try to support keras warnings

In [24]:
# Useful functions
#
def makebold(txt):
    return '\x1b[1;30m'+txt+'\x1b[0m'
def makeblue(txt):
    return '\x1b[1;34m'+txt+'\x1b[0m'
def makegreen(txt):
    return '\x1b[1;32m'+txt+'\x1b[0m'
def makered(txt):
    return '\x1b[1;33m'+txt+'\x1b[0m'

In [20]:
#
# Initial prompt - I'll use this prompt to contrast-and-compare various 
# prompt completion techniques using Keras
#
initial_prompt = "Tell me a story. My name is Daniel."

In [6]:
#
#  Test a basic prompt completion using the model generate() call
#
model_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
sampler = keras_nlp.samplers.TopKSampler(k=3, seed=2)
model_lm.compile(sampler=sampler)

# Create the prompt via manual formatting
template = "{instruction}" # Note: format may depend on the fine-tuning dataset format
simple_prompt = template.format(
    instruction=initial_prompt,
    response="",
)
print(makebold("Prompt:\n"), makegreen("\"" + simple_prompt + "\"") )

# Complete the prompt via model generate()
model_prompt_completion = model_lm.generate(simple_prompt, max_length=2048) # Note: Using the context size of Gemma as max
print(makebold("\nCompletion:\n"), makeblue("\""+ model_prompt_completion + "\"" ))

model_lm = None # clean up memory
gc.collect() # Run python memory collector

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


[1;30mPrompt:
[0m [1;32m"Tell me a story. My name is Daniel."[0m


I0000 00:00:1726940455.888266 11907544 service.cc:146] XLA service 0x600002b87f00 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1726940455.888289 11907544 service.cc:154]   StreamExecutor device (0): Host, Default Version
I0000 00:00:1726940455.892887 11907544 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1;30m
Completion:
[0m [1;34m"Tell me a story. My name is Daniel.

I was born and raised in a small coastal town in California called Santa Cruz. The ocean was my playground, and I spent countless hours swimming, surfing, and playing in the waves. I had a deep love for the ocean and all it had to offer.

One fateful day, a storm hit Santa Cruz with devastating force. The powerful waves crashed over the coastline, destroying everything in their path. My beloved town was devastated, and I was heartbroken.

But through the devastation, there were also glimmers of hope. The ocean had a way of reclaiming what it had lost, and new life emerged in the aftermath. The town was rebuilt, stronger than ever before.

As I grew older, I realized that my passion for the sea had not changed. I still yearned for the calmness and beauty of the ocean, and I found solace in the memories of my childhood.

Years later, I returned to Santa Cruz to rebuild my life. I had learned from my previous experience

96399

In [7]:
#
# Test a basic prompt completion via a minimal Langchain Runnable
# with minimal interaction
#
class LMRunnable(Runnable):
    def __init__(self, model: keras_nlp.models.GemmaCausalLM):
        self.model = model

    def invoke(self, input: Any, config: Optional[dict] = None) -> str:
        prompt = str(input)
        output = self.model.generate(prompt, max_length=2048)
        return output 

# Use the Gemma model as a Langchain Runnable
model_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
sampler = keras_nlp.samplers.TopKSampler(k=3, seed=2)
model_lm.compile(sampler=sampler)
runnable = LMRunnable(model_lm)

# Create the prompt via manual formatting
runnable_prompt = f"{initial_prompt}" # Note: format may depend on the fine-tuning dataset format
print(makebold("Runnable prompt:\n") +  makegreen("\"" + runnable_prompt + "\"") )

# Complete the prompt via Runnable invoke()
runnable_prompt_completion = runnable.invoke(runnable_prompt)
print(makebold("\nRunnable completion:\n") + makeblue("\"" + runnable_prompt_completion + "\""))

runnable = None # release model handle
model_lm = None # try to clean up model memory
gc.collect() # Run python memory collector

[1;30mRunnable prompt:
[0m[1;32m"Tell me a story. My name is Daniel."[0m
[1;30m
Runnable completion:
[0m[1;34m"Tell me a story. My name is Daniel.

I was born and raised in a small coastal town in California called Santa Cruz. The ocean was my playground, and I spent countless hours swimming, surfing, and playing in the waves. I had a deep love for the ocean and all it had to offer.

One fateful day, a storm hit Santa Cruz with devastating force. The powerful waves crashed over the coastline, destroying everything in their path. My beloved town was devastated, and I was heartbroken.

But through the devastation, there were also glimmers of hope. The ocean had a way of reclaiming what it had lost, and new life emerged in the aftermath. The town was rebuilt, stronger than ever before.

As I grew older, I realized that my passion for the sea had not changed. I still yearned for the calmness and beauty of the ocean, and I found solace in the memories of my childhood.

Years later, I

90275

In [25]:
#
# Test a basic prompt completion via a minimal Langchain Runnable and ChatPromptTemplate
#
class LMRunnable(Runnable):
    def __init__(self, model: keras_nlp.models.GemmaCausalLM):
        self.model = model

    def invoke(self, input: Any, config: Optional[dict] = None) -> str: 
        if not isinstance(input,langchain_core.prompt_values.PromptValue):
            raise Exception("Unknown input type")
        # Note the implementor of this class needs to know about the template
        # in order to form the raw prompt to present to the model
        raw_prompt = input.messages[0].content + "\n\n" + input.messages[1].content
        output = self.model.generate(raw_prompt, max_length=2048)
        return output 

# Use the Gemma model as a Langchain Runnable
model_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
sampler = keras_nlp.samplers.TopKSampler(k=3, seed=2)
model_lm.compile(sampler=sampler)
runnable = LMRunnable(model_lm)

# Create the prompt via a ChatPromptTemplate
prompt_template = ChatPromptTemplate.from_messages([
    ("system", "You are a writer of function."),
    MessagesPlaceholder("msgs")
])

# Make a prompt using the template just defined
runnable_prompt_templ = prompt_template.invoke({"msgs": [HumanMessage(content=initial_prompt)]})
print( makebold("Runnable prompt from template:\n") \
      + makegreen( "\"" + runnable_prompt_templ.messages[0].content + "\n\n" + \
                  runnable_prompt_templ.messages[1].content + "\"") )

# Complete the prompt
runnable_prompt_completion_templ = runnable.invoke(runnable_prompt_templ)
print(makebold("\nRunnable completion:\n") + makeblue("\"" + runnable_prompt_completion_templ + "\""))

runnable = None # release model handle
model_lm = None # try to clean up model memory
gc.collect() # Run python memory collector

[1;30mRunnable prompt from template:
[0m[1;32m"You are a writer of function.

Tell me a story. My name is Daniel."[0m
[1;30m
Runnable completion:
[0m[1;34m"You are a writer of function.

Tell me a story. My name is Daniel.

I have a wife, Sarah, and we have two beautiful children, Emily and Daniel. I work as a software developer, and I am very passionate about my work. I am also very active in my community, and I volunteer my time to help others.

My life has been filled with joy and sorrow, but through it all, I have learned that the most important thing is to be kind to others and to always be there for those I love.

What is your story?"[0m


90183

In [8]:
#
# Test a basic prompt completion using Langchain's GemmaLocalKaggle class
# Documentation: https://python.langchain.com/api_reference/google_vertexai/gemma/langchain_google_vertexai.gemma.GemmaLocalKaggle.html
#
from langchain_google_vertexai import GemmaLocalKaggle
glk_model = GemmaLocalKaggle(model_name="gemma_instruct_2b_en", keras_backend="jax")

# Create the prompt via manual formatting
glk_prompt = f"{initial_prompt}" # Note: format may depend on the fine-tuning dataset format
print(makebold("GemmaLocalKaggle prompt:\n") + makegreen("\"" + glk_prompt + "\""))

# Complete the prompt via invoke()
glk_prompt_completion = glk_model.invoke(glk_prompt)
print(makebold("\nGemmaLocalKaggle completion:\n"), makeblue("\""+ glk_prompt_completion+ "\"") )

glk_model = None # try to clean up model memory
gc.collect() # Run python memory collector

[1;30mGemmaLocalKaggle prompt:
[0m[1;32m"Tell me a story. My name is Daniel."[0m
[1;30m
GemmaLocalKaggle completion:
[0m [1;34m"Tell me a story. My name is Daniel.

I woke up in a small cottage nestled amongst the rolling hills. The morning sun painted the walls with warm golden hues, and the birds sang melodies that filled the air with joy.

As I opened my eyes, I noticed a small cottage nestled amongst the rolling hills. The cottage was made of wood and stone, with a thatched roof that seemed to dance in the sunlight. A small door was open, inviting me inside.

I stepped through the door and into the cozy interior. The walls were adorned with paintings of animals and landscapes, and the floor was covered with a soft rug. A fireplace crackled in the corner, casting warm shadows on the walls.

A warm meal awaited me in the kitchen, along with a cozy bed to rest my head. I sat down at the table and enjoyed the simple yet delicious meal.

As the sun began to set, I took a stroll o

89844

In [9]:
#
# Test a basic prompt completion using Langchain's GemmaChatLocalKaggle class
# Documentation: https://python.langchain.com/api_reference/google_vertexai/gemma/langchain_google_vertexai.gemma.GemmaChatLocalKaggle.html
#
from langchain_google_vertexai import GemmaChatLocalKaggle
gclk_model = GemmaChatLocalKaggle(model_name="gemma_instruct_2b_en", keras_backend="jax")

# Create the prompt via manual formatting
gclk_prompt = f"{initial_prompt}" # Note: format may depend on the fine-tuning dataset format
print(makebold("GemmaChatLocalKaggle prompt:\n"), makegreen("\""+ gclk_prompt + "\"") )

# Complete the prompt via invoke
ai_message = gclk_model.invoke(gclk_prompt) # return a special aimessage object
print(makebold("\nGemmaChatLocalKaggle aimessage response object:"))
ai_message.pretty_print()
# extract just the model completion from aimessage - vertexai appears to insert "turn" tags
turn_prefix = "<start_of_turn>model\n"
pos = ai_message.content.find(turn_prefix)
gclk_prompt_completion = ai_message.content[ pos + len(turn_prefix): ]

print(makebold("\nGemmaChatLocalKaggle response only:\n") + makeblue("\"" + gclk_prompt_completion + "\""))

gclk_model = None # try to clean up model memory
gc.collect() # Run python memory collector

[1;30mGemmaChatLocalKaggle prompt:
[0m [1;32m"Tell me a story. My name is Daniel."[0m
[1;30m
GemmaChatLocalKaggle aimessage response object:[0m

<start_of_turn>user
Tell me a story. My name is Daniel.<end_of_turn>
<start_of_turn>model
My name is Daniel, and I have a story to tell. I was born in a small village nestled amidst rolling hills and shimmering rivers. My childhood was filled with laughter and adventure, as I explored the lush forests and hidden waterfalls that surrounded our home.

As I grew older, I became fascinated by the world beyond our village. I would spend countless hours gazing up at the vast expanse of the sky, wondering about the mysteries of the universe. My curiosity led me to explore nearby towns and cities, where I learned about different cultures and traditions.

My thirst for knowledge never waned. I enrolled in a prestigious university, where I excelled in my studies. I was particularly drawn to the field of literature, where I discovered the power of 

89845

In [10]:
#
# The rest of this notebook attempts to integrate previous "chat" history
# in various ways.    
#

In [11]:
#
# Continuation via model generate()
#

model_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
sampler = keras_nlp.samplers.TopKSampler(k=3, seed=2)
model_lm.compile(sampler=sampler)

# Extract the last two sentences of previous response as 'context' for the next prompt in the conversation/story.
cherry_picked_history = " ".join( model_prompt_completion.strip().split("\n\n")[-2:] )
print(makebold("Cherry-picked history:\n"), makered("\"" + cherry_picked_history + "\""))

# Form the follow-up prompt.
continuation_prompt = template.format(
    instruction= f"You are a writer of fiction. Use the following context and continue telling the story.\n\n{cherry_picked_history}",
    response="",
) # Note: That format may depend additional on the dataset used for fine-tuning
print(makebold("\nContinuation prompt:\n"), makegreen("\"" + continuation_prompt + "\""))

# Complete the prompt via model generate()
model_continuation_prompt_completion = model_lm.generate(continuation_prompt, max_length=2048) 
print(makebold("\nPrompt completion:\n"), makeblue("\"" + model_continuation_prompt_completion + "\""))

model_lm = None # try to clean up model memory
gc.collect() # Run python memory collector

[1;30mCherry-picked history:
[0m [1;33m"It was then that I realized that my story was one of hope, resilience, and the enduring power of the ocean. I am Daniel, the son of Santa Cruz, and I will always be proud of my roots."[0m
[1;30m
Continuation prompt:
[0m [1;32m"You are a writer of fiction. Use the following context and continue telling the story.

It was then that I realized that my story was one of hope, resilience, and the enduring power of the ocean. I am Daniel, the son of Santa Cruz, and I will always be proud of my roots."[0m
[1;30m
Prompt completion:
[0m [1;34m"You are a writer of fiction. Use the following context and continue telling the story.

It was then that I realized that my story was one of hope, resilience, and the enduring power of the ocean. I am Daniel, the son of Santa Cruz, and I will always be proud of my roots.

**Continue the Story**

The sun rose over the shimmering waters of Santa Cruz, casting an ethereal glow upon the bustling streets below.

90273

In [12]:
#
# Continuation via a basic Langchain Runnable
#
class LMRunnable(Runnable):
    def __init__(self, model: keras_nlp.models.GemmaCausalLM):
        self.model = model

    def invoke(self, input: Any, config: Optional[dict] = None) -> str:
        prompt = str(input)
        output = self.model.generate(prompt, max_length=2048)
        return output if isinstance(output,str) else "Unknown output type."

# Use the (previously compiled) Gemma model as a Langchain Runnable
model_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
sampler = keras_nlp.samplers.TopKSampler(k=3, seed=2)
model_lm.compile(sampler=sampler)
runnable = LMRunnable(model_lm)

# Extract the last two sentences of previous response as 'context' for the next prompt in the conversation/story.
cherry_picked_history = " ".join( runnable_prompt_completion.strip().split("\n\n")[-2:] )
print(makebold("Cherry-picked history:\n"), makered( "\"" + cherry_picked_history + "\""))

# Create the prompt via manual formatting
continuation_prompt = f"You are a writer of fiction. Use the following context and continue telling the story.\n\n{cherry_picked_history}"
print(makebold("\nContinuation prompt:\n"), makegreen("\"" + continuation_prompt + "\""))

# Complete the prompt via Runnable invoke()
runnable_continuation_prompt_completion = runnable.invoke(continuation_prompt)
print(makebold("\nPrompt completiont:\n"), makeblue("\"" + runnable_continuation_prompt_completion + "\""))

runnable = None # release model handle
model_lm = None # try to clean up model memory
gc.collect() # Run python memory collector

[1;30mCherry-picked history:
[0m [1;33m"It was then that I realized that my story was one of hope, resilience, and the enduring power of the ocean. I am Daniel, the son of Santa Cruz, and I will always be proud of my roots."[0m
[1;30m
Continuation prompt:
[0m [1;32m"You are a writer of fiction. Use the following context and continue telling the story.

It was then that I realized that my story was one of hope, resilience, and the enduring power of the ocean. I am Daniel, the son of Santa Cruz, and I will always be proud of my roots."[0m
[1;30m
Prompt completiont:
[0m [1;34m"You are a writer of fiction. Use the following context and continue telling the story.

It was then that I realized that my story was one of hope, resilience, and the enduring power of the ocean. I am Daniel, the son of Santa Cruz, and I will always be proud of my roots.

**Continue the Story**

The sun rose over the shimmering waters of Santa Cruz, casting an ethereal glow upon the bustling streets below

90284

In [13]:
#
# Continuation via Langchain's GemmaLocalKaggle class
# Documentation: https://python.langchain.com/api_reference/google_vertexai/gemma/langchain_google_vertexai.gemma.GemmaLocalKaggle.html
#
from langchain_google_vertexai import GemmaLocalKaggle
glk_model = GemmaLocalKaggle(model_name="gemma_instruct_2b_en", keras_backend="jax")

# Extract the last two sentences of previous response as 'context' for the next prompt in the conversation/story.
cherry_picked_history = " ".join( glk_prompt_completion.strip().split("\n\n")[-2:] )
print(makebold("Cherry-picked history:\n"), makered( "\"" + cherry_picked_history + "\""))

# Complete the prompt via Runnable invoke()
glk_continuation_prompt = f"You are a writer of fiction. Use the following context and continue telling the story.\n\n{cherry_picked_history}" # Note: format may depend on the fine-tuning dataset format
print(makebold("\nGemmaLocalKaggle prompt:\n") + makegreen("\"" + glk_continuation_prompt + "\""))

# Complete the prompt via invoke()
glk_continuation_prompt_completion = glk_model.invoke(glk_continuation_prompt)
print(makebold("\nGemmaLocalKaggle completion:\n"), makeblue("\"" + glk_continuation_prompt_completion) + "\"")

glk_model # try to clean up model memory
gc.collect() # Run python memory collector

[1;30mCherry-picked history:
[0m [1;33m"As the sun rose the next morning, I bid farewell to Sarah and the charming cottage. I knew that I would cherish the memories I had made during my stay. And so, my journey began, guided by the gentle hand of the sun and the kindness of strangers. I was Daniel, a traveler who found a home in the rolling hills and the hearts of the people I met along the way."[0m
[1;30m
GemmaLocalKaggle prompt:
[0m[1;32m"You are a writer of fiction. Use the following context and continue telling the story.

As the sun rose the next morning, I bid farewell to Sarah and the charming cottage. I knew that I would cherish the memories I had made during my stay. And so, my journey began, guided by the gentle hand of the sun and the kindness of strangers. I was Daniel, a traveler who found a home in the rolling hills and the hearts of the people I met along the way."[0m
[1;30m
GemmaLocalKaggle completion:
[0m [1;34m"You are a writer of fiction. Use the following

40

In [14]:
#
# Continuation via Langchain's GemmaChatLocalKaggle class
# Documentation: https://python.langchain.com/api_reference/google_vertexai/gemma/langchain_google_vertexai.gemma.GemmaChatLocalKaggle.html
#
from langchain_google_vertexai import GemmaChatLocalKaggle
gclk_model = GemmaChatLocalKaggle(model_name="gemma_instruct_2b_en", keras_backend="jax")

# Extract the last two sentences of previous response as 'context' for the next prompt in the conversation/story.
cherry_picked_history = " ".join( gclk_prompt_completion.strip().split("\n")[-2:] )
print(makebold("Cherry-picked history:\n"), makered("\"" + cherry_picked_history + "\""))

# Create the prompt via manual formatting
gclk_continuation_prompt = f"You are a storyteller.  Use the following context and continue telling the story.\n\n{cherry_picked_history}" # Note: format may depend on the fine-tuning dataset format
print(makebold("\nGemmaChatLocalKaggle prompt:\n"), makegreen( "\"" + gclk_continuation_prompt + "\"") )

# Complete the prompt via invoke()
ai_message = gclk_model.invoke(gclk_continuation_prompt) # return a special aimessage object
print(makebold("\nGemmaChatLocalKaggle aimessage response object:"))
ai_message.pretty_print()
# extract just the model completion from aimessage - vertexai appears to insert "turn" tags
turn_prefix = "<start_of_turn>model\n"
pos = ai_message.content.find(turn_prefix)
gclk_prompt_completion = ai_message.content[ pos + len(turn_prefix): ]

print(makebold("\nGemmaChatLocalKaggle response only:\n") + makeblue( "\"" + gclk_prompt_completion + "\""))

gclk_model = None # try to clean up model memory
gc.collect() # Run python memory collector

[1;30mCherry-picked history:
[0m [1;33m" My story is a testament to the power of curiosity, the importance of lifelong learning, and the enduring strength of love. It is a story that I will continue to tell, for it is a story that belongs to all of us."[0m
[1;30m
GemmaChatLocalKaggle prompt:
[0m [1;32m"You are a storyteller.  Use the following context and continue telling the story.

 My story is a testament to the power of curiosity, the importance of lifelong learning, and the enduring strength of love. It is a story that I will continue to tell, for it is a story that belongs to all of us."[0m
[1;30m
GemmaChatLocalKaggle aimessage response object:[0m

<start_of_turn>user
You are a storyteller.  Use the following context and continue telling the story.

 My story is a testament to the power of curiosity, the importance of lifelong learning, and the enduring strength of love. It is a story that I will continue to tell, for it is a story that belongs to all of us.<end_of_turn

89845