In [1]:
#
#  This notebook demonstrates the following:
#  * Use the Keras built-in OPT model ('opt_125m_en')
#  * Shows various methods to complete a prompt in Keras
#    * model generate() method
#    * Langchain Runnable class
#    * ChatVertexAI
#  * Experiments with ways to integrate conversation history
#

In [2]:
#  Python imports
#
#  Notes:
#  * Make sure you install the packages in requirements.txt
#  * Make sure you setup your KAGGLE secrets via env vars.
#
import os
import keras
import keras_nlp
from keras_nlp.models import GemmaBackbone, BertBackbone, OPTBackbone
from keras.models import load_model
import kagglehub
from langchain.schema.runnable import Runnable
import langchain_core
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from typing import Any, Optional
import tensorflow as tf
from keras.config import disable_interactive_logging
from langchain_google_vertexai import ChatVertexAI
import gc

In [3]:
#  Global config
#
os.environ["GRPC_VERBOSITY"] = "ERROR" # Try to suppress annoying GRPC warnings
os.environ["GLOG_minloglevel"] = "3"   # Try to suppress annoying GLOG warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Try to suppress annoying TF CPP warnings
# Try to suppress general warnings sent through the warnings module
def warn(*args, **kwargs): 
    pass
import warnings
warnings.warn = warn
disable_interactive_logging() # Try to support keras warnings

In [4]:
# Useful functions
#
def makebold(txt):
    return '\x1b[1;30m'+txt+'\x1b[0m'
def makeblue(txt):
    return '\x1b[1;34m'+txt+'\x1b[0m'
def makegreen(txt):
    return '\x1b[1;32m'+txt+'\x1b[0m'
def makered(txt):
    return '\x1b[1;33m'+txt+'\x1b[0m'

In [5]:
#
# Initial prompt - I'll use this prompt to contrast-and-compare various 
# prompt completion techniques using Keras
#
initial_prompt = "Tell me a story. My name is Daniel."

In [6]:
#
#  Test a basic prompt completion using the model generate() call
#
model_lm = keras_nlp.models.OPTCausalLM.from_preset("opt_125m_en")
sampler = keras_nlp.samplers.TopKSampler(k=3, seed=2)
model_lm.compile(sampler=sampler)

# Create the prompt via manual formatting
template = "{instruction}" # Note: format may depend on the fine-tuning dataset format
simple_prompt = template.format(
    instruction=initial_prompt,
    response="",
)
print(makebold("Prompt:\n"), makegreen("\"" + simple_prompt + "\"") )

# Complete the prompt via model generate()
model_prompt_completion = model_lm.generate(simple_prompt, max_length=2048) # Note: Using the context size of Gemma as max
print(makebold("\nCompletion:\n"), makeblue("\""+ model_prompt_completion + "\"" ))

model_lm = None # clean up memory
gc.collect() # Run python memory collector

