# Install and import dependencies

In [None]:
! nvidia-smi -L

## Install dependencies

In [None]:
%%time

from IPython.display import clear_output

! pip install -qq -U kaggle
! pip install -qq -U keras-nlp
! pip install -qq -U keras>=3
! pip install -qq -U datasets

clear_output()

## Kaggle Config


In [None]:
# copy kaggle.json to /root/.kaggle/ folder so that kaggle cli can access it.
!mkdir /.kaggle
!mv kaggle.json /.kaggle
!mv /.kaggle /root/
!chmod 600 ~/.kaggle/kaggle.json

## Select a Backend

In [None]:
import os

os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

## Import dependencies

In [None]:
import keras
import keras_nlp

# Load Dataset

In [None]:
import json

data = []

with open("/content/databricks-dolly-15k.jsonl") as file:
  for line in file:
    features = json.loads(line)

    #Filter out examples with context, to keep it simple.
    if features["context"] == "":
      continue

    # Format the entire example as a single string.
    template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
    data.append(template.format(**features))


# Load Model

In [None]:
%%time

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

# Inference before fine tuning

In [None]:
## Brazil Trip Prompt

prompt = template.format(
    instruction="What should I do on a trip to Brazil ?",
    response=""
)

print(gemma_lm.generate(prompt, max_length=256))

In [None]:
#ELI5 Photosynthesis Prompt

prompt = template.format(
    instruction="Explatin the process of photosynthesis in a way that a child could understand.",
    response=""
)

print(gemma_lm.generate(prompt, max_length=256))

# LoRA Fine-Turing

To get better responses from the model, fine-tune the model with Low Rank Adaptation (LoRA) using the Databricks Dolly 15k dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.



In [None]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.5 billion to 1.3 million).

In [None]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512

# Use AdamW (a common optimizer for tranformer models).
optimizer = keras.optimizers.AdamW(
  learning_rate=5e-4,
  weight_decay=0.01,
)

# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=optimizer,
  weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

# Inference after fine-tuning

In [None]:
# Brazil Trip Prompt

prompt = template.format(
    instruction="What should I do on a trip to Brazil ?",
    response=""
)

print(gemma_lm.generate(prompt, max_length=256))

In [None]:
# ELI5 Photosynthesis Prompt

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response=""
)

print(gemma_lm.generate(prompt, max_length=256))