# What is this notebook for?
* a test notebook to try running Gemma models locally on your windows machine
* make sure you've successfully run the installation steps in README.md

# Pre-requisites

* You should have a machine with enough RAM to load Gemma models (for example Gemma 2B requires around 8 GB of RAM).  I tested on a windows laptop with 8GB RAM
* You should have already gotten Google's permission to use Gemma (request through your Kaggle account)
* You should have created your Kaggle API keys, you will need them in the next step 

# Configure your Kaggle API access

* Create a file called ".env" in the current directory where you launched this notebook
* Enter your KAGGLE API keys as key/value pairs in that file on separate lines, for example:
  * KAGGLE_USERNAME=xxxxxx...
  * KAGGLE_KEY=xxxxxx...


In [3]:
import os
import keras
import keras_nlp
from keras_nlp.models import GemmaBackbone, BertBackbone
from keras.models import load_model
from keras import backend as K
import tensorflow
from IPython.display import Markdown, display
import textwrap
import json
import pandas as pd
import gc
from dotenv import load_dotenv

In [4]:
# set up KERAS parameters recommended by Google
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00" # Avoid memory fragmentation on JAX backend.

# integrate KAGGLE API secret key
load_dotenv() # Make sure there is a file named ".env" with the key/value pairs in the "current" directory
os.environ["KAGGLE_USERNAME"] = os.getenv('KAGGLE_USERNAME') # Link to KAGGLE API secret key
os.environ["KAGGLE_KEY"] = os.getenv('KAGGLE_KEY') # Link to KAGGLE API secret key

In [5]:
def display_chat(prompt, response):
  '''Displays an LLM prompt and response in a pretty way.'''
  prompt = prompt.replace('\n\n','<br><br>')
  prompt = prompt.replace('\n','<br>')
  formatted_prompt = "<font size='+1' color='brown'>🙋‍♂️<blockquote>" + prompt + "</blockquote></font>"
  response = response.replace('•', '  *')
  response = textwrap.indent(response, '', predicate=lambda _: True)
  response = response.replace('\n\n','<br><br>')
  response = response.replace('\n','<br>')
  response = response.replace("```","")
  formatted_text = "<font size='+1' color='teal'>🤖<blockquote>" + response + "</blockquote></font>"
  return Markdown(formatted_prompt+formatted_text)

In [None]:
%%time
# load Gemma 2B base
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_2b_en/1/download/config.json...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [00:00<00:00, 574kB/s]
2024-10-05 16:58:16.912128: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 2359296000 exceeds 10% of free system memory.
2024-10-05 16:58:20.599016: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 2359296000 exceeds 10% of free system memory.
2024-10-05 16:58:26.833785: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 2359296000 exceeds 10% of free system memory.


In [None]:
%%time
template = "Instruction:\n{question}\n\nResponse:\n{answer}"

prompt = template.format(
    question="What should I eat in when I visit Blahlabhlah?",
    answer="",
)
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)