# Fine-tune Gemma models in Keras using LoRA

In [1]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.3.3 which is incompatible.[0m[31m
[0m

In [2]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

In [3]:
import keras
import keras_nlp

2024-05-13 22:04:56.819151: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-13 22:04:56.819282: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-13 22:04:56.952914: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Load Dataset

In [6]:
import json
data = []
with open('/kaggle/input/databricks-dolly-15k/databricks-dolly-15k.jsonl') as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]
data[:2]

['Instruction:\nWhich is a species of fish? Tope or Rope\n\nResponse:\nTope',
 'Instruction:\nWhy can camels survive for long without water?\n\nResponse:\nCamels use the fat in their humps to keep them filled with energy and hydration for long periods of time.']

## Load Model

In [7]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'task.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' f

In [8]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
1. Take a trip to Europe.
2. Take a trip to Europe.
3. Take a trip to Europe.
4. Take a trip to Europe.
5. Take a trip to Europe.
6. Take a trip to Europe.
7. Take a trip to Europe.
8. Take a trip to Europe.
9. Take a trip to Europe.
10. Take a trip to Europe.
11. Take a trip to Europe.
12. Take a trip to Europe.
13. Take a trip to Europe.
14. Take a trip to Europe.
15. Take a trip to Europe.
16. Take a trip to Europe.
17. Take a trip to Europe.
18. Take a trip to Europe.
19. Take a trip to Europe.
20. Take a trip to Europe.
21. Take a trip to Europe.
22. Take a trip to Europe.
23. Take a trip to Europe.
24. Take a trip to Europe.
25. Take a trip to


In [9]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Photosynthesis is the process by which plants use the energy from the sun to convert water and carbon dioxide into oxygen and glucose. The process begins with the absorption of light energy by chlorophyll molecules in the leaves of plants. The energy from the light is used to split water molecules into hydrogen and oxygen. The oxygen is released into the atmosphere, while the hydrogen is used to make glucose. The glucose is then used by the plant to make energy and grow.

Explanation:
Photosynthesis is the process by which plants use the energy from the sun to convert water and carbon dioxide into oxygen and glucose. The process begins with the absorption of light energy by chlorophyll molecules in the leaves of plants. The energy from the light is used to split water molecules into hydrogen and oxygen. The oxygen is released into the atmosphere, while the hydrogen is used to make gluc

The responses contains words that might not be easy to understand for a child such as chlorophyll, glucose, etc.

In [10]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.5 billion to 1.3 million).

In [11]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m757s[0m 732ms/step - loss: 0.4593 - sparse_categorical_accuracy: 0.5236


<keras.src.callbacks.history.History at 0x7a53580257e0>

### Europe Trip Prompt


In [12]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
There are so many things to do in Europe, but here are a few suggestions:

1. Visit the Eiffel Tower in Paris
2. Take a river cruise on the Rhine River in Germany
3. Visit the Colosseum in Rome
4. Take a train ride through the Swiss Alps
5. Visit the Vatican City in Rome


In [13]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Photosynthesis is the process by which plants convert light energy into chemical energy. The chemical energy is stored in the form of glucose, which is used by the plant to grow and reproduce. The process of photosynthesis involves the following steps:
1. Light energy is absorbed by chlorophyll molecules in the leaves of the plant.
2. The absorbed light energy is converted into chemical energy in the form of ATP (adenosine triphosphate) molecules.
3. The ATP molecules are used to power the process of photosynthesis.
4. The carbon dioxide gas and water molecules are combined to form glucose molecules.
5. The glucose molecules are used by the plant to grow and reproduce.
