In [24]:
#
#  This notebook demonstrates the following:
#  * gemma_2b_en vs gemma_1.1_instruct_2b_en
#

In [17]:
#  Python imports
#
#  Notes:
#  * Make sure you install the packages in requirements.txt
#  * Make sure you setup your KAGGLE secrets via env vars.
#
import os
import keras
import keras_nlp
from keras_nlp.models import GemmaBackbone, BertBackbone
from keras.models import load_model
import kagglehub
from langchain.schema.runnable import Runnable
from typing import Any, Optional
import tensorflow as tf
from keras.config import disable_interactive_logging
import gc
from IPython.display import Markdown
import textwrap

In [32]:
# Some useful functions we will use later to display the interaction with the model
#
def display_chat(prompt, text):
  prompt = prompt.replace('\n\n','<br><br>')
  prompt = prompt.replace('\n','<br>')
  formatted_prompt = "<font size='+1' color='brown'>🙋‍♂️<blockquote>" + prompt + "</blockquote></font>"
  text = text.replace('•', '  *')
  text = textwrap.indent(text, '', predicate=lambda _: True)
  text = text.replace('\n\n','<br><br>')
  text = text.replace('\n','<br>')
  text = text.replace("```","")
  formatted_text = "<font size='+1' color='teal'>🤖<blockquote>" + text + "</blockquote></font>"
  return Markdown(formatted_prompt+formatted_text)
def to_markdown(text):
  text = text.replace('•', '  *')
  return Markdown(textwrap.indent(text, '', predicate=lambda _: True))

In [14]:
# Load the default Gemma model, not fine-tuned
#
model_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

In [19]:
# Ask a simple math problem
#
prompt = "Tell me, in a few words,  how to compute all prime numbers up to 1000?"
completion = model_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

<font size='+1' color='brown'>🙋‍♂️<blockquote>Tell me, in a few words,  how to compute all prime numbers up to 1000?</blockquote></font><font size='+1' color='teal'>🤖<blockquote><br><br>Answer:<br><br>Step 1/5<br>1. Start by finding the first prime number, which is 2.<br><br>Step 2/5<br>2. Check if 2 is a prime number by dividing it by all numbers less than or equal to 2. If it is not divisible by any number, then it is a prime number.<br><br>Step 3/5<br>3. If 2 is not a prime number, then find the next prime number, which is 3.<br><br>Step 4/5<br>4. Check if 3 is a prime number by dividing it by all numbers less than or equal to 3. If it is not divisible by any number, then it is a prime number.<br><br>Step 5/5<br>5. Continue this process until you reach 1000. Here is a Python code to do this: ``` prime_numbers = [2] for i in range(3, 1000): if i % 2 == 0: continue if i % 3 == 0: continue if i % 5 == 0: continue if i % 7 == 0: continue if i % 11 == 0: continue if i % 13 == 0: continue if i % 17 == 0: continue if i % 19 == 0: continue if i % 23 == 0: continue if i % 29 == 0: continue if i % 31 == 0: continue if i % 37 == 0: continue if i % 41 == 0: continue if i % 43 == 0: continue if i % 47 == 0: continue if i % 53 == 0: continue if i % 59 == 0: continue if i % 61 == 0: continue if i % 67 == 0: continue if i % 71 == 0: continue if i % 73 == 0: continue if i % 79 == 0: continue if i % 83 == 0: continue if i % 89 == 0: continue if i % 97 == 0: continue if i % 101 == 0: continue prime_numbers.append(i) print(prime_numbers) ``` This code will print all prime numbers up to 1000.</blockquote></font>

In [20]:
# Note above that the default "gemma_2b_en" model already has the ability to take requests and commands from the user
# and returns a suitable response.  In this case, it answered our question about a simple math problem.  
# Now, lets try to ask a follow up question in order to re-fine our original request.

In [22]:
# Ask a follow-up question
#
prompt = "Now in Python! No numpy, please!"
completion = model_lm.generate(prompt)
response = completion.replace(prompt, "")
display_chat(prompt, response)

