In [1]:
%pip install --upgrade -q \
    keras-nlp==0.12.1 \
    keras==3.3.3 \
    jaxlib==0.4.30 \
    jax[cuda12]==0.4.30 \
    git+https://github.com/google-deepmind/gemma.git@a24194737dcb54b7392091e9ba772aea1cb68ffb \
    \
    kagglehub==0.2.5


Note: you may need to restart the kernel to use updated packages.


## Download model

In [2]:
import os

# TODO: Create Kaggle API token from https://www.kaggle.com/settings
os.environ["KAGGLE_USERNAME"] = "[TODO]"
os.environ["KAGGLE_KEY"] = "[TODO]"

import kagglehub

# Download latest version
model_path = kagglehub.model_download("keras/gemma/keras/gemma_instruct_2b_en")

print("Path to model files:", model_path)



Path to model files: /home/work/.cache/kagglehub/models/keras/gemma/keras/gemma_instruct_2b_en/2


## Download dataset

In [3]:
!mkdir -p datasets
!wget \
    https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl \
    -O datasets/databricks-dolly-15k.jsonl


--2024-06-24 00:51:50--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 13.225.131.35, 13.225.131.94, 13.225.131.6, ...
Connecting to huggingface.co (huggingface.co)|13.225.131.35|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1719449510&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxOTQ0OTUxMH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkO

## Pre-process dataset

In [4]:
import json
from pathlib import Path

import pandas as pd
from tqdm.notebook import tqdm

num_samples = 1000

dataset_path = Path().parent / "datasets" / "databricks-dolly-15k.jsonl"
data = pd.read_json(dataset_path, lines=True)
print(data.shape)

prompt_template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

preprocessed_data = []
for _, row in tqdm(data.iterrows()):
    preprocessed_data.append(
        prompt_template.format(
            instruction=row["instruction"],
            response=row["response"],
        )
    )

# Only use a limited number of training examples
preprocessed_data = preprocessed_data[:num_samples]


(15011, 4)


0it [00:00, ?it/s]

## Fine-tune

In [5]:
import os

os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

import keras
import keras_nlp

batch_size = 1

model = keras_nlp.models.GemmaCausalLM.from_preset(str(model_path))
model.summary()

model.backbone.enable_lora(rank=4)
model.summary()

model.preprocessor.sequence_length = 512
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(preprocessed_data, epochs=1, batch_size=batch_size, verbose=1)


2024-06-24 00:51:55.491401: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
2024-06-24 00:52:06.671942: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2024-06-24 00:52:06.674296: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


2024-06-24 00:52:08.298407: E tensorflow/core/util/util.cc:131] oneDNN supports DT_INT64 only on platforms with AVX-512. Falling back to the default Eigen-based implementation if present.
2024-06-24 00:54:14.215545: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module gemm_fusion_dot.639] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2024-06-24 00:54:15.572648: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 2m1.357171695s

********************************
[Compiling module gemm_fusion_dot.639] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m174s[0m 42ms/step - loss: 0.5734 - sparse_categorical_accuracy: 0.4935


2024-06-24 00:55:02.705848: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


<keras.src.callbacks.history.History at 0x7f1ea0613d90>

## Evaluate

In [6]:
prompts = [
    prompt_template.format(
        instruction="What should I do on a trip to Europe?",
        response="",
    ),
    prompt_template.format(
        instruction="Explain the process of photosynthesis in a way that a child could understand.",
        response="",
    ),
]

sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
model.compile(sampler=sampler)

for prompt in prompts:
    print(model.generate(prompt, max_length=256))


Instruction:
What should I do on a trip to Europe?

Response:
There are two main types of trips to Europe: short trips and long trips. If you are looking to spend a weekend in Europe, there are many cities to choose from such as London, Paris, Rome, and Amsterdam. If you are looking to spend several weeks in Europe, there are many cities and countries to explore such as Barcelona, Berlin, and Prague.
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Sure, here's the process of photosynthesis explained in simpler terms.
Sure, photosynthesis is when plants and other organisms use sunlight to convert water, carbon dioxide and energy to make food, or glucose. It's a process that helps us to get the food that we need to survive.
It's done by special cells called chloroplasts in plant and algal cells called chloroplasts in plant and algal cells.
The chloroplasts contain chlorophyll, a green pigment that absorbs the energy from the Sun.
When