In [None]:
import os
from google.colab import userdata


os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

In [None]:
! pip install -q -U keras-hub
! pip install  -q -U keras

In [None]:
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

In [None]:
import keras
import keras_hub

In [None]:
# load the model

# https://www.kaggle.com/models/keras/gemma3

# gemma_lm = keras_nlp.models.GemmaCasualLM.from_preset("gemma_2b_en")

gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_instruct_1b")

gemma_lm.summary()

In [None]:
# inference before fine-tuning
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_hub.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

In [None]:
# load dataset jsonl file
! wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl

In [None]:
import json

prompts = []
responses = []
line_count = 0

with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        if line_count >= 1000:
            break  # Limit the training examples, to reduce execution time.

        examples = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if examples["context"]:
            continue
        # Format data into prompts and response lists.
        prompts.append(examples["instruction"])
        responses.append(examples["response"])

        line_count += 1

data = {
    "prompts": prompts,
    "responses": responses
}

In [None]:
# fine-tuning

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)

In [None]:
# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256

# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)

# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

In [None]:
# run the fine-tuning process
gemma_lm.fit(data, epochs=1, batch_size=1)

In [None]:
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

In [None]:
# inference after fine-tuning
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_hub.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

In [None]:
# Improving fine-tune results
# For demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:

# Increasing the size of the fine-tuning dataset
# Training for more steps (epochs)
# Setting a higher LoRA rank
# Modifying the hyperparameter values such as learning_rate and weight_decay.