Text generation with GEMMA using KerasNLP, a collection of natural language processing (NLP) models implemented in Keras and runnable on JAX, PyTorch, and TensorFlow.
Use Case | Framework | Model Repo | Branch/Commit/Tag | Optional Patch |
---|---|---|---|---|
Inference | Keras | gemma | - | - |
The model checkpoints are available through Kaggle at http://kaggle.com/models/google/gemma. Select one of the keras model variations, click the ⤓ button to download the model archive, then extract the contents to a local directory. The archive contains the model weights, the tokenizer and other necessary files to load the model. An example of what the extracted archive of gemma_2b_en
keras model looks like:
assets
config.json
metadata.json
model.weights.h5 # Model weights
tokenizer.json # Tokenizer
Once you download the file archive.tar.gz
, untar the file and point the unzipped directory to MODEL_PATH
.
-
git clone https://github.com/IntelAI/models.git
-
cd models/models_v2/jax/gemma/inference/cpu
-
Create virtual environment
venv
and activate it:python3 -m venv venv . ./venv/bin/activate
-
Setup required environment variables for setup
Environment Variable Purpose export command JAX_NIGHTLY (optional)Set to 1 to install the nightly release of JAX. If not set to 1, defaults to the public release of JAX export JAX_NIGHTLY=1
-
Run
setup.sh
./setup.sh
-
Setup required environment variables for running the model
Environment Variable Purpose export command PRECISIONDetermine the precision for inference export PRECISION=fp32/fp16/bfloat16
MODEL_PATHLocal path to the downloaded model weights & tokenizer export MODEL_PATH=/tmp/gemma_2b_en
KERAS_BACKENDDetermine the backend framework for Keras export KERAS_BACKEND=tensorflow/jax>
OUTPUT_DIRLocal path to save the output logs export OUTPUT_DIR=/tmp/keras_gemma_output
MAX_LENGTH (optional)Max length of the generated sequence export MAX_LENGTH=64
-
Run
run_model.sh
. This will runN
instances ofgenerate.py
, whereN
is the number of sockets on the system (1 instance per socket)../run_model.sh
Output of run_model.sh
typically looks as below. Note that the value indicates the sum of throughput of all the instances:
Total throughput: 0.390845 inputs/sec
Output of any of the instances typically looks like:
Time taken for first generate (warmup): 10.724524021148682
Time taken for second generate: 10.216123819351196
Latency: 10.216123819351196 sec
Throughput: 0.1957689663286614 inputs/sec
followed by the prompt
and its corresponding output
.
Final results of the inference run can be found in results.yaml
file.
results:
- key: total throughput
value: 0.390845
unit: inputs/sec
keras_nlp
installs stock version of latest publictensorflow
as a dependency. If you're running with a custom built or nightly version of TensorFlow, you will need to uninstalltensorflow
after installingkeras-nlp
and then force reinstall your version oftensorflow
.- There are other ways to load the model using the Kaggle APIs like
KaggleHub
orKaggle CLI
orcURL
or by configuring your Kaggle API key.