##### Copyright 2024 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/get_started"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  </td>
    <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/get_started.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335"><img src="https://ai.google.dev/images/cloud-icon.svg" width="40" />Open in Vertex AI</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/get_started.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

# Get started with Gemma using KerasNLP

This tutorial shows you how to get started with Gemma using [KerasNLP](https://keras.io/keras_nlp/). Gemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models. KerasNLP is a collection of natural language processing (NLP) models implemented in [Keras](https://keras.io/) and runnable on JAX, PyTorch, and TensorFlow.

In this tutorial, you'll use Gemma to generate text responses to several prompts. If you're new to Keras, you might want to read [Getting started with Keras](https://keras.io/getting_started/) before you begin, but you don't have to. You'll learn more about Keras as you work through this tutorial.

## Setup

### Gemma setup

To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

* Get access to Gemma on kaggle.com.
* Select a Colab runtime with sufficient resources to run
  the Gemma 2B model.
* Generate and configure a Kaggle username and API key.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.


### Set environment variables

Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`.

In [None]:
import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

### Install dependencies

Install Keras and KerasNLP.

In [None]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
...processing symbol 'RobertaPreprocessor'
...processing symbol 'ReversibleEmbedding'
...processing symbol 'XLMRobertaMaskedLMPreprocessor'
...processing symbol 'DebertaV3Preprocessor'
...processing symbol 'FNetMaskedLMPreprocessor'
...processing symbol 'MistralCausalLM'
...processing symbol 'BartBackbone'
...processing symbol 'MaskedLMHead'
...processing symbol 'AlbertTokenizer'
...processing symbol 'MistralPreprocessor'
...processing symbol 'RobertaTokenizer'
...processing symbol 'FNetTokenizer'
...processing symbol 'OPTPreprocessor'
...processing symbol 'T5Tokenizer'
Writing out API files.
...writing keras_nlp/samplers/__init__.py
...writing keras_nlp/__init__.py
...writing keras_nlp/metrics/__init__.py
...writing keras_nlp/layers/__init__.py
...writing keras_nlp/tokenizers/__init__.py
...writing keras_nlp/models/__init__.py
[1m* Creating venv isolated environment...[0m
[1m* Installing packages in isolated environme

### Import packages

Import Keras and KerasNLP.

In [None]:
import keras
import keras_nlp

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. [Keras 3](https://keras.io/keras_3) lets you choose the backend: TensorFlow, JAX, or PyTorch. All three will work for this tutorial.

In [None]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".

## Create a model

KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/). In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [None]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")


Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...


`from_preset` instantiates the model from a preset architecture and weights. In the code above, the string `"gemma_2b_en"` specifies the preset architecture: a Gemma model with 2 billion parameters. (A Gemma model with 7 billion parameters is also available. To run the larger model in Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform [distributed tuning on a Gemma 7B model](https://ai.google.dev/gemma/docs/distributed_tuning) on Kaggle or Google Cloud.)

Use `summary` to get more info about the model:

In [None]:
gemma_lm.summary()

As you can see from the summary, the model has 2.5 billion trainable parameters.

## Generate text

Now it's time to generate some text! The model has a `generate` method that generates text based on a prompt. The optional `max_length` argument specifies the maximum length of the generated sequence.

Try it out with the prompt `"What is the meaning of life?"`.

In [None]:
gemma_lm.generate("What is the meaning of life?", max_length=64)

'What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives'

Try calling `generate` again with a different prompt.

In [None]:
gemma_lm.generate("How does the brain work?", max_length=64)

'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up'

If you're running on JAX or TensorFlow backends, you'll notice that the second `generate` call returns nearly instantly. This is because each call to `generate` for a given batch size and `max_length` is compiled with XLA. The first run is expensive, but subsequent runs are much faster.

You can also provide batched prompts using a list as input:

In [None]:
gemma_lm.generate(
    ["What is the meaning of life?",
     "How does the brain work?"],
    max_length=64)

['What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives',
 'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up']

### Optional: Try a different sampler

You can control the generation strategy for `GemmaCausalLM` by setting the `sampler` argument on `compile()`. By default, [`"greedy"`](https://keras.io/api/keras_nlp/samplers/greedy_sampler/#greedysampler-class) sampling will be used.

As an experiment, try setting a [`"top_k"`](https://keras.io/api/keras_nlp/samplers/top_k_sampler/) strategy:

In [None]:
gemma_lm.compile(sampler="top_k")
gemma_lm.generate("What is the meaning of life?", max_length=64)

'What is the meaning of life? That was a question I asked myself as I was driving home from work one night in 2012. I was driving through the city of San Bernardino, and all I could think was, “What the heck am I doing?”\n\nMy life was completely different. I'

While the default greedy algorithm always picks the token with the largest probability, the top-K algorithm randomly picks the next token from the tokens of top K probability.

You don't have to specify a sampler, and you can ignore the last code snippet if it's not helpful to your use case. If you'd like learn more about the available samplers, see [Samplers](https://keras.io/api/keras_nlp/samplers/).

## What's next

In this tutorial, you learned how to generate text using KerasNLP and Gemma. Here are a few suggestions for what to learn next:

* Learn how to [finetune a Gemma model](https://ai.google.dev/gemma/docs/lora_tuning).
* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/distributed_tuning).
* Learn how to [use Gemma models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).