<font size='+1' color='brown'>🙋‍♂️<blockquote>Now in Python! No numpy, please!</blockquote></font><font size='+1' color='teal'>🤖<blockquote><br><br><h1>Introduction</h1><br><br>This is a simple implementation of the <strong><em>K-Means</em></strong> clustering algorithm in Python.<br><br><h1>What is K-Means?</h1><br><br>K-Means is a <strong><em>clustering algorithm</em></strong> that groups a set of data points into <strong><em>k</em></strong> clusters.<br><br>The algorithm is based on the <strong><em>inertia</em></strong> of the clusters.<br><br>The <strong><em>inertia</em></strong> is the sum of the squared distances of each data point to its nearest cluster center.<br><br>The goal of the algorithm is to find the <strong><em>k</em></strong> cluster centers that minimize the <strong><em>inertia</em></strong>.<br><br><h1>How does it work?</h1><br><br>The algorithm starts by randomly selecting <strong><em>k</em></strong> cluster centers.<br><br>Then, it iteratively assigns each data point to the cluster center that is closest to it.<br><br>This process is repeated until the <strong><em>inertia</em></strong> does not change significantly.<br><br><h1>Implementation</h1><br><br>The code is available on my GitHub.<br><br><h1>Conclusion</h1><br><br>This is a simple implementation of the K-Means algorithm in Python.<br><br>It is not the most efficient or the most accurate implementation, but it should be enough for most practical applications.<br><br>If you have any questions or suggestions, feel free to leave a comment.<br><br>Thanks for reading!</blockquote></font>

In [25]:
# Notice that the answer to the folow-up question had nothing to do with our original request.
# In order to support a more conversational interaction with Gemma, Google has release a version
# of the model that's more helpful.  Let's try that one, called "gemma_1.1_instruct_2b_en" now.

In [34]:
# The model "gemma_1.1_instruct_2b_en" supports adding conversation context into the prompt but it
# requires using a specifc set of tags in the prompt.  The following code makes it easy to do this
# by encapsulating all of that formatting and it also retains the conversation history with the model.
#
class ChatState():
  """
  Manages the conversation history for a turn-based chatbot
  Follows the turn-based conversation guidelines for the Gemma family of models
  documented at https://ai.google.dev/gemma/docs/formatting
  """
  __START_TURN_USER__ = "<start_of_turn>user\n"
  __START_TURN_MODEL__ = "<start_of_turn>model\n"
  __END_TURN__ = "<end_of_turn>\n"
  def __init__(self, model, system=""):
    """
    Initializes the chat state.
    Args:
        model: The language model to use for generating responses.
        system: (Optional) System instructions or bot description.
    """
    self.model = model
    self.system = system
    self.history = []
  def add_to_history_as_user(self, message):
      """
      Adds a user message to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)
  def add_to_history_as_model(self, message):
      """
      Adds a model response to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_MODEL__ + message + self.__END_TURN__)
  def get_history(self):
      """
      Returns the entire chat history as a single string.
      """
      return "".join([*self.history])
  def get_full_prompt(self):
    """
    Builds the prompt for the language model, including history and system description.
    """
    prompt = self.get_history() + self.__START_TURN_MODEL__
    if len(self.system)>0:
      prompt = self.system + "\n" + prompt
    return prompt

  def send_message(self, message):
    """
    Handles sending a user message and getting a model response.
    Args:
        message: The user's message.
    Returns:
        The model's response.
    """
    self.add_to_history_as_user(message)
    prompt = self.get_full_prompt()
    response = self.model.generate(prompt, max_length=1024)
    result = response.replace(prompt, "")  # Extract only the new response    
    self.add_to_history_as_model(result)
    return result

In [42]:
# Let's instantiate the model and the chat helper code
#
model_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_2b_en")
chat = ChatState(model_lm)

In [43]:
# Now let's repeat the conversation with "gemma_1.1_instruct_2b_en" to see if it performs better.
# Let's start with the question...
#
prompt = "Tell me, in a few words,  how to compute all prime numbers up to 1000?"
response = chat.send_message(prompt)
display_chat(prompt, response)

<font size='+1' color='brown'>🙋‍♂️<blockquote>Tell me, in a few words,  how to compute all prime numbers up to 1000?</blockquote></font><font size='+1' color='teal'>🤖<blockquote>The Sieve of Eratosthenes is a widely used method to compute all prime numbers up to a given limit. It involves iteratively marking out multiples of each prime number.</blockquote></font>

In [44]:
# Now let's ask the follow-up question...
#
prompt = "Now in Python! No numpy, please!"
response = chat.send_message(prompt)
display_chat(prompt, response)

<font size='+1' color='brown'>🙋‍♂️<blockquote>Now in Python! No numpy, please!</blockquote></font><font size='+1' color='teal'>🤖<blockquote>python<br>def prime(n):<br>    if n <= 1:<br>        return False<br>    for i in range(2, int(n**0.5) + 1):<br>        if n % i == 0:<br>            return False<br>    return True<br></blockquote></font>

In [45]:
# Release resources that consume lots of memory
chat = None
model_lm = None
gc.collect

<function gc.collect(generation=2)>