[1;30mPrompt:
[0m [1;32m"Tell me a story. My name is Daniel."[0m


I0000 00:00:1726947013.688816 11938918 service.cc:146] XLA service 0x600000d98300 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1726947013.688843 11938918 service.cc:154]   StreamExecutor device (0): Host, Default Version
I0000 00:00:1726947013.693087 11938921 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1;30m
Completion:
[0m [1;34m"Tell me a story. My name is Daniel. I am a young woman. I am a student of the University of California, Berkeley. I live in the United States of America. My name is Daniel.

I was a student in the University of California, Berkeley when I was born. I was born in a very poor neighborhood. I was born in a very poor neighborhood. I was born a poor boy.

My father, my grandfather, was a very poor man. He was a very poor boy.

My father, my grandfather, was a very poor man. He was a very poor boy.

My father, my grandfather, was a very poor boy.

I was a very poor boy.

I was a very poor boy.

I was a very poor boy.

I was a very poor boy.

I was a very poor boy.

I was a very poor boy.

I was a very poor boy.

I was a very poor boy."[0m


58304

In [7]:
#
# Test a basic prompt completion via a minimal Langchain Runnable
#
class LMRunnable(Runnable):
    def __init__(self, model: keras_nlp.models.OPTCausalLM):
        self.model = model

    def invoke(self, input: Any, config: Optional[dict] = None) -> str:
        prompt = str(input)
        output = self.model.generate(prompt, max_length=2048)
        return output if isinstance(output,str) else "Unknown output type."

# Use the Gemma model as a Langchain Runnable
model_lm = keras_nlp.models.OPTCausalLM.from_preset("opt_125m_en")
sampler = keras_nlp.samplers.TopKSampler(k=3, seed=2)
model_lm.compile(sampler=sampler)
runnable = LMRunnable(model_lm)

# Create the prompt via manual formatting
runnable_prompt = f"{initial_prompt}" # Note: format may depend on the fine-tuning dataset format
print(makebold("Runnable prompt:\n") +  makegreen("\"" + runnable_prompt + "\"") )

# Complete the prompt via Runnable invoke()
runnable_prompt_completion = runnable.invoke(runnable_prompt)
print(makebold("\nRunnable completion:\n") + makeblue("\"" + runnable_prompt_completion + "\""))

runnable = None # release model handle
model_lm = None # try to clean up model memory
gc.collect() # Run python memory collector

[1;30mRunnable prompt:
[0m[1;32m"Tell me a story. My name is Daniel."[0m
[1;30m
Runnable completion:
[0m[1;34m"Tell me a story. My name is Daniel. I am a young woman. I am a student of the University of California, Berkeley. I live in the United States of America. My name is Daniel.

I was a student in the University of California, Berkeley when I was born. I was born in a very poor neighborhood. I was born in a very poor neighborhood. I was born a poor boy.

My father, my grandfather, was a very poor man. He was a very poor boy.

My father, my grandfather, was a very poor man. He was a very poor boy.

My father, my grandfather, was a very poor boy.

I was a very poor boy.

I was a very poor boy.

I was a very poor boy.

I was a very poor boy.

I was a very poor boy.

I was a very poor boy.

I was a very poor boy.

I was a very poor boy."[0m


58272

In [8]:
#
# Test a basic prompt completion via a minimal Langchain Runnable and ChatPromptTemplate
#
class LMRunnable(Runnable):
    def __init__(self, model: keras_nlp.models.OPTCausalLM):
        self.model = model

    def invoke(self, input: Any, config: Optional[dict] = None) -> str: 
        if not isinstance(input,langchain_core.prompt_values.PromptValue):
            raise Exception("Unknown input type")
        # Note the implementor of this class needs to know about the template
        # in order to form the raw prompt to present to the model
        raw_prompt = input.messages[0].content + "\n\n" + input.messages[1].content
        output = self.model.generate(raw_prompt, max_length=2048)
        return output 

# Use the Gemma model as a Langchain Runnable
model_lm = keras_nlp.models.OPTCausalLM.from_preset("opt_125m_en")
sampler = keras_nlp.samplers.TopKSampler(k=3, seed=2)
model_lm.compile(sampler=sampler)
runnable = LMRunnable(model_lm)

# Create the prompt via a ChatPromptTemplate
prompt_template = ChatPromptTemplate.from_messages([
    ("system", "You are a writer of function."),
    MessagesPlaceholder("msgs")
])

# Make a prompt using the template just defined
runnable_prompt = prompt_template.invoke({"msgs": [HumanMessage(content=initial_prompt)]})
print( makebold("Runnable prompt from template:\n") \
      + makegreen( "\"" + runnable_prompt.messages[0].content + "\n\n" + runnable_prompt.messages[1].content + "\"") )

# Complete the prompt
runnable_prompt_completion = runnable.invoke(runnable_prompt)
print(makebold("\nRunnable completion:\n") + makeblue("\"" + runnable_prompt_completion + "\""))

runnable = None # release model handle
model_lm = None # try to clean up model memory
gc.collect() # Run python memory collector

[1;30mRunnable prompt from template:
[0m[1;32m"You are a writer of function.

Tell me a story. My name is Daniel."[0m
[1;30m
Runnable completion:
[0m[1;34m"You are a writer of function.

Tell me a story. My name is Daniel. I am a writer of function. I am the author of function and function. I am the author of function and function.

Tell me a story. I am a writer of function.

I am the author of function and function.

Tell me a story.

I am the author of function and function.

I am the author of function and function.

Tell me a story.

I am the author of function and function.

Tell me a story.

I am the author of function and function.

Tell me a story.

I am the author of function and function.

I am the author of function and function.

Tell us about your story.

I am the author of function and function.

I am the author of function and function.

Tell us about your story.

I am the author of function and function.

Tell us about your story.

I am the author of function an

58282

In [9]:
#
# TODO: Test a basic prompt completion using Langchain's ChatVertexAI class
# Documentation: TODO
#

# TODO cvai_model = ChatVertexAI(model_name="opt_125m_en", keras_backend="jax